diff options
Diffstat (limited to 'llvm/lib/Analysis')
102 files changed, 83637 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/AliasAnalysis.cpp b/llvm/lib/Analysis/AliasAnalysis.cpp new file mode 100644 index 000000000000..55dd9a4cda08 --- /dev/null +++ b/llvm/lib/Analysis/AliasAnalysis.cpp @@ -0,0 +1,907 @@ +//==- AliasAnalysis.cpp - Generic Alias Analysis Interface Implementation --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the generic AliasAnalysis interface which is used as the +// common interface used by all clients and implementations of alias analysis. +// +// This file also implements the default version of the AliasAnalysis interface +// that is to be used when no other implementation is specified. This does some +// simple tests that detect obvious cases: two different global pointers cannot +// alias, a global cannot alias a malloc, two different mallocs cannot alias, +// etc. +// +// This alias analysis implementation really isn't very good for anything, but +// it is very fast, and makes a nice clean default implementation. Because it +// handles lots of little corner cases, other, more complex, alias analysis +// implementations may choose to rely on this pass to resolve these simple and +// easy cases. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ObjCARCAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <iterator> + +using namespace llvm; + +/// Allow disabling BasicAA from the AA results. This is particularly useful +/// when testing to isolate a single AA implementation. +static cl::opt<bool> DisableBasicAA("disable-basicaa", cl::Hidden, + cl::init(false)); + +AAResults::AAResults(AAResults &&Arg) + : TLI(Arg.TLI), AAs(std::move(Arg.AAs)), AADeps(std::move(Arg.AADeps)) { + for (auto &AA : AAs) + AA->setAAResults(this); +} + +AAResults::~AAResults() { +// FIXME; It would be nice to at least clear out the pointers back to this +// aggregation here, but we end up with non-nesting lifetimes in the legacy +// pass manager that prevent this from working. In the legacy pass manager +// we'll end up with dangling references here in some cases. +#if 0 + for (auto &AA : AAs) + AA->setAAResults(nullptr); +#endif +} + +bool AAResults::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // AAResults preserves the AAManager by default, due to the stateless nature + // of AliasAnalysis. There is no need to check whether it has been preserved + // explicitly. Check if any module dependency was invalidated and caused the + // AAManager to be invalidated. Invalidate ourselves in that case. + auto PAC = PA.getChecker<AAManager>(); + if (!PAC.preservedWhenStateless()) + return true; + + // Check if any of the function dependencies were invalidated, and invalidate + // ourselves in that case. + for (AnalysisKey *ID : AADeps) + if (Inv.invalidate(ID, F, PA)) + return true; + + // Everything we depend on is still fine, so are we. Nothing to invalidate. + return false; +} + +//===----------------------------------------------------------------------===// +// Default chaining methods +//===----------------------------------------------------------------------===// + +AliasResult AAResults::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + AAQueryInfo AAQIP; + return alias(LocA, LocB, AAQIP); +} + +AliasResult AAResults::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, AAQueryInfo &AAQI) { + for (const auto &AA : AAs) { + auto Result = AA->alias(LocA, LocB, AAQI); + if (Result != MayAlias) + return Result; + } + return MayAlias; +} + +bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc, + bool OrLocal) { + AAQueryInfo AAQIP; + return pointsToConstantMemory(Loc, AAQIP, OrLocal); +} + +bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc, + AAQueryInfo &AAQI, bool OrLocal) { + for (const auto &AA : AAs) + if (AA->pointsToConstantMemory(Loc, AAQI, OrLocal)) + return true; + + return false; +} + +ModRefInfo AAResults::getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) { + ModRefInfo Result = ModRefInfo::ModRef; + + for (const auto &AA : AAs) { + Result = intersectModRef(Result, AA->getArgModRefInfo(Call, ArgIdx)); + + // Early-exit the moment we reach the bottom of the lattice. + if (isNoModRef(Result)) + return ModRefInfo::NoModRef; + } + + return Result; +} + +ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2) { + AAQueryInfo AAQIP; + return getModRefInfo(I, Call2, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2, + AAQueryInfo &AAQI) { + // We may have two calls. + if (const auto *Call1 = dyn_cast<CallBase>(I)) { + // Check if the two calls modify the same memory. + return getModRefInfo(Call1, Call2, AAQI); + } else if (I->isFenceLike()) { + // If this is a fence, just return ModRef. + return ModRefInfo::ModRef; + } else { + // Otherwise, check if the call modifies or references the + // location this memory access defines. The best we can say + // is that if the call references what this instruction + // defines, it must be clobbered by this location. + const MemoryLocation DefLoc = MemoryLocation::get(I); + ModRefInfo MR = getModRefInfo(Call2, DefLoc, AAQI); + if (isModOrRefSet(MR)) + return setModAndRef(MR); + } + return ModRefInfo::NoModRef; +} + +ModRefInfo AAResults::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(Call, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + ModRefInfo Result = ModRefInfo::ModRef; + + for (const auto &AA : AAs) { + Result = intersectModRef(Result, AA->getModRefInfo(Call, Loc, AAQI)); + + // Early-exit the moment we reach the bottom of the lattice. + if (isNoModRef(Result)) + return ModRefInfo::NoModRef; + } + + // Try to refine the mod-ref info further using other API entry points to the + // aggregate set of AA results. + auto MRB = getModRefBehavior(Call); + if (MRB == FMRB_DoesNotAccessMemory || + MRB == FMRB_OnlyAccessesInaccessibleMem) + return ModRefInfo::NoModRef; + + if (onlyReadsMemory(MRB)) + Result = clearMod(Result); + else if (doesNotReadMemory(MRB)) + Result = clearRef(Result); + + if (onlyAccessesArgPointees(MRB) || onlyAccessesInaccessibleOrArgMem(MRB)) { + bool IsMustAlias = true; + ModRefInfo AllArgsMask = ModRefInfo::NoModRef; + if (doesAccessArgPointees(MRB)) { + for (auto AI = Call->arg_begin(), AE = Call->arg_end(); AI != AE; ++AI) { + const Value *Arg = *AI; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned ArgIdx = std::distance(Call->arg_begin(), AI); + MemoryLocation ArgLoc = + MemoryLocation::getForArgument(Call, ArgIdx, TLI); + AliasResult ArgAlias = alias(ArgLoc, Loc); + if (ArgAlias != NoAlias) { + ModRefInfo ArgMask = getArgModRefInfo(Call, ArgIdx); + AllArgsMask = unionModRef(AllArgsMask, ArgMask); + } + // Conservatively clear IsMustAlias unless only MustAlias is found. + IsMustAlias &= (ArgAlias == MustAlias); + } + } + // Return NoModRef if no alias found with any argument. + if (isNoModRef(AllArgsMask)) + return ModRefInfo::NoModRef; + // Logical & between other AA analyses and argument analysis. + Result = intersectModRef(Result, AllArgsMask); + // If only MustAlias found above, set Must bit. + Result = IsMustAlias ? setMust(Result) : clearMust(Result); + } + + // If Loc is a constant memory location, the call definitely could not + // modify the memory location. + if (isModSet(Result) && pointsToConstantMemory(Loc, /*OrLocal*/ false)) + Result = clearMod(Result); + + return Result; +} + +ModRefInfo AAResults::getModRefInfo(const CallBase *Call1, + const CallBase *Call2) { + AAQueryInfo AAQIP; + return getModRefInfo(Call1, Call2, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const CallBase *Call1, + const CallBase *Call2, AAQueryInfo &AAQI) { + ModRefInfo Result = ModRefInfo::ModRef; + + for (const auto &AA : AAs) { + Result = intersectModRef(Result, AA->getModRefInfo(Call1, Call2, AAQI)); + + // Early-exit the moment we reach the bottom of the lattice. + if (isNoModRef(Result)) + return ModRefInfo::NoModRef; + } + + // Try to refine the mod-ref info further using other API entry points to the + // aggregate set of AA results. + + // If Call1 or Call2 are readnone, they don't interact. + auto Call1B = getModRefBehavior(Call1); + if (Call1B == FMRB_DoesNotAccessMemory) + return ModRefInfo::NoModRef; + + auto Call2B = getModRefBehavior(Call2); + if (Call2B == FMRB_DoesNotAccessMemory) + return ModRefInfo::NoModRef; + + // If they both only read from memory, there is no dependence. + if (onlyReadsMemory(Call1B) && onlyReadsMemory(Call2B)) + return ModRefInfo::NoModRef; + + // If Call1 only reads memory, the only dependence on Call2 can be + // from Call1 reading memory written by Call2. + if (onlyReadsMemory(Call1B)) + Result = clearMod(Result); + else if (doesNotReadMemory(Call1B)) + Result = clearRef(Result); + + // If Call2 only access memory through arguments, accumulate the mod/ref + // information from Call1's references to the memory referenced by + // Call2's arguments. + if (onlyAccessesArgPointees(Call2B)) { + if (!doesAccessArgPointees(Call2B)) + return ModRefInfo::NoModRef; + ModRefInfo R = ModRefInfo::NoModRef; + bool IsMustAlias = true; + for (auto I = Call2->arg_begin(), E = Call2->arg_end(); I != E; ++I) { + const Value *Arg = *I; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned Call2ArgIdx = std::distance(Call2->arg_begin(), I); + auto Call2ArgLoc = + MemoryLocation::getForArgument(Call2, Call2ArgIdx, TLI); + + // ArgModRefC2 indicates what Call2 might do to Call2ArgLoc, and the + // dependence of Call1 on that location is the inverse: + // - If Call2 modifies location, dependence exists if Call1 reads or + // writes. + // - If Call2 only reads location, dependence exists if Call1 writes. + ModRefInfo ArgModRefC2 = getArgModRefInfo(Call2, Call2ArgIdx); + ModRefInfo ArgMask = ModRefInfo::NoModRef; + if (isModSet(ArgModRefC2)) + ArgMask = ModRefInfo::ModRef; + else if (isRefSet(ArgModRefC2)) + ArgMask = ModRefInfo::Mod; + + // ModRefC1 indicates what Call1 might do to Call2ArgLoc, and we use + // above ArgMask to update dependence info. + ModRefInfo ModRefC1 = getModRefInfo(Call1, Call2ArgLoc); + ArgMask = intersectModRef(ArgMask, ModRefC1); + + // Conservatively clear IsMustAlias unless only MustAlias is found. + IsMustAlias &= isMustSet(ModRefC1); + + R = intersectModRef(unionModRef(R, ArgMask), Result); + if (R == Result) { + // On early exit, not all args were checked, cannot set Must. + if (I + 1 != E) + IsMustAlias = false; + break; + } + } + + if (isNoModRef(R)) + return ModRefInfo::NoModRef; + + // If MustAlias found above, set Must bit. + return IsMustAlias ? setMust(R) : clearMust(R); + } + + // If Call1 only accesses memory through arguments, check if Call2 references + // any of the memory referenced by Call1's arguments. If not, return NoModRef. + if (onlyAccessesArgPointees(Call1B)) { + if (!doesAccessArgPointees(Call1B)) + return ModRefInfo::NoModRef; + ModRefInfo R = ModRefInfo::NoModRef; + bool IsMustAlias = true; + for (auto I = Call1->arg_begin(), E = Call1->arg_end(); I != E; ++I) { + const Value *Arg = *I; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned Call1ArgIdx = std::distance(Call1->arg_begin(), I); + auto Call1ArgLoc = + MemoryLocation::getForArgument(Call1, Call1ArgIdx, TLI); + + // ArgModRefC1 indicates what Call1 might do to Call1ArgLoc; if Call1 + // might Mod Call1ArgLoc, then we care about either a Mod or a Ref by + // Call2. If Call1 might Ref, then we care only about a Mod by Call2. + ModRefInfo ArgModRefC1 = getArgModRefInfo(Call1, Call1ArgIdx); + ModRefInfo ModRefC2 = getModRefInfo(Call2, Call1ArgLoc); + if ((isModSet(ArgModRefC1) && isModOrRefSet(ModRefC2)) || + (isRefSet(ArgModRefC1) && isModSet(ModRefC2))) + R = intersectModRef(unionModRef(R, ArgModRefC1), Result); + + // Conservatively clear IsMustAlias unless only MustAlias is found. + IsMustAlias &= isMustSet(ModRefC2); + + if (R == Result) { + // On early exit, not all args were checked, cannot set Must. + if (I + 1 != E) + IsMustAlias = false; + break; + } + } + + if (isNoModRef(R)) + return ModRefInfo::NoModRef; + + // If MustAlias found above, set Must bit. + return IsMustAlias ? setMust(R) : clearMust(R); + } + + return Result; +} + +FunctionModRefBehavior AAResults::getModRefBehavior(const CallBase *Call) { + FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior; + + for (const auto &AA : AAs) { + Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(Call)); + + // Early-exit the moment we reach the bottom of the lattice. + if (Result == FMRB_DoesNotAccessMemory) + return Result; + } + + return Result; +} + +FunctionModRefBehavior AAResults::getModRefBehavior(const Function *F) { + FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior; + + for (const auto &AA : AAs) { + Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(F)); + + // Early-exit the moment we reach the bottom of the lattice. + if (Result == FMRB_DoesNotAccessMemory) + return Result; + } + + return Result; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, AliasResult AR) { + switch (AR) { + case NoAlias: + OS << "NoAlias"; + break; + case MustAlias: + OS << "MustAlias"; + break; + case MayAlias: + OS << "MayAlias"; + break; + case PartialAlias: + OS << "PartialAlias"; + break; + } + return OS; +} + +//===----------------------------------------------------------------------===// +// Helper method implementation +//===----------------------------------------------------------------------===// + +ModRefInfo AAResults::getModRefInfo(const LoadInst *L, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(L, Loc, AAQIP); +} +ModRefInfo AAResults::getModRefInfo(const LoadInst *L, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + // Be conservative in the face of atomic. + if (isStrongerThan(L->getOrdering(), AtomicOrdering::Unordered)) + return ModRefInfo::ModRef; + + // If the load address doesn't alias the given address, it doesn't read + // or write the specified memory. + if (Loc.Ptr) { + AliasResult AR = alias(MemoryLocation::get(L), Loc, AAQI); + if (AR == NoAlias) + return ModRefInfo::NoModRef; + if (AR == MustAlias) + return ModRefInfo::MustRef; + } + // Otherwise, a load just reads. + return ModRefInfo::Ref; +} + +ModRefInfo AAResults::getModRefInfo(const StoreInst *S, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(S, Loc, AAQIP); +} +ModRefInfo AAResults::getModRefInfo(const StoreInst *S, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + // Be conservative in the face of atomic. + if (isStrongerThan(S->getOrdering(), AtomicOrdering::Unordered)) + return ModRefInfo::ModRef; + + if (Loc.Ptr) { + AliasResult AR = alias(MemoryLocation::get(S), Loc, AAQI); + // If the store address cannot alias the pointer in question, then the + // specified memory cannot be modified by the store. + if (AR == NoAlias) + return ModRefInfo::NoModRef; + + // If the pointer is a pointer to constant memory, then it could not have + // been modified by this store. + if (pointsToConstantMemory(Loc, AAQI)) + return ModRefInfo::NoModRef; + + // If the store address aliases the pointer as must alias, set Must. + if (AR == MustAlias) + return ModRefInfo::MustMod; + } + + // Otherwise, a store just writes. + return ModRefInfo::Mod; +} + +ModRefInfo AAResults::getModRefInfo(const FenceInst *S, const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(S, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const FenceInst *S, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + // If we know that the location is a constant memory location, the fence + // cannot modify this location. + if (Loc.Ptr && pointsToConstantMemory(Loc, AAQI)) + return ModRefInfo::Ref; + return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const VAArgInst *V, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(V, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const VAArgInst *V, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (Loc.Ptr) { + AliasResult AR = alias(MemoryLocation::get(V), Loc, AAQI); + // If the va_arg address cannot alias the pointer in question, then the + // specified memory cannot be accessed by the va_arg. + if (AR == NoAlias) + return ModRefInfo::NoModRef; + + // If the pointer is a pointer to constant memory, then it could not have + // been modified by this va_arg. + if (pointsToConstantMemory(Loc, AAQI)) + return ModRefInfo::NoModRef; + + // If the va_arg aliases the pointer as must alias, set Must. + if (AR == MustAlias) + return ModRefInfo::MustModRef; + } + + // Otherwise, a va_arg reads and writes. + return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(CatchPad, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (Loc.Ptr) { + // If the pointer is a pointer to constant memory, + // then it could not have been modified by this catchpad. + if (pointsToConstantMemory(Loc, AAQI)) + return ModRefInfo::NoModRef; + } + + // Otherwise, a catchpad reads and writes. + return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(CatchRet, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (Loc.Ptr) { + // If the pointer is a pointer to constant memory, + // then it could not have been modified by this catchpad. + if (pointsToConstantMemory(Loc, AAQI)) + return ModRefInfo::NoModRef; + } + + // Otherwise, a catchret reads and writes. + return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(CX, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + // Acquire/Release cmpxchg has properties that matter for arbitrary addresses. + if (isStrongerThanMonotonic(CX->getSuccessOrdering())) + return ModRefInfo::ModRef; + + if (Loc.Ptr) { + AliasResult AR = alias(MemoryLocation::get(CX), Loc, AAQI); + // If the cmpxchg address does not alias the location, it does not access + // it. + if (AR == NoAlias) + return ModRefInfo::NoModRef; + + // If the cmpxchg address aliases the pointer as must alias, set Must. + if (AR == MustAlias) + return ModRefInfo::MustModRef; + } + + return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW, + const MemoryLocation &Loc) { + AAQueryInfo AAQIP; + return getModRefInfo(RMW, Loc, AAQIP); +} + +ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + // Acquire/Release atomicrmw has properties that matter for arbitrary addresses. + if (isStrongerThanMonotonic(RMW->getOrdering())) + return ModRefInfo::ModRef; + + if (Loc.Ptr) { + AliasResult AR = alias(MemoryLocation::get(RMW), Loc, AAQI); + // If the atomicrmw address does not alias the location, it does not access + // it. + if (AR == NoAlias) + return ModRefInfo::NoModRef; + + // If the atomicrmw address aliases the pointer as must alias, set Must. + if (AR == MustAlias) + return ModRefInfo::MustModRef; + } + + return ModRefInfo::ModRef; +} + +/// Return information about whether a particular call site modifies +/// or reads the specified memory location \p MemLoc before instruction \p I +/// in a BasicBlock. An ordered basic block \p OBB can be used to speed up +/// instruction-ordering queries inside the BasicBlock containing \p I. +/// FIXME: this is really just shoring-up a deficiency in alias analysis. +/// BasicAA isn't willing to spend linear time determining whether an alloca +/// was captured before or after this particular call, while we are. However, +/// with a smarter AA in place, this test is just wasting compile time. +ModRefInfo AAResults::callCapturesBefore(const Instruction *I, + const MemoryLocation &MemLoc, + DominatorTree *DT, + OrderedBasicBlock *OBB) { + if (!DT) + return ModRefInfo::ModRef; + + const Value *Object = + GetUnderlyingObject(MemLoc.Ptr, I->getModule()->getDataLayout()); + if (!isIdentifiedObject(Object) || isa<GlobalValue>(Object) || + isa<Constant>(Object)) + return ModRefInfo::ModRef; + + const auto *Call = dyn_cast<CallBase>(I); + if (!Call || Call == Object) + return ModRefInfo::ModRef; + + if (PointerMayBeCapturedBefore(Object, /* ReturnCaptures */ true, + /* StoreCaptures */ true, I, DT, + /* include Object */ true, + /* OrderedBasicBlock */ OBB)) + return ModRefInfo::ModRef; + + unsigned ArgNo = 0; + ModRefInfo R = ModRefInfo::NoModRef; + bool IsMustAlias = true; + // Set flag only if no May found and all operands processed. + for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end(); + CI != CE; ++CI, ++ArgNo) { + // Only look at the no-capture or byval pointer arguments. If this + // pointer were passed to arguments that were neither of these, then it + // couldn't be no-capture. + if (!(*CI)->getType()->isPointerTy() || + (!Call->doesNotCapture(ArgNo) && ArgNo < Call->getNumArgOperands() && + !Call->isByValArgument(ArgNo))) + continue; + + AliasResult AR = alias(MemoryLocation(*CI), MemoryLocation(Object)); + // If this is a no-capture pointer argument, see if we can tell that it + // is impossible to alias the pointer we're checking. If not, we have to + // assume that the call could touch the pointer, even though it doesn't + // escape. + if (AR != MustAlias) + IsMustAlias = false; + if (AR == NoAlias) + continue; + if (Call->doesNotAccessMemory(ArgNo)) + continue; + if (Call->onlyReadsMemory(ArgNo)) { + R = ModRefInfo::Ref; + continue; + } + // Not returning MustModRef since we have not seen all the arguments. + return ModRefInfo::ModRef; + } + return IsMustAlias ? setMust(R) : clearMust(R); +} + +/// canBasicBlockModify - Return true if it is possible for execution of the +/// specified basic block to modify the location Loc. +/// +bool AAResults::canBasicBlockModify(const BasicBlock &BB, + const MemoryLocation &Loc) { + return canInstructionRangeModRef(BB.front(), BB.back(), Loc, ModRefInfo::Mod); +} + +/// canInstructionRangeModRef - Return true if it is possible for the +/// execution of the specified instructions to mod\ref (according to the +/// mode) the location Loc. The instructions to consider are all +/// of the instructions in the range of [I1,I2] INCLUSIVE. +/// I1 and I2 must be in the same basic block. +bool AAResults::canInstructionRangeModRef(const Instruction &I1, + const Instruction &I2, + const MemoryLocation &Loc, + const ModRefInfo Mode) { + assert(I1.getParent() == I2.getParent() && + "Instructions not in same basic block!"); + BasicBlock::const_iterator I = I1.getIterator(); + BasicBlock::const_iterator E = I2.getIterator(); + ++E; // Convert from inclusive to exclusive range. + + for (; I != E; ++I) // Check every instruction in range + if (isModOrRefSet(intersectModRef(getModRefInfo(&*I, Loc), Mode))) + return true; + return false; +} + +// Provide a definition for the root virtual destructor. +AAResults::Concept::~Concept() = default; + +// Provide a definition for the static object used to identify passes. +AnalysisKey AAManager::Key; + +namespace { + + +} // end anonymous namespace + +char ExternalAAWrapperPass::ID = 0; + +INITIALIZE_PASS(ExternalAAWrapperPass, "external-aa", "External Alias Analysis", + false, true) + +ImmutablePass * +llvm::createExternalAAWrapperPass(ExternalAAWrapperPass::CallbackT Callback) { + return new ExternalAAWrapperPass(std::move(Callback)); +} + +AAResultsWrapperPass::AAResultsWrapperPass() : FunctionPass(ID) { + initializeAAResultsWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +char AAResultsWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(AAResultsWrapperPass, "aa", + "Function Alias Analysis Results", false, true) +INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLAndersAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLSteensAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ExternalAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScopedNoAliasAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TypeBasedAAWrapperPass) +INITIALIZE_PASS_END(AAResultsWrapperPass, "aa", + "Function Alias Analysis Results", false, true) + +FunctionPass *llvm::createAAResultsWrapperPass() { + return new AAResultsWrapperPass(); +} + +/// Run the wrapper pass to rebuild an aggregation over known AA passes. +/// +/// This is the legacy pass manager's interface to the new-style AA results +/// aggregation object. Because this is somewhat shoe-horned into the legacy +/// pass manager, we hard code all the specific alias analyses available into +/// it. While the particular set enabled is configured via commandline flags, +/// adding a new alias analysis to LLVM will require adding support for it to +/// this list. +bool AAResultsWrapperPass::runOnFunction(Function &F) { + // NB! This *must* be reset before adding new AA results to the new + // AAResults object because in the legacy pass manager, each instance + // of these will refer to the *same* immutable analyses, registering and + // unregistering themselves with them. We need to carefully tear down the + // previous object first, in this case replacing it with an empty one, before + // registering new results. + AAR.reset( + new AAResults(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F))); + + // BasicAA is always available for function analyses. Also, we add it first + // so that it can trump TBAA results when it proves MustAlias. + // FIXME: TBAA should have an explicit mode to support this and then we + // should reconsider the ordering here. + if (!DisableBasicAA) + AAR->addAAResult(getAnalysis<BasicAAWrapperPass>().getResult()); + + // Populate the results with the currently available AAs. + if (auto *WrapperPass = getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<TypeBasedAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = + getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<GlobalsAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<SCEVAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + + // If available, run an external AA providing callback over the results as + // well. + if (auto *WrapperPass = getAnalysisIfAvailable<ExternalAAWrapperPass>()) + if (WrapperPass->CB) + WrapperPass->CB(*this, F, *AAR); + + // Analyses don't mutate the IR, so return false. + return false; +} + +void AAResultsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<BasicAAWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + + // We also need to mark all the alias analysis passes we will potentially + // probe in runOnFunction as used here to ensure the legacy pass manager + // preserves them. This hard coding of lists of alias analyses is specific to + // the legacy pass manager. + AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); + AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); + AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); + AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); + AU.addUsedIfAvailable<SCEVAAWrapperPass>(); + AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); + AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); +} + +AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, + BasicAAResult &BAR) { + AAResults AAR(P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F)); + + // Add in our explicitly constructed BasicAA results. + if (!DisableBasicAA) + AAR.addAAResult(BAR); + + // Populate the results with the other currently available AAs. + if (auto *WrapperPass = + P.getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = P.getAnalysisIfAvailable<TypeBasedAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = + P.getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = P.getAnalysisIfAvailable<GlobalsAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) + AAR.addAAResult(WrapperPass->getResult()); + + return AAR; +} + +bool llvm::isNoAliasCall(const Value *V) { + if (const auto *Call = dyn_cast<CallBase>(V)) + return Call->hasRetAttr(Attribute::NoAlias); + return false; +} + +bool llvm::isNoAliasArgument(const Value *V) { + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasNoAliasAttr(); + return false; +} + +bool llvm::isIdentifiedObject(const Value *V) { + if (isa<AllocaInst>(V)) + return true; + if (isa<GlobalValue>(V) && !isa<GlobalAlias>(V)) + return true; + if (isNoAliasCall(V)) + return true; + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasNoAliasAttr() || A->hasByValAttr(); + return false; +} + +bool llvm::isIdentifiedFunctionLocal(const Value *V) { + return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); +} + +void llvm::getAAResultsAnalysisUsage(AnalysisUsage &AU) { + // This function needs to be in sync with llvm::createLegacyPMAAResults -- if + // more alias analyses are added to llvm::createLegacyPMAAResults, they need + // to be added here also. + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); + AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); + AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); + AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); + AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); + AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); +} diff --git a/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp new file mode 100644 index 000000000000..e83703867e09 --- /dev/null +++ b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -0,0 +1,434 @@ +//===- AliasAnalysisEvaluator.cpp - Alias Analysis Accuracy Evaluator -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysisEvaluator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static cl::opt<bool> PrintAll("print-all-alias-modref-info", cl::ReallyHidden); + +static cl::opt<bool> PrintNoAlias("print-no-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMayAlias("print-may-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintPartialAlias("print-partial-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMustAlias("print-must-aliases", cl::ReallyHidden); + +static cl::opt<bool> PrintNoModRef("print-no-modref", cl::ReallyHidden); +static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden); +static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden); +static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); +static cl::opt<bool> PrintMust("print-must", cl::ReallyHidden); +static cl::opt<bool> PrintMustRef("print-mustref", cl::ReallyHidden); +static cl::opt<bool> PrintMustMod("print-mustmod", cl::ReallyHidden); +static cl::opt<bool> PrintMustModRef("print-mustmodref", cl::ReallyHidden); + +static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); + +static void PrintResults(AliasResult AR, bool P, const Value *V1, + const Value *V2, const Module *M) { + if (PrintAll || P) { + std::string o1, o2; + { + raw_string_ostream os1(o1), os2(o2); + V1->printAsOperand(os1, true, M); + V2->printAsOperand(os2, true, M); + } + + if (o2 < o1) + std::swap(o1, o2); + errs() << " " << AR << ":\t" << o1 << ", " << o2 << "\n"; + } +} + +static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I, + Value *Ptr, Module *M) { + if (PrintAll || P) { + errs() << " " << Msg << ": Ptr: "; + Ptr->printAsOperand(errs(), true, M); + errs() << "\t<->" << *I << '\n'; + } +} + +static inline void PrintModRefResults(const char *Msg, bool P, CallBase *CallA, + CallBase *CallB, Module *M) { + if (PrintAll || P) { + errs() << " " << Msg << ": " << *CallA << " <-> " << *CallB << '\n'; + } +} + +static inline void PrintLoadStoreResults(AliasResult AR, bool P, + const Value *V1, const Value *V2, + const Module *M) { + if (PrintAll || P) { + errs() << " " << AR << ": " << *V1 << " <-> " << *V2 << '\n'; + } +} + +static inline bool isInterestingPointer(Value *V) { + return V->getType()->isPointerTy() + && !isa<ConstantPointerNull>(V); +} + +PreservedAnalyses AAEvaluator::run(Function &F, FunctionAnalysisManager &AM) { + runInternal(F, AM.getResult<AAManager>(F)); + return PreservedAnalyses::all(); +} + +void AAEvaluator::runInternal(Function &F, AAResults &AA) { + const DataLayout &DL = F.getParent()->getDataLayout(); + + ++FunctionCount; + + SetVector<Value *> Pointers; + SmallSetVector<CallBase *, 16> Calls; + SetVector<Value *> Loads; + SetVector<Value *> Stores; + + for (auto &I : F.args()) + if (I.getType()->isPointerTy()) // Add all pointer arguments. + Pointers.insert(&I); + + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (I->getType()->isPointerTy()) // Add all pointer instructions. + Pointers.insert(&*I); + if (EvalAAMD && isa<LoadInst>(&*I)) + Loads.insert(&*I); + if (EvalAAMD && isa<StoreInst>(&*I)) + Stores.insert(&*I); + Instruction &Inst = *I; + if (auto *Call = dyn_cast<CallBase>(&Inst)) { + Value *Callee = Call->getCalledValue(); + // Skip actual functions for direct function calls. + if (!isa<Function>(Callee) && isInterestingPointer(Callee)) + Pointers.insert(Callee); + // Consider formals. + for (Use &DataOp : Call->data_ops()) + if (isInterestingPointer(DataOp)) + Pointers.insert(DataOp); + Calls.insert(Call); + } else { + // Consider all operands. + for (Instruction::op_iterator OI = Inst.op_begin(), OE = Inst.op_end(); + OI != OE; ++OI) + if (isInterestingPointer(*OI)) + Pointers.insert(*OI); + } + } + + if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias || + PrintMustAlias || PrintNoModRef || PrintMod || PrintRef || PrintModRef) + errs() << "Function: " << F.getName() << ": " << Pointers.size() + << " pointers, " << Calls.size() << " call sites\n"; + + // iterate over the worklist, and run the full (n^2)/2 disambiguations + for (SetVector<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end(); + I1 != E; ++I1) { + auto I1Size = LocationSize::unknown(); + Type *I1ElTy = cast<PointerType>((*I1)->getType())->getElementType(); + if (I1ElTy->isSized()) + I1Size = LocationSize::precise(DL.getTypeStoreSize(I1ElTy)); + + for (SetVector<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) { + auto I2Size = LocationSize::unknown(); + Type *I2ElTy = cast<PointerType>((*I2)->getType())->getElementType(); + if (I2ElTy->isSized()) + I2Size = LocationSize::precise(DL.getTypeStoreSize(I2ElTy)); + + AliasResult AR = AA.alias(*I1, I1Size, *I2, I2Size); + switch (AR) { + case NoAlias: + PrintResults(AR, PrintNoAlias, *I1, *I2, F.getParent()); + ++NoAliasCount; + break; + case MayAlias: + PrintResults(AR, PrintMayAlias, *I1, *I2, F.getParent()); + ++MayAliasCount; + break; + case PartialAlias: + PrintResults(AR, PrintPartialAlias, *I1, *I2, F.getParent()); + ++PartialAliasCount; + break; + case MustAlias: + PrintResults(AR, PrintMustAlias, *I1, *I2, F.getParent()); + ++MustAliasCount; + break; + } + } + } + + if (EvalAAMD) { + // iterate over all pairs of load, store + for (Value *Load : Loads) { + for (Value *Store : Stores) { + AliasResult AR = AA.alias(MemoryLocation::get(cast<LoadInst>(Load)), + MemoryLocation::get(cast<StoreInst>(Store))); + switch (AR) { + case NoAlias: + PrintLoadStoreResults(AR, PrintNoAlias, Load, Store, F.getParent()); + ++NoAliasCount; + break; + case MayAlias: + PrintLoadStoreResults(AR, PrintMayAlias, Load, Store, F.getParent()); + ++MayAliasCount; + break; + case PartialAlias: + PrintLoadStoreResults(AR, PrintPartialAlias, Load, Store, F.getParent()); + ++PartialAliasCount; + break; + case MustAlias: + PrintLoadStoreResults(AR, PrintMustAlias, Load, Store, F.getParent()); + ++MustAliasCount; + break; + } + } + } + + // iterate over all pairs of store, store + for (SetVector<Value *>::iterator I1 = Stores.begin(), E = Stores.end(); + I1 != E; ++I1) { + for (SetVector<Value *>::iterator I2 = Stores.begin(); I2 != I1; ++I2) { + AliasResult AR = AA.alias(MemoryLocation::get(cast<StoreInst>(*I1)), + MemoryLocation::get(cast<StoreInst>(*I2))); + switch (AR) { + case NoAlias: + PrintLoadStoreResults(AR, PrintNoAlias, *I1, *I2, F.getParent()); + ++NoAliasCount; + break; + case MayAlias: + PrintLoadStoreResults(AR, PrintMayAlias, *I1, *I2, F.getParent()); + ++MayAliasCount; + break; + case PartialAlias: + PrintLoadStoreResults(AR, PrintPartialAlias, *I1, *I2, F.getParent()); + ++PartialAliasCount; + break; + case MustAlias: + PrintLoadStoreResults(AR, PrintMustAlias, *I1, *I2, F.getParent()); + ++MustAliasCount; + break; + } + } + } + } + + // Mod/ref alias analysis: compare all pairs of calls and values + for (CallBase *Call : Calls) { + for (auto Pointer : Pointers) { + auto Size = LocationSize::unknown(); + Type *ElTy = cast<PointerType>(Pointer->getType())->getElementType(); + if (ElTy->isSized()) + Size = LocationSize::precise(DL.getTypeStoreSize(ElTy)); + + switch (AA.getModRefInfo(Call, Pointer, Size)) { + case ModRefInfo::NoModRef: + PrintModRefResults("NoModRef", PrintNoModRef, Call, Pointer, + F.getParent()); + ++NoModRefCount; + break; + case ModRefInfo::Mod: + PrintModRefResults("Just Mod", PrintMod, Call, Pointer, F.getParent()); + ++ModCount; + break; + case ModRefInfo::Ref: + PrintModRefResults("Just Ref", PrintRef, Call, Pointer, F.getParent()); + ++RefCount; + break; + case ModRefInfo::ModRef: + PrintModRefResults("Both ModRef", PrintModRef, Call, Pointer, + F.getParent()); + ++ModRefCount; + break; + case ModRefInfo::Must: + PrintModRefResults("Must", PrintMust, Call, Pointer, F.getParent()); + ++MustCount; + break; + case ModRefInfo::MustMod: + PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, Call, Pointer, + F.getParent()); + ++MustModCount; + break; + case ModRefInfo::MustRef: + PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, Call, Pointer, + F.getParent()); + ++MustRefCount; + break; + case ModRefInfo::MustModRef: + PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, Call, + Pointer, F.getParent()); + ++MustModRefCount; + break; + } + } + } + + // Mod/ref alias analysis: compare all pairs of calls + for (CallBase *CallA : Calls) { + for (CallBase *CallB : Calls) { + if (CallA == CallB) + continue; + switch (AA.getModRefInfo(CallA, CallB)) { + case ModRefInfo::NoModRef: + PrintModRefResults("NoModRef", PrintNoModRef, CallA, CallB, + F.getParent()); + ++NoModRefCount; + break; + case ModRefInfo::Mod: + PrintModRefResults("Just Mod", PrintMod, CallA, CallB, F.getParent()); + ++ModCount; + break; + case ModRefInfo::Ref: + PrintModRefResults("Just Ref", PrintRef, CallA, CallB, F.getParent()); + ++RefCount; + break; + case ModRefInfo::ModRef: + PrintModRefResults("Both ModRef", PrintModRef, CallA, CallB, + F.getParent()); + ++ModRefCount; + break; + case ModRefInfo::Must: + PrintModRefResults("Must", PrintMust, CallA, CallB, F.getParent()); + ++MustCount; + break; + case ModRefInfo::MustMod: + PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, CallA, CallB, + F.getParent()); + ++MustModCount; + break; + case ModRefInfo::MustRef: + PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, CallA, CallB, + F.getParent()); + ++MustRefCount; + break; + case ModRefInfo::MustModRef: + PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, CallA, + CallB, F.getParent()); + ++MustModRefCount; + break; + } + } + } +} + +static void PrintPercent(int64_t Num, int64_t Sum) { + errs() << "(" << Num * 100LL / Sum << "." << ((Num * 1000LL / Sum) % 10) + << "%)\n"; +} + +AAEvaluator::~AAEvaluator() { + if (FunctionCount == 0) + return; + + int64_t AliasSum = + NoAliasCount + MayAliasCount + PartialAliasCount + MustAliasCount; + errs() << "===== Alias Analysis Evaluator Report =====\n"; + if (AliasSum == 0) { + errs() << " Alias Analysis Evaluator Summary: No pointers!\n"; + } else { + errs() << " " << AliasSum << " Total Alias Queries Performed\n"; + errs() << " " << NoAliasCount << " no alias responses "; + PrintPercent(NoAliasCount, AliasSum); + errs() << " " << MayAliasCount << " may alias responses "; + PrintPercent(MayAliasCount, AliasSum); + errs() << " " << PartialAliasCount << " partial alias responses "; + PrintPercent(PartialAliasCount, AliasSum); + errs() << " " << MustAliasCount << " must alias responses "; + PrintPercent(MustAliasCount, AliasSum); + errs() << " Alias Analysis Evaluator Pointer Alias Summary: " + << NoAliasCount * 100 / AliasSum << "%/" + << MayAliasCount * 100 / AliasSum << "%/" + << PartialAliasCount * 100 / AliasSum << "%/" + << MustAliasCount * 100 / AliasSum << "%\n"; + } + + // Display the summary for mod/ref analysis + int64_t ModRefSum = NoModRefCount + RefCount + ModCount + ModRefCount + + MustCount + MustRefCount + MustModCount + MustModRefCount; + if (ModRefSum == 0) { + errs() << " Alias Analysis Mod/Ref Evaluator Summary: no " + "mod/ref!\n"; + } else { + errs() << " " << ModRefSum << " Total ModRef Queries Performed\n"; + errs() << " " << NoModRefCount << " no mod/ref responses "; + PrintPercent(NoModRefCount, ModRefSum); + errs() << " " << ModCount << " mod responses "; + PrintPercent(ModCount, ModRefSum); + errs() << " " << RefCount << " ref responses "; + PrintPercent(RefCount, ModRefSum); + errs() << " " << ModRefCount << " mod & ref responses "; + PrintPercent(ModRefCount, ModRefSum); + errs() << " " << MustCount << " must responses "; + PrintPercent(MustCount, ModRefSum); + errs() << " " << MustModCount << " must mod responses "; + PrintPercent(MustModCount, ModRefSum); + errs() << " " << MustRefCount << " must ref responses "; + PrintPercent(MustRefCount, ModRefSum); + errs() << " " << MustModRefCount << " must mod & ref responses "; + PrintPercent(MustModRefCount, ModRefSum); + errs() << " Alias Analysis Evaluator Mod/Ref Summary: " + << NoModRefCount * 100 / ModRefSum << "%/" + << ModCount * 100 / ModRefSum << "%/" << RefCount * 100 / ModRefSum + << "%/" << ModRefCount * 100 / ModRefSum << "%/" + << MustCount * 100 / ModRefSum << "%/" + << MustRefCount * 100 / ModRefSum << "%/" + << MustModCount * 100 / ModRefSum << "%/" + << MustModRefCount * 100 / ModRefSum << "%\n"; + } +} + +namespace llvm { +class AAEvalLegacyPass : public FunctionPass { + std::unique_ptr<AAEvaluator> P; + +public: + static char ID; // Pass identification, replacement for typeid + AAEvalLegacyPass() : FunctionPass(ID) { + initializeAAEvalLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AAResultsWrapperPass>(); + AU.setPreservesAll(); + } + + bool doInitialization(Module &M) override { + P.reset(new AAEvaluator()); + return false; + } + + bool runOnFunction(Function &F) override { + P->runInternal(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); + return false; + } + bool doFinalization(Module &M) override { + P.reset(); + return false; + } +}; +} + +char AAEvalLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AAEvalLegacyPass, "aa-eval", + "Exhaustive Alias Analysis Precision Evaluator", false, + true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(AAEvalLegacyPass, "aa-eval", + "Exhaustive Alias Analysis Precision Evaluator", false, + true) + +FunctionPass *llvm::createAAEvalPass() { return new AAEvalLegacyPass(); } diff --git a/llvm/lib/Analysis/AliasAnalysisSummary.cpp b/llvm/lib/Analysis/AliasAnalysisSummary.cpp new file mode 100644 index 000000000000..2f3396a44117 --- /dev/null +++ b/llvm/lib/Analysis/AliasAnalysisSummary.cpp @@ -0,0 +1,103 @@ +#include "AliasAnalysisSummary.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Compiler.h" + +namespace llvm { +namespace cflaa { + +namespace { +const unsigned AttrEscapedIndex = 0; +const unsigned AttrUnknownIndex = 1; +const unsigned AttrGlobalIndex = 2; +const unsigned AttrCallerIndex = 3; +const unsigned AttrFirstArgIndex = 4; +const unsigned AttrLastArgIndex = NumAliasAttrs; +const unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; + +// It would be *slightly* prettier if we changed these to AliasAttrs, but it +// seems that both GCC and MSVC emit dynamic initializers for const bitsets. +using AliasAttr = unsigned; +const AliasAttr AttrNone = 0; +const AliasAttr AttrEscaped = 1 << AttrEscapedIndex; +const AliasAttr AttrUnknown = 1 << AttrUnknownIndex; +const AliasAttr AttrGlobal = 1 << AttrGlobalIndex; +const AliasAttr AttrCaller = 1 << AttrCallerIndex; +const AliasAttr ExternalAttrMask = AttrEscaped | AttrUnknown | AttrGlobal; +} + +AliasAttrs getAttrNone() { return AttrNone; } + +AliasAttrs getAttrUnknown() { return AttrUnknown; } +bool hasUnknownAttr(AliasAttrs Attr) { return Attr.test(AttrUnknownIndex); } + +AliasAttrs getAttrCaller() { return AttrCaller; } +bool hasCallerAttr(AliasAttrs Attr) { return Attr.test(AttrCaller); } +bool hasUnknownOrCallerAttr(AliasAttrs Attr) { + return Attr.test(AttrUnknownIndex) || Attr.test(AttrCallerIndex); +} + +AliasAttrs getAttrEscaped() { return AttrEscaped; } +bool hasEscapedAttr(AliasAttrs Attr) { return Attr.test(AttrEscapedIndex); } + +static AliasAttr argNumberToAttr(unsigned ArgNum) { + if (ArgNum >= AttrMaxNumArgs) + return AttrUnknown; + // N.B. MSVC complains if we use `1U` here, since AliasAttr' ctor takes + // an unsigned long long. + return AliasAttr(1ULL << (ArgNum + AttrFirstArgIndex)); +} + +AliasAttrs getGlobalOrArgAttrFromValue(const Value &Val) { + if (isa<GlobalValue>(Val)) + return AttrGlobal; + + if (auto *Arg = dyn_cast<Argument>(&Val)) + // Only pointer arguments should have the argument attribute, + // because things can't escape through scalars without us seeing a + // cast, and thus, interaction with them doesn't matter. + if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy()) + return argNumberToAttr(Arg->getArgNo()); + return AttrNone; +} + +bool isGlobalOrArgAttr(AliasAttrs Attr) { + return Attr.reset(AttrEscapedIndex) + .reset(AttrUnknownIndex) + .reset(AttrCallerIndex) + .any(); +} + +AliasAttrs getExternallyVisibleAttrs(AliasAttrs Attr) { + return Attr & AliasAttrs(ExternalAttrMask); +} + +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue, + CallBase &Call) { + auto Index = IValue.Index; + auto *V = (Index == 0) ? &Call : Call.getArgOperand(Index - 1); + if (V->getType()->isPointerTy()) + return InstantiatedValue{V, IValue.DerefLevel}; + return None; +} + +Optional<InstantiatedRelation> +instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call) { + auto From = instantiateInterfaceValue(ERelation.From, Call); + if (!From) + return None; + auto To = instantiateInterfaceValue(ERelation.To, Call); + if (!To) + return None; + return InstantiatedRelation{*From, *To, ERelation.Offset}; +} + +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr, + CallBase &Call) { + auto Value = instantiateInterfaceValue(EAttr.IValue, Call); + if (!Value) + return None; + return InstantiatedAttr{*Value, EAttr.Attr}; +} +} +} diff --git a/llvm/lib/Analysis/AliasAnalysisSummary.h b/llvm/lib/Analysis/AliasAnalysisSummary.h new file mode 100644 index 000000000000..fe75b03cedef --- /dev/null +++ b/llvm/lib/Analysis/AliasAnalysisSummary.h @@ -0,0 +1,265 @@ +//=====- CFLSummary.h - Abstract stratified sets implementation. --------=====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines various utility types and functions useful to +/// summary-based alias analysis. +/// +/// Summary-based analysis, also known as bottom-up analysis, is a style of +/// interprocedrual static analysis that tries to analyze the callees before the +/// callers get analyzed. The key idea of summary-based analysis is to first +/// process each function independently, outline its behavior in a condensed +/// summary, and then instantiate the summary at the callsite when the said +/// function is called elsewhere. This is often in contrast to another style +/// called top-down analysis, in which callers are always analyzed first before +/// the callees. +/// +/// In a summary-based analysis, functions must be examined independently and +/// out-of-context. We have no information on the state of the memory, the +/// arguments, the global values, and anything else external to the function. To +/// carry out the analysis conservative assumptions have to be made about those +/// external states. In exchange for the potential loss of precision, the +/// summary we obtain this way is highly reusable, which makes the analysis +/// easier to scale to large programs even if carried out context-sensitively. +/// +/// Currently, all CFL-based alias analyses adopt the summary-based approach +/// and therefore heavily rely on this header. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H +#define LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/InstrTypes.h" +#include <bitset> + +namespace llvm { +namespace cflaa { + +//===----------------------------------------------------------------------===// +// AliasAttr related stuffs +//===----------------------------------------------------------------------===// + +/// The number of attributes that AliasAttr should contain. Attributes are +/// described below, and 32 was an arbitrary choice because it fits nicely in 32 +/// bits (because we use a bitset for AliasAttr). +static const unsigned NumAliasAttrs = 32; + +/// These are attributes that an alias analysis can use to mark certain special +/// properties of a given pointer. Refer to the related functions below to see +/// what kinds of attributes are currently defined. +typedef std::bitset<NumAliasAttrs> AliasAttrs; + +/// Attr represent whether the said pointer comes from an unknown source +/// (such as opaque memory or an integer cast). +AliasAttrs getAttrNone(); + +/// AttrUnknown represent whether the said pointer comes from a source not known +/// to alias analyses (such as opaque memory or an integer cast). +AliasAttrs getAttrUnknown(); +bool hasUnknownAttr(AliasAttrs); + +/// AttrCaller represent whether the said pointer comes from a source not known +/// to the current function but known to the caller. Values pointed to by the +/// arguments of the current function have this attribute set +AliasAttrs getAttrCaller(); +bool hasCallerAttr(AliasAttrs); +bool hasUnknownOrCallerAttr(AliasAttrs); + +/// AttrEscaped represent whether the said pointer comes from a known source but +/// escapes to the unknown world (e.g. casted to an integer, or passed as an +/// argument to opaque function). Unlike non-escaped pointers, escaped ones may +/// alias pointers coming from unknown sources. +AliasAttrs getAttrEscaped(); +bool hasEscapedAttr(AliasAttrs); + +/// AttrGlobal represent whether the said pointer is a global value. +/// AttrArg represent whether the said pointer is an argument, and if so, what +/// index the argument has. +AliasAttrs getGlobalOrArgAttrFromValue(const Value &); +bool isGlobalOrArgAttr(AliasAttrs); + +/// Given an AliasAttrs, return a new AliasAttrs that only contains attributes +/// meaningful to the caller. This function is primarily used for +/// interprocedural analysis +/// Currently, externally visible AliasAttrs include AttrUnknown, AttrGlobal, +/// and AttrEscaped +AliasAttrs getExternallyVisibleAttrs(AliasAttrs); + +//===----------------------------------------------------------------------===// +// Function summary related stuffs +//===----------------------------------------------------------------------===// + +/// The maximum number of arguments we can put into a summary. +static const unsigned MaxSupportedArgsInSummary = 50; + +/// We use InterfaceValue to describe parameters/return value, as well as +/// potential memory locations that are pointed to by parameters/return value, +/// of a function. +/// Index is an integer which represents a single parameter or a return value. +/// When the index is 0, it refers to the return value. Non-zero index i refers +/// to the i-th parameter. +/// DerefLevel indicates the number of dereferences one must perform on the +/// parameter/return value to get this InterfaceValue. +struct InterfaceValue { + unsigned Index; + unsigned DerefLevel; +}; + +inline bool operator==(InterfaceValue LHS, InterfaceValue RHS) { + return LHS.Index == RHS.Index && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InterfaceValue LHS, InterfaceValue RHS) { + return !(LHS == RHS); +} +inline bool operator<(InterfaceValue LHS, InterfaceValue RHS) { + return LHS.Index < RHS.Index || + (LHS.Index == RHS.Index && LHS.DerefLevel < RHS.DerefLevel); +} +inline bool operator>(InterfaceValue LHS, InterfaceValue RHS) { + return RHS < LHS; +} +inline bool operator<=(InterfaceValue LHS, InterfaceValue RHS) { + return !(RHS < LHS); +} +inline bool operator>=(InterfaceValue LHS, InterfaceValue RHS) { + return !(LHS < RHS); +} + +// We use UnknownOffset to represent pointer offsets that cannot be determined +// at compile time. Note that MemoryLocation::UnknownSize cannot be used here +// because we require a signed value. +static const int64_t UnknownOffset = INT64_MAX; + +inline int64_t addOffset(int64_t LHS, int64_t RHS) { + if (LHS == UnknownOffset || RHS == UnknownOffset) + return UnknownOffset; + // FIXME: Do we need to guard against integer overflow here? + return LHS + RHS; +} + +/// We use ExternalRelation to describe an externally visible aliasing relations +/// between parameters/return value of a function. +struct ExternalRelation { + InterfaceValue From, To; + int64_t Offset; +}; + +inline bool operator==(ExternalRelation LHS, ExternalRelation RHS) { + return LHS.From == RHS.From && LHS.To == RHS.To && LHS.Offset == RHS.Offset; +} +inline bool operator!=(ExternalRelation LHS, ExternalRelation RHS) { + return !(LHS == RHS); +} +inline bool operator<(ExternalRelation LHS, ExternalRelation RHS) { + if (LHS.From < RHS.From) + return true; + if (LHS.From > RHS.From) + return false; + if (LHS.To < RHS.To) + return true; + if (LHS.To > RHS.To) + return false; + return LHS.Offset < RHS.Offset; +} +inline bool operator>(ExternalRelation LHS, ExternalRelation RHS) { + return RHS < LHS; +} +inline bool operator<=(ExternalRelation LHS, ExternalRelation RHS) { + return !(RHS < LHS); +} +inline bool operator>=(ExternalRelation LHS, ExternalRelation RHS) { + return !(LHS < RHS); +} + +/// We use ExternalAttribute to describe an externally visible AliasAttrs +/// for parameters/return value. +struct ExternalAttribute { + InterfaceValue IValue; + AliasAttrs Attr; +}; + +/// AliasSummary is just a collection of ExternalRelation and ExternalAttribute +struct AliasSummary { + // RetParamRelations is a collection of ExternalRelations. + SmallVector<ExternalRelation, 8> RetParamRelations; + + // RetParamAttributes is a collection of ExternalAttributes. + SmallVector<ExternalAttribute, 8> RetParamAttributes; +}; + +/// This is the result of instantiating InterfaceValue at a particular call +struct InstantiatedValue { + Value *Val; + unsigned DerefLevel; +}; +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue, + CallBase &Call); + +inline bool operator==(InstantiatedValue LHS, InstantiatedValue RHS) { + return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(LHS == RHS); +} +inline bool operator<(InstantiatedValue LHS, InstantiatedValue RHS) { + return std::less<Value *>()(LHS.Val, RHS.Val) || + (LHS.Val == RHS.Val && LHS.DerefLevel < RHS.DerefLevel); +} +inline bool operator>(InstantiatedValue LHS, InstantiatedValue RHS) { + return RHS < LHS; +} +inline bool operator<=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(RHS < LHS); +} +inline bool operator>=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(LHS < RHS); +} + +/// This is the result of instantiating ExternalRelation at a particular +/// callsite +struct InstantiatedRelation { + InstantiatedValue From, To; + int64_t Offset; +}; +Optional<InstantiatedRelation> +instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call); + +/// This is the result of instantiating ExternalAttribute at a particular +/// callsite +struct InstantiatedAttr { + InstantiatedValue IValue; + AliasAttrs Attr; +}; +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr, + CallBase &Call); +} + +template <> struct DenseMapInfo<cflaa::InstantiatedValue> { + static inline cflaa::InstantiatedValue getEmptyKey() { + return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getEmptyKey(), + DenseMapInfo<unsigned>::getEmptyKey()}; + } + static inline cflaa::InstantiatedValue getTombstoneKey() { + return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getTombstoneKey(), + DenseMapInfo<unsigned>::getTombstoneKey()}; + } + static unsigned getHashValue(const cflaa::InstantiatedValue &IV) { + return DenseMapInfo<std::pair<Value *, unsigned>>::getHashValue( + std::make_pair(IV.Val, IV.DerefLevel)); + } + static bool isEqual(const cflaa::InstantiatedValue &LHS, + const cflaa::InstantiatedValue &RHS) { + return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; + } +}; +} + +#endif diff --git a/llvm/lib/Analysis/AliasSetTracker.cpp b/llvm/lib/Analysis/AliasSetTracker.cpp new file mode 100644 index 000000000000..79fbcd464c1b --- /dev/null +++ b/llvm/lib/Analysis/AliasSetTracker.cpp @@ -0,0 +1,776 @@ +//===- AliasSetTracker.cpp - Alias Sets Tracker implementation-------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AliasSetTracker and AliasSet classes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <vector> + +using namespace llvm; + +static cl::opt<unsigned> + SaturationThreshold("alias-set-saturation-threshold", cl::Hidden, + cl::init(250), + cl::desc("The maximum number of pointers may-alias " + "sets may contain before degradation")); + +/// mergeSetIn - Merge the specified alias set into this alias set. +/// +void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { + assert(!AS.Forward && "Alias set is already forwarding!"); + assert(!Forward && "This set is a forwarding set!!"); + + bool WasMustAlias = (Alias == SetMustAlias); + // Update the alias and access types of this set... + Access |= AS.Access; + Alias |= AS.Alias; + + if (Alias == SetMustAlias) { + // Check that these two merged sets really are must aliases. Since both + // used to be must-alias sets, we can just check any pointer from each set + // for aliasing. + AliasAnalysis &AA = AST.getAliasAnalysis(); + PointerRec *L = getSomePointer(); + PointerRec *R = AS.getSomePointer(); + + // If the pointers are not a must-alias pair, this set becomes a may alias. + if (AA.alias(MemoryLocation(L->getValue(), L->getSize(), L->getAAInfo()), + MemoryLocation(R->getValue(), R->getSize(), R->getAAInfo())) != + MustAlias) + Alias = SetMayAlias; + } + + if (Alias == SetMayAlias) { + if (WasMustAlias) + AST.TotalMayAliasSetSize += size(); + if (AS.Alias == SetMustAlias) + AST.TotalMayAliasSetSize += AS.size(); + } + + bool ASHadUnknownInsts = !AS.UnknownInsts.empty(); + if (UnknownInsts.empty()) { // Merge call sites... + if (ASHadUnknownInsts) { + std::swap(UnknownInsts, AS.UnknownInsts); + addRef(); + } + } else if (ASHadUnknownInsts) { + UnknownInsts.insert(UnknownInsts.end(), AS.UnknownInsts.begin(), AS.UnknownInsts.end()); + AS.UnknownInsts.clear(); + } + + AS.Forward = this; // Forward across AS now... + addRef(); // AS is now pointing to us... + + // Merge the list of constituent pointers... + if (AS.PtrList) { + SetSize += AS.size(); + AS.SetSize = 0; + *PtrListEnd = AS.PtrList; + AS.PtrList->setPrevInList(PtrListEnd); + PtrListEnd = AS.PtrListEnd; + + AS.PtrList = nullptr; + AS.PtrListEnd = &AS.PtrList; + assert(*AS.PtrListEnd == nullptr && "End of list is not null?"); + } + if (ASHadUnknownInsts) + AS.dropRef(AST); +} + +void AliasSetTracker::removeAliasSet(AliasSet *AS) { + if (AliasSet *Fwd = AS->Forward) { + Fwd->dropRef(*this); + AS->Forward = nullptr; + } else // Update TotalMayAliasSetSize only if not forwarding. + if (AS->Alias == AliasSet::SetMayAlias) + TotalMayAliasSetSize -= AS->size(); + + AliasSets.erase(AS); + // If we've removed the saturated alias set, set saturated marker back to + // nullptr and ensure this tracker is empty. + if (AS == AliasAnyAS) { + AliasAnyAS = nullptr; + assert(AliasSets.empty() && "Tracker not empty"); + } +} + +void AliasSet::removeFromTracker(AliasSetTracker &AST) { + assert(RefCount == 0 && "Cannot remove non-dead alias set from tracker!"); + AST.removeAliasSet(this); +} + +void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, + LocationSize Size, const AAMDNodes &AAInfo, + bool KnownMustAlias, bool SkipSizeUpdate) { + assert(!Entry.hasAliasSet() && "Entry already in set!"); + + // Check to see if we have to downgrade to _may_ alias. + if (isMustAlias()) + if (PointerRec *P = getSomePointer()) { + if (!KnownMustAlias) { + AliasAnalysis &AA = AST.getAliasAnalysis(); + AliasResult Result = AA.alias( + MemoryLocation(P->getValue(), P->getSize(), P->getAAInfo()), + MemoryLocation(Entry.getValue(), Size, AAInfo)); + if (Result != MustAlias) { + Alias = SetMayAlias; + AST.TotalMayAliasSetSize += size(); + } + assert(Result != NoAlias && "Cannot be part of must set!"); + } else if (!SkipSizeUpdate) + P->updateSizeAndAAInfo(Size, AAInfo); + } + + Entry.setAliasSet(this); + Entry.updateSizeAndAAInfo(Size, AAInfo); + + // Add it to the end of the list... + ++SetSize; + assert(*PtrListEnd == nullptr && "End of list is not null?"); + *PtrListEnd = &Entry; + PtrListEnd = Entry.setPrevInList(PtrListEnd); + assert(*PtrListEnd == nullptr && "End of list is not null?"); + // Entry points to alias set. + addRef(); + + if (Alias == SetMayAlias) + AST.TotalMayAliasSetSize++; +} + +void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { + if (UnknownInsts.empty()) + addRef(); + UnknownInsts.emplace_back(I); + + // Guards are marked as modifying memory for control flow modelling purposes, + // but don't actually modify any specific memory location. + using namespace PatternMatch; + bool MayWriteMemory = I->mayWriteToMemory() && !isGuard(I) && + !(I->use_empty() && match(I, m_Intrinsic<Intrinsic::invariant_start>())); + if (!MayWriteMemory) { + Alias = SetMayAlias; + Access |= RefAccess; + return; + } + + // FIXME: This should use mod/ref information to make this not suck so bad + Alias = SetMayAlias; + Access = ModRefAccess; +} + +/// aliasesPointer - If the specified pointer "may" (or must) alias one of the +/// members in the set return the appropriate AliasResult. Otherwise return +/// NoAlias. +/// +AliasResult AliasSet::aliasesPointer(const Value *Ptr, LocationSize Size, + const AAMDNodes &AAInfo, + AliasAnalysis &AA) const { + if (AliasAny) + return MayAlias; + + if (Alias == SetMustAlias) { + assert(UnknownInsts.empty() && "Illegal must alias set!"); + + // If this is a set of MustAliases, only check to see if the pointer aliases + // SOME value in the set. + PointerRec *SomePtr = getSomePointer(); + assert(SomePtr && "Empty must-alias set??"); + return AA.alias(MemoryLocation(SomePtr->getValue(), SomePtr->getSize(), + SomePtr->getAAInfo()), + MemoryLocation(Ptr, Size, AAInfo)); + } + + // If this is a may-alias set, we have to check all of the pointers in the set + // to be sure it doesn't alias the set... + for (iterator I = begin(), E = end(); I != E; ++I) + if (AliasResult AR = AA.alias( + MemoryLocation(Ptr, Size, AAInfo), + MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo()))) + return AR; + + // Check the unknown instructions... + if (!UnknownInsts.empty()) { + for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) + if (auto *Inst = getUnknownInst(i)) + if (isModOrRefSet( + AA.getModRefInfo(Inst, MemoryLocation(Ptr, Size, AAInfo)))) + return MayAlias; + } + + return NoAlias; +} + +bool AliasSet::aliasesUnknownInst(const Instruction *Inst, + AliasAnalysis &AA) const { + + if (AliasAny) + return true; + + assert(Inst->mayReadOrWriteMemory() && + "Instruction must either read or write memory."); + + for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { + if (auto *UnknownInst = getUnknownInst(i)) { + const auto *C1 = dyn_cast<CallBase>(UnknownInst); + const auto *C2 = dyn_cast<CallBase>(Inst); + if (!C1 || !C2 || isModOrRefSet(AA.getModRefInfo(C1, C2)) || + isModOrRefSet(AA.getModRefInfo(C2, C1))) + return true; + } + } + + for (iterator I = begin(), E = end(); I != E; ++I) + if (isModOrRefSet(AA.getModRefInfo( + Inst, MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo())))) + return true; + + return false; +} + +Instruction* AliasSet::getUniqueInstruction() { + if (AliasAny) + // May have collapses alias set + return nullptr; + if (begin() != end()) { + if (!UnknownInsts.empty()) + // Another instruction found + return nullptr; + if (std::next(begin()) != end()) + // Another instruction found + return nullptr; + Value *Addr = begin()->getValue(); + assert(!Addr->user_empty() && + "where's the instruction which added this pointer?"); + if (std::next(Addr->user_begin()) != Addr->user_end()) + // Another instruction found -- this is really restrictive + // TODO: generalize! + return nullptr; + return cast<Instruction>(*(Addr->user_begin())); + } + if (1 != UnknownInsts.size()) + return nullptr; + return cast<Instruction>(UnknownInsts[0]); +} + +void AliasSetTracker::clear() { + // Delete all the PointerRec entries. + for (PointerMapType::iterator I = PointerMap.begin(), E = PointerMap.end(); + I != E; ++I) + I->second->eraseFromList(); + + PointerMap.clear(); + + // The alias sets should all be clear now. + AliasSets.clear(); +} + +/// mergeAliasSetsForPointer - Given a pointer, merge all alias sets that may +/// alias the pointer. Return the unified set, or nullptr if no set that aliases +/// the pointer was found. MustAliasAll is updated to true/false if the pointer +/// is found to MustAlias all the sets it merged. +AliasSet *AliasSetTracker::mergeAliasSetsForPointer(const Value *Ptr, + LocationSize Size, + const AAMDNodes &AAInfo, + bool &MustAliasAll) { + AliasSet *FoundSet = nullptr; + AliasResult AllAR = MustAlias; + for (iterator I = begin(), E = end(); I != E;) { + iterator Cur = I++; + if (Cur->Forward) + continue; + + AliasResult AR = Cur->aliasesPointer(Ptr, Size, AAInfo, AA); + if (AR == NoAlias) + continue; + + AllAR = + AliasResult(AllAR & AR); // Possible downgrade to May/Partial, even No + + if (!FoundSet) { + // If this is the first alias set ptr can go into, remember it. + FoundSet = &*Cur; + } else { + // Otherwise, we must merge the sets. + FoundSet->mergeSetIn(*Cur, *this); + } + } + + MustAliasAll = (AllAR == MustAlias); + return FoundSet; +} + +AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { + AliasSet *FoundSet = nullptr; + for (iterator I = begin(), E = end(); I != E;) { + iterator Cur = I++; + if (Cur->Forward || !Cur->aliasesUnknownInst(Inst, AA)) + continue; + if (!FoundSet) { + // If this is the first alias set ptr can go into, remember it. + FoundSet = &*Cur; + } else { + // Otherwise, we must merge the sets. + FoundSet->mergeSetIn(*Cur, *this); + } + } + return FoundSet; +} + +AliasSet &AliasSetTracker::getAliasSetFor(const MemoryLocation &MemLoc) { + + Value * const Pointer = const_cast<Value*>(MemLoc.Ptr); + const LocationSize Size = MemLoc.Size; + const AAMDNodes &AAInfo = MemLoc.AATags; + + AliasSet::PointerRec &Entry = getEntryFor(Pointer); + + if (AliasAnyAS) { + // At this point, the AST is saturated, so we only have one active alias + // set. That means we already know which alias set we want to return, and + // just need to add the pointer to that set to keep the data structure + // consistent. + // This, of course, means that we will never need a merge here. + if (Entry.hasAliasSet()) { + Entry.updateSizeAndAAInfo(Size, AAInfo); + assert(Entry.getAliasSet(*this) == AliasAnyAS && + "Entry in saturated AST must belong to only alias set"); + } else { + AliasAnyAS->addPointer(*this, Entry, Size, AAInfo); + } + return *AliasAnyAS; + } + + bool MustAliasAll = false; + // Check to see if the pointer is already known. + if (Entry.hasAliasSet()) { + // If the size changed, we may need to merge several alias sets. + // Note that we can *not* return the result of mergeAliasSetsForPointer + // due to a quirk of alias analysis behavior. Since alias(undef, undef) + // is NoAlias, mergeAliasSetsForPointer(undef, ...) will not find the + // the right set for undef, even if it exists. + if (Entry.updateSizeAndAAInfo(Size, AAInfo)) + mergeAliasSetsForPointer(Pointer, Size, AAInfo, MustAliasAll); + // Return the set! + return *Entry.getAliasSet(*this)->getForwardedTarget(*this); + } + + if (AliasSet *AS = + mergeAliasSetsForPointer(Pointer, Size, AAInfo, MustAliasAll)) { + // Add it to the alias set it aliases. + AS->addPointer(*this, Entry, Size, AAInfo, MustAliasAll); + return *AS; + } + + // Otherwise create a new alias set to hold the loaded pointer. + AliasSets.push_back(new AliasSet()); + AliasSets.back().addPointer(*this, Entry, Size, AAInfo, true); + return AliasSets.back(); +} + +void AliasSetTracker::add(Value *Ptr, LocationSize Size, + const AAMDNodes &AAInfo) { + addPointer(MemoryLocation(Ptr, Size, AAInfo), AliasSet::NoAccess); +} + +void AliasSetTracker::add(LoadInst *LI) { + if (isStrongerThanMonotonic(LI->getOrdering())) + return addUnknown(LI); + addPointer(MemoryLocation::get(LI), AliasSet::RefAccess); +} + +void AliasSetTracker::add(StoreInst *SI) { + if (isStrongerThanMonotonic(SI->getOrdering())) + return addUnknown(SI); + addPointer(MemoryLocation::get(SI), AliasSet::ModAccess); +} + +void AliasSetTracker::add(VAArgInst *VAAI) { + addPointer(MemoryLocation::get(VAAI), AliasSet::ModRefAccess); +} + +void AliasSetTracker::add(AnyMemSetInst *MSI) { + addPointer(MemoryLocation::getForDest(MSI), AliasSet::ModAccess); +} + +void AliasSetTracker::add(AnyMemTransferInst *MTI) { + addPointer(MemoryLocation::getForDest(MTI), AliasSet::ModAccess); + addPointer(MemoryLocation::getForSource(MTI), AliasSet::RefAccess); +} + +void AliasSetTracker::addUnknown(Instruction *Inst) { + if (isa<DbgInfoIntrinsic>(Inst)) + return; // Ignore DbgInfo Intrinsics. + + if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { + // These intrinsics will show up as affecting memory, but they are just + // markers. + switch (II->getIntrinsicID()) { + default: + break; + // FIXME: Add lifetime/invariant intrinsics (See: PR30807). + case Intrinsic::assume: + case Intrinsic::sideeffect: + return; + } + } + if (!Inst->mayReadOrWriteMemory()) + return; // doesn't alias anything + + if (AliasSet *AS = findAliasSetForUnknownInst(Inst)) { + AS->addUnknownInst(Inst, AA); + return; + } + AliasSets.push_back(new AliasSet()); + AliasSets.back().addUnknownInst(Inst, AA); +} + +void AliasSetTracker::add(Instruction *I) { + // Dispatch to one of the other add methods. + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return add(LI); + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return add(SI); + if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + return add(VAAI); + if (AnyMemSetInst *MSI = dyn_cast<AnyMemSetInst>(I)) + return add(MSI); + if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(I)) + return add(MTI); + + // Handle all calls with known mod/ref sets genericall + if (auto *Call = dyn_cast<CallBase>(I)) + if (Call->onlyAccessesArgMemory()) { + auto getAccessFromModRef = [](ModRefInfo MRI) { + if (isRefSet(MRI) && isModSet(MRI)) + return AliasSet::ModRefAccess; + else if (isModSet(MRI)) + return AliasSet::ModAccess; + else if (isRefSet(MRI)) + return AliasSet::RefAccess; + else + return AliasSet::NoAccess; + }; + + ModRefInfo CallMask = createModRefInfo(AA.getModRefBehavior(Call)); + + // Some intrinsics are marked as modifying memory for control flow + // modelling purposes, but don't actually modify any specific memory + // location. + using namespace PatternMatch; + if (Call->use_empty() && + match(Call, m_Intrinsic<Intrinsic::invariant_start>())) + CallMask = clearMod(CallMask); + + for (auto IdxArgPair : enumerate(Call->args())) { + int ArgIdx = IdxArgPair.index(); + const Value *Arg = IdxArgPair.value(); + if (!Arg->getType()->isPointerTy()) + continue; + MemoryLocation ArgLoc = + MemoryLocation::getForArgument(Call, ArgIdx, nullptr); + ModRefInfo ArgMask = AA.getArgModRefInfo(Call, ArgIdx); + ArgMask = intersectModRef(CallMask, ArgMask); + if (!isNoModRef(ArgMask)) + addPointer(ArgLoc, getAccessFromModRef(ArgMask)); + } + return; + } + + return addUnknown(I); +} + +void AliasSetTracker::add(BasicBlock &BB) { + for (auto &I : BB) + add(&I); +} + +void AliasSetTracker::add(const AliasSetTracker &AST) { + assert(&AA == &AST.AA && + "Merging AliasSetTracker objects with different Alias Analyses!"); + + // Loop over all of the alias sets in AST, adding the pointers contained + // therein into the current alias sets. This can cause alias sets to be + // merged together in the current AST. + for (const AliasSet &AS : AST) { + if (AS.Forward) + continue; // Ignore forwarding alias sets + + // If there are any call sites in the alias set, add them to this AST. + for (unsigned i = 0, e = AS.UnknownInsts.size(); i != e; ++i) + if (auto *Inst = AS.getUnknownInst(i)) + add(Inst); + + // Loop over all of the pointers in this alias set. + for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) + addPointer( + MemoryLocation(ASI.getPointer(), ASI.getSize(), ASI.getAAInfo()), + (AliasSet::AccessLattice)AS.Access); + } +} + +void AliasSetTracker::addAllInstructionsInLoopUsingMSSA() { + assert(MSSA && L && "MSSA and L must be available"); + for (const BasicBlock *BB : L->blocks()) + if (auto *Accesses = MSSA->getBlockAccesses(BB)) + for (auto &Access : *Accesses) + if (auto *MUD = dyn_cast<MemoryUseOrDef>(&Access)) + add(MUD->getMemoryInst()); +} + +// deleteValue method - This method is used to remove a pointer value from the +// AliasSetTracker entirely. It should be used when an instruction is deleted +// from the program to update the AST. If you don't use this, you would have +// dangling pointers to deleted instructions. +// +void AliasSetTracker::deleteValue(Value *PtrVal) { + // First, look up the PointerRec for this pointer. + PointerMapType::iterator I = PointerMap.find_as(PtrVal); + if (I == PointerMap.end()) return; // Noop + + // If we found one, remove the pointer from the alias set it is in. + AliasSet::PointerRec *PtrValEnt = I->second; + AliasSet *AS = PtrValEnt->getAliasSet(*this); + + // Unlink and delete from the list of values. + PtrValEnt->eraseFromList(); + + if (AS->Alias == AliasSet::SetMayAlias) { + AS->SetSize--; + TotalMayAliasSetSize--; + } + + // Stop using the alias set. + AS->dropRef(*this); + + PointerMap.erase(I); +} + +// copyValue - This method should be used whenever a preexisting value in the +// program is copied or cloned, introducing a new value. Note that it is ok for +// clients that use this method to introduce the same value multiple times: if +// the tracker already knows about a value, it will ignore the request. +// +void AliasSetTracker::copyValue(Value *From, Value *To) { + // First, look up the PointerRec for this pointer. + PointerMapType::iterator I = PointerMap.find_as(From); + if (I == PointerMap.end()) + return; // Noop + assert(I->second->hasAliasSet() && "Dead entry?"); + + AliasSet::PointerRec &Entry = getEntryFor(To); + if (Entry.hasAliasSet()) return; // Already in the tracker! + + // getEntryFor above may invalidate iterator \c I, so reinitialize it. + I = PointerMap.find_as(From); + // Add it to the alias set it aliases... + AliasSet *AS = I->second->getAliasSet(*this); + AS->addPointer(*this, Entry, I->second->getSize(), I->second->getAAInfo(), + true, true); +} + +AliasSet &AliasSetTracker::mergeAllAliasSets() { + assert(!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold) && + "Full merge should happen once, when the saturation threshold is " + "reached"); + + // Collect all alias sets, so that we can drop references with impunity + // without worrying about iterator invalidation. + std::vector<AliasSet *> ASVector; + ASVector.reserve(SaturationThreshold); + for (iterator I = begin(), E = end(); I != E; I++) + ASVector.push_back(&*I); + + // Copy all instructions and pointers into a new set, and forward all other + // sets to it. + AliasSets.push_back(new AliasSet()); + AliasAnyAS = &AliasSets.back(); + AliasAnyAS->Alias = AliasSet::SetMayAlias; + AliasAnyAS->Access = AliasSet::ModRefAccess; + AliasAnyAS->AliasAny = true; + + for (auto Cur : ASVector) { + // If Cur was already forwarding, just forward to the new AS instead. + AliasSet *FwdTo = Cur->Forward; + if (FwdTo) { + Cur->Forward = AliasAnyAS; + AliasAnyAS->addRef(); + FwdTo->dropRef(*this); + continue; + } + + // Otherwise, perform the actual merge. + AliasAnyAS->mergeSetIn(*Cur, *this); + } + + return *AliasAnyAS; +} + +AliasSet &AliasSetTracker::addPointer(MemoryLocation Loc, + AliasSet::AccessLattice E) { + AliasSet &AS = getAliasSetFor(Loc); + AS.Access |= E; + + if (!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold)) { + // The AST is now saturated. From here on, we conservatively consider all + // pointers to alias each-other. + return mergeAllAliasSets(); + } + + return AS; +} + +//===----------------------------------------------------------------------===// +// AliasSet/AliasSetTracker Printing Support +//===----------------------------------------------------------------------===// + +void AliasSet::print(raw_ostream &OS) const { + OS << " AliasSet[" << (const void*)this << ", " << RefCount << "] "; + OS << (Alias == SetMustAlias ? "must" : "may") << " alias, "; + switch (Access) { + case NoAccess: OS << "No access "; break; + case RefAccess: OS << "Ref "; break; + case ModAccess: OS << "Mod "; break; + case ModRefAccess: OS << "Mod/Ref "; break; + default: llvm_unreachable("Bad value for Access!"); + } + if (Forward) + OS << " forwarding to " << (void*)Forward; + + if (!empty()) { + OS << "Pointers: "; + for (iterator I = begin(), E = end(); I != E; ++I) { + if (I != begin()) OS << ", "; + I.getPointer()->printAsOperand(OS << "("); + if (I.getSize() == LocationSize::unknown()) + OS << ", unknown)"; + else + OS << ", " << I.getSize() << ")"; + } + } + if (!UnknownInsts.empty()) { + OS << "\n " << UnknownInsts.size() << " Unknown instructions: "; + for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { + if (i) OS << ", "; + if (auto *I = getUnknownInst(i)) { + if (I->hasName()) + I->printAsOperand(OS); + else + I->print(OS); + } + } + } + OS << "\n"; +} + +void AliasSetTracker::print(raw_ostream &OS) const { + OS << "Alias Set Tracker: " << AliasSets.size(); + if (AliasAnyAS) + OS << " (Saturated)"; + OS << " alias sets for " << PointerMap.size() << " pointer values.\n"; + for (const AliasSet &AS : *this) + AS.print(OS); + OS << "\n"; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void AliasSet::dump() const { print(dbgs()); } +LLVM_DUMP_METHOD void AliasSetTracker::dump() const { print(dbgs()); } +#endif + +//===----------------------------------------------------------------------===// +// ASTCallbackVH Class Implementation +//===----------------------------------------------------------------------===// + +void AliasSetTracker::ASTCallbackVH::deleted() { + assert(AST && "ASTCallbackVH called with a null AliasSetTracker!"); + AST->deleteValue(getValPtr()); + // this now dangles! +} + +void AliasSetTracker::ASTCallbackVH::allUsesReplacedWith(Value *V) { + AST->copyValue(getValPtr(), V); +} + +AliasSetTracker::ASTCallbackVH::ASTCallbackVH(Value *V, AliasSetTracker *ast) + : CallbackVH(V), AST(ast) {} + +AliasSetTracker::ASTCallbackVH & +AliasSetTracker::ASTCallbackVH::operator=(Value *V) { + return *this = ASTCallbackVH(V, AST); +} + +//===----------------------------------------------------------------------===// +// AliasSetPrinter Pass +//===----------------------------------------------------------------------===// + +namespace { + + class AliasSetPrinter : public FunctionPass { + AliasSetTracker *Tracker; + + public: + static char ID; // Pass identification, replacement for typeid + + AliasSetPrinter() : FunctionPass(ID) { + initializeAliasSetPrinterPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<AAResultsWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + auto &AAWP = getAnalysis<AAResultsWrapperPass>(); + Tracker = new AliasSetTracker(AAWP.getAAResults()); + errs() << "Alias sets for function '" << F.getName() << "':\n"; + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) + Tracker->add(&*I); + Tracker->print(errs()); + delete Tracker; + return false; + } + }; + +} // end anonymous namespace + +char AliasSetPrinter::ID = 0; + +INITIALIZE_PASS_BEGIN(AliasSetPrinter, "print-alias-sets", + "Alias Set Printer", false, true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(AliasSetPrinter, "print-alias-sets", + "Alias Set Printer", false, true) diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp new file mode 100644 index 000000000000..af718526684b --- /dev/null +++ b/llvm/lib/Analysis/Analysis.cpp @@ -0,0 +1,138 @@ +//===-- Analysis.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm-c/Analysis.h" +#include "llvm-c/Initialization.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/PassRegistry.h" +#include "llvm/Support/raw_ostream.h" +#include <cstring> + +using namespace llvm; + +/// initializeAnalysis - Initialize all passes linked into the Analysis library. +void llvm::initializeAnalysis(PassRegistry &Registry) { + initializeAAEvalLegacyPassPass(Registry); + initializeAliasSetPrinterPass(Registry); + initializeBasicAAWrapperPassPass(Registry); + initializeBlockFrequencyInfoWrapperPassPass(Registry); + initializeBranchProbabilityInfoWrapperPassPass(Registry); + initializeCallGraphWrapperPassPass(Registry); + initializeCallGraphDOTPrinterPass(Registry); + initializeCallGraphPrinterLegacyPassPass(Registry); + initializeCallGraphViewerPass(Registry); + initializeCostModelAnalysisPass(Registry); + initializeCFGViewerLegacyPassPass(Registry); + initializeCFGPrinterLegacyPassPass(Registry); + initializeCFGOnlyViewerLegacyPassPass(Registry); + initializeCFGOnlyPrinterLegacyPassPass(Registry); + initializeCFLAndersAAWrapperPassPass(Registry); + initializeCFLSteensAAWrapperPassPass(Registry); + initializeDependenceAnalysisWrapperPassPass(Registry); + initializeDelinearizationPass(Registry); + initializeDemandedBitsWrapperPassPass(Registry); + initializeDominanceFrontierWrapperPassPass(Registry); + initializeDomViewerPass(Registry); + initializeDomPrinterPass(Registry); + initializeDomOnlyViewerPass(Registry); + initializePostDomViewerPass(Registry); + initializeDomOnlyPrinterPass(Registry); + initializePostDomPrinterPass(Registry); + initializePostDomOnlyViewerPass(Registry); + initializePostDomOnlyPrinterPass(Registry); + initializeAAResultsWrapperPassPass(Registry); + initializeGlobalsAAWrapperPassPass(Registry); + initializeIVUsersWrapperPassPass(Registry); + initializeInstCountPass(Registry); + initializeIntervalPartitionPass(Registry); + initializeLazyBranchProbabilityInfoPassPass(Registry); + initializeLazyBlockFrequencyInfoPassPass(Registry); + initializeLazyValueInfoWrapperPassPass(Registry); + initializeLazyValueInfoPrinterPass(Registry); + initializeLegacyDivergenceAnalysisPass(Registry); + initializeLintPass(Registry); + initializeLoopInfoWrapperPassPass(Registry); + initializeMemDepPrinterPass(Registry); + initializeMemDerefPrinterPass(Registry); + initializeMemoryDependenceWrapperPassPass(Registry); + initializeModuleDebugInfoPrinterPass(Registry); + initializeModuleSummaryIndexWrapperPassPass(Registry); + initializeMustExecutePrinterPass(Registry); + initializeMustBeExecutedContextPrinterPass(Registry); + initializeObjCARCAAWrapperPassPass(Registry); + initializeOptimizationRemarkEmitterWrapperPassPass(Registry); + initializePhiValuesWrapperPassPass(Registry); + initializePostDominatorTreeWrapperPassPass(Registry); + initializeRegionInfoPassPass(Registry); + initializeRegionViewerPass(Registry); + initializeRegionPrinterPass(Registry); + initializeRegionOnlyViewerPass(Registry); + initializeRegionOnlyPrinterPass(Registry); + initializeSCEVAAWrapperPassPass(Registry); + initializeScalarEvolutionWrapperPassPass(Registry); + initializeStackSafetyGlobalInfoWrapperPassPass(Registry); + initializeStackSafetyInfoWrapperPassPass(Registry); + initializeTargetTransformInfoWrapperPassPass(Registry); + initializeTypeBasedAAWrapperPassPass(Registry); + initializeScopedNoAliasAAWrapperPassPass(Registry); + initializeLCSSAVerificationPassPass(Registry); + initializeMemorySSAWrapperPassPass(Registry); + initializeMemorySSAPrinterLegacyPassPass(Registry); +} + +void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { + initializeAnalysis(*unwrap(R)); +} + +void LLVMInitializeIPA(LLVMPassRegistryRef R) { + initializeAnalysis(*unwrap(R)); +} + +LLVMBool LLVMVerifyModule(LLVMModuleRef M, LLVMVerifierFailureAction Action, + char **OutMessages) { + raw_ostream *DebugOS = Action != LLVMReturnStatusAction ? &errs() : nullptr; + std::string Messages; + raw_string_ostream MsgsOS(Messages); + + LLVMBool Result = verifyModule(*unwrap(M), OutMessages ? &MsgsOS : DebugOS); + + // Duplicate the output to stderr. + if (DebugOS && OutMessages) + *DebugOS << MsgsOS.str(); + + if (Action == LLVMAbortProcessAction && Result) + report_fatal_error("Broken module found, compilation aborted!"); + + if (OutMessages) + *OutMessages = strdup(MsgsOS.str().c_str()); + + return Result; +} + +LLVMBool LLVMVerifyFunction(LLVMValueRef Fn, LLVMVerifierFailureAction Action) { + LLVMBool Result = verifyFunction( + *unwrap<Function>(Fn), Action != LLVMReturnStatusAction ? &errs() + : nullptr); + + if (Action == LLVMAbortProcessAction && Result) + report_fatal_error("Broken function found, compilation aborted!"); + + return Result; +} + +void LLVMViewFunctionCFG(LLVMValueRef Fn) { + Function *F = unwrap<Function>(Fn); + F->viewCFG(); +} + +void LLVMViewFunctionCFGOnly(LLVMValueRef Fn) { + Function *F = unwrap<Function>(Fn); + F->viewCFGOnly(); +} diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp new file mode 100644 index 000000000000..129944743c5e --- /dev/null +++ b/llvm/lib/Analysis/AssumptionCache.cpp @@ -0,0 +1,302 @@ +//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that keeps track of @llvm.assume intrinsics in +// the functions of a module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +static cl::opt<bool> + VerifyAssumptionCache("verify-assumption-cache", cl::Hidden, + cl::desc("Enable verification of assumption cache"), + cl::init(false)); + +SmallVector<WeakTrackingVH, 1> & +AssumptionCache::getOrInsertAffectedValues(Value *V) { + // Try using find_as first to avoid creating extra value handles just for the + // purpose of doing the lookup. + auto AVI = AffectedValues.find_as(V); + if (AVI != AffectedValues.end()) + return AVI->second; + + auto AVIP = AffectedValues.insert( + {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()}); + return AVIP.first->second; +} + +static void findAffectedValues(CallInst *CI, + SmallVectorImpl<Value *> &Affected) { + // Note: This code must be kept in-sync with the code in + // computeKnownBitsFromAssume in ValueTracking. + + auto AddAffected = [&Affected](Value *V) { + if (isa<Argument>(V)) { + Affected.push_back(V); + } else if (auto *I = dyn_cast<Instruction>(V)) { + Affected.push_back(I); + + // Peek through unary operators to find the source of the condition. + Value *Op; + if (match(I, m_BitCast(m_Value(Op))) || + match(I, m_PtrToInt(m_Value(Op))) || + match(I, m_Not(m_Value(Op)))) { + if (isa<Instruction>(Op) || isa<Argument>(Op)) + Affected.push_back(Op); + } + } + }; + + Value *Cond = CI->getArgOperand(0), *A, *B; + AddAffected(Cond); + + CmpInst::Predicate Pred; + if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + AddAffected(A); + AddAffected(B); + + if (Pred == ICmpInst::ICMP_EQ) { + // For equality comparisons, we handle the case of bit inversion. + auto AddAffectedFromEq = [&AddAffected](Value *V) { + Value *A; + if (match(V, m_Not(m_Value(A)))) { + AddAffected(A); + V = A; + } + + Value *B; + ConstantInt *C; + // (A & B) or (A | B) or (A ^ B). + if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) { + AddAffected(A); + AddAffected(B); + // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant. + } else if (match(V, m_Shift(m_Value(A), m_ConstantInt(C)))) { + AddAffected(A); + } + }; + + AddAffectedFromEq(A); + AddAffectedFromEq(B); + } + } +} + +void AssumptionCache::updateAffectedValues(CallInst *CI) { + SmallVector<Value *, 16> Affected; + findAffectedValues(CI, Affected); + + for (auto &AV : Affected) { + auto &AVV = getOrInsertAffectedValues(AV); + if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end()) + AVV.push_back(CI); + } +} + +void AssumptionCache::unregisterAssumption(CallInst *CI) { + SmallVector<Value *, 16> Affected; + findAffectedValues(CI, Affected); + + for (auto &AV : Affected) { + auto AVI = AffectedValues.find_as(AV); + if (AVI != AffectedValues.end()) + AffectedValues.erase(AVI); + } + + AssumeHandles.erase( + remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }), + AssumeHandles.end()); +} + +void AssumptionCache::AffectedValueCallbackVH::deleted() { + auto AVI = AC->AffectedValues.find(getValPtr()); + if (AVI != AC->AffectedValues.end()) + AC->AffectedValues.erase(AVI); + // 'this' now dangles! +} + +void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) { + auto &NAVV = getOrInsertAffectedValues(NV); + auto AVI = AffectedValues.find(OV); + if (AVI == AffectedValues.end()) + return; + + for (auto &A : AVI->second) + if (std::find(NAVV.begin(), NAVV.end(), A) == NAVV.end()) + NAVV.push_back(A); + AffectedValues.erase(OV); +} + +void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) { + if (!isa<Instruction>(NV) && !isa<Argument>(NV)) + return; + + // Any assumptions that affected this value now affect the new value. + + AC->transferAffectedValuesInCache(getValPtr(), NV); + // 'this' now might dangle! If the AffectedValues map was resized to add an + // entry for NV then this object might have been destroyed in favor of some + // copy in the grown map. +} + +void AssumptionCache::scanFunction() { + assert(!Scanned && "Tried to scan the function twice!"); + assert(AssumeHandles.empty() && "Already have assumes when scanning!"); + + // Go through all instructions in all blocks, add all calls to @llvm.assume + // to this cache. + for (BasicBlock &B : F) + for (Instruction &II : B) + if (match(&II, m_Intrinsic<Intrinsic::assume>())) + AssumeHandles.push_back(&II); + + // Mark the scan as complete. + Scanned = true; + + // Update affected values. + for (auto &A : AssumeHandles) + updateAffectedValues(cast<CallInst>(A)); +} + +void AssumptionCache::registerAssumption(CallInst *CI) { + assert(match(CI, m_Intrinsic<Intrinsic::assume>()) && + "Registered call does not call @llvm.assume"); + + // If we haven't scanned the function yet, just drop this assumption. It will + // be found when we scan later. + if (!Scanned) + return; + + AssumeHandles.push_back(CI); + +#ifndef NDEBUG + assert(CI->getParent() && + "Cannot register @llvm.assume call not in a basic block"); + assert(&F == CI->getParent()->getParent() && + "Cannot register @llvm.assume call not in this function"); + + // We expect the number of assumptions to be small, so in an asserts build + // check that we don't accumulate duplicates and that all assumptions point + // to the same function. + SmallPtrSet<Value *, 16> AssumptionSet; + for (auto &VH : AssumeHandles) { + if (!VH) + continue; + + assert(&F == cast<Instruction>(VH)->getParent()->getParent() && + "Cached assumption not inside this function!"); + assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) && + "Cached something other than a call to @llvm.assume!"); + assert(AssumptionSet.insert(VH).second && + "Cache contains multiple copies of a call!"); + } +#endif + + updateAffectedValues(CI); +} + +AnalysisKey AssumptionAnalysis::Key; + +PreservedAnalyses AssumptionPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + + OS << "Cached assumptions for function: " << F.getName() << "\n"; + for (auto &VH : AC.assumptions()) + if (VH) + OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n"; + + return PreservedAnalyses::all(); +} + +void AssumptionCacheTracker::FunctionCallbackVH::deleted() { + auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr())); + if (I != ACT->AssumptionCaches.end()) + ACT->AssumptionCaches.erase(I); + // 'this' now dangles! +} + +AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) { + // We probe the function map twice to try and avoid creating a value handle + // around the function in common cases. This makes insertion a bit slower, + // but if we have to insert we're going to scan the whole function so that + // shouldn't matter. + auto I = AssumptionCaches.find_as(&F); + if (I != AssumptionCaches.end()) + return *I->second; + + // Ok, build a new cache by scanning the function, insert it and the value + // handle into our map, and return the newly populated cache. + auto IP = AssumptionCaches.insert(std::make_pair( + FunctionCallbackVH(&F, this), std::make_unique<AssumptionCache>(F))); + assert(IP.second && "Scanning function already in the map?"); + return *IP.first->second; +} + +AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) { + auto I = AssumptionCaches.find_as(&F); + if (I != AssumptionCaches.end()) + return I->second.get(); + return nullptr; +} + +void AssumptionCacheTracker::verifyAnalysis() const { + // FIXME: In the long term the verifier should not be controllable with a + // flag. We should either fix all passes to correctly update the assumption + // cache and enable the verifier unconditionally or somehow arrange for the + // assumption list to be updated automatically by passes. + if (!VerifyAssumptionCache) + return; + + SmallPtrSet<const CallInst *, 4> AssumptionSet; + for (const auto &I : AssumptionCaches) { + for (auto &VH : I.second->assumptions()) + if (VH) + AssumptionSet.insert(cast<CallInst>(VH)); + + for (const BasicBlock &B : cast<Function>(*I.first)) + for (const Instruction &II : B) + if (match(&II, m_Intrinsic<Intrinsic::assume>()) && + !AssumptionSet.count(cast<CallInst>(&II))) + report_fatal_error("Assumption in scanned function not in cache"); + } +} + +AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) { + initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry()); +} + +AssumptionCacheTracker::~AssumptionCacheTracker() = default; + +char AssumptionCacheTracker::ID = 0; + +INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker", + "Assumption Cache Tracker", false, true) diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp new file mode 100644 index 000000000000..f3c30c258c19 --- /dev/null +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -0,0 +1,2100 @@ +//===- BasicAliasAnalysis.cpp - Stateless Alias Analysis Impl -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the primary stateless implementation of the +// Alias Analysis interface that implements identities (two different +// globals cannot alias, etc), but does no stateful analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/PhiValues.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/KnownBits.h" +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <utility> + +#define DEBUG_TYPE "basicaa" + +using namespace llvm; + +/// Enable analysis of recursive PHI nodes. +static cl::opt<bool> EnableRecPhiAnalysis("basicaa-recphi", cl::Hidden, + cl::init(false)); + +/// By default, even on 32-bit architectures we use 64-bit integers for +/// calculations. This will allow us to more-aggressively decompose indexing +/// expressions calculated using i64 values (e.g., long long in C) which is +/// common enough to worry about. +static cl::opt<bool> ForceAtLeast64Bits("basicaa-force-at-least-64b", + cl::Hidden, cl::init(true)); +static cl::opt<bool> DoubleCalcBits("basicaa-double-calc-bits", + cl::Hidden, cl::init(false)); + +/// SearchLimitReached / SearchTimes shows how often the limit of +/// to decompose GEPs is reached. It will affect the precision +/// of basic alias analysis. +STATISTIC(SearchLimitReached, "Number of times the limit to " + "decompose GEPs is reached"); +STATISTIC(SearchTimes, "Number of times a GEP is decomposed"); + +/// Cutoff after which to stop analysing a set of phi nodes potentially involved +/// in a cycle. Because we are analysing 'through' phi nodes, we need to be +/// careful with value equivalence. We use reachability to make sure a value +/// cannot be involved in a cycle. +const unsigned MaxNumPhiBBsValueReachabilityCheck = 20; + +// The max limit of the search depth in DecomposeGEPExpression() and +// GetUnderlyingObject(), both functions need to use the same search +// depth otherwise the algorithm in aliasGEP will assert. +static const unsigned MaxLookupSearchDepth = 6; + +bool BasicAAResult::invalidate(Function &Fn, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // We don't care if this analysis itself is preserved, it has no state. But + // we need to check that the analyses it depends on have been. Note that we + // may be created without handles to some analyses and in that case don't + // depend on them. + if (Inv.invalidate<AssumptionAnalysis>(Fn, PA) || + (DT && Inv.invalidate<DominatorTreeAnalysis>(Fn, PA)) || + (LI && Inv.invalidate<LoopAnalysis>(Fn, PA)) || + (PV && Inv.invalidate<PhiValuesAnalysis>(Fn, PA))) + return true; + + // Otherwise this analysis result remains valid. + return false; +} + +//===----------------------------------------------------------------------===// +// Useful predicates +//===----------------------------------------------------------------------===// + +/// Returns true if the pointer is to a function-local object that never +/// escapes from the function. +static bool isNonEscapingLocalObject( + const Value *V, + SmallDenseMap<const Value *, bool, 8> *IsCapturedCache = nullptr) { + SmallDenseMap<const Value *, bool, 8>::iterator CacheIt; + if (IsCapturedCache) { + bool Inserted; + std::tie(CacheIt, Inserted) = IsCapturedCache->insert({V, false}); + if (!Inserted) + // Found cached result, return it! + return CacheIt->second; + } + + // If this is a local allocation, check to see if it escapes. + if (isa<AllocaInst>(V) || isNoAliasCall(V)) { + // Set StoreCaptures to True so that we can assume in our callers that the + // pointer is not the result of a load instruction. Currently + // PointerMayBeCaptured doesn't have any special analysis for the + // StoreCaptures=false case; if it did, our callers could be refined to be + // more precise. + auto Ret = !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true); + if (IsCapturedCache) + CacheIt->second = Ret; + return Ret; + } + + // If this is an argument that corresponds to a byval or noalias argument, + // then it has not escaped before entering the function. Check if it escapes + // inside the function. + if (const Argument *A = dyn_cast<Argument>(V)) + if (A->hasByValAttr() || A->hasNoAliasAttr()) { + // Note even if the argument is marked nocapture, we still need to check + // for copies made inside the function. The nocapture attribute only + // specifies that there are no copies made that outlive the function. + auto Ret = !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true); + if (IsCapturedCache) + CacheIt->second = Ret; + return Ret; + } + + return false; +} + +/// Returns true if the pointer is one which would have been considered an +/// escape by isNonEscapingLocalObject. +static bool isEscapeSource(const Value *V) { + if (isa<CallBase>(V)) + return true; + + if (isa<Argument>(V)) + return true; + + // The load case works because isNonEscapingLocalObject considers all + // stores to be escapes (it passes true for the StoreCaptures argument + // to PointerMayBeCaptured). + if (isa<LoadInst>(V)) + return true; + + return false; +} + +/// Returns the size of the object specified by V or UnknownSize if unknown. +static uint64_t getObjectSize(const Value *V, const DataLayout &DL, + const TargetLibraryInfo &TLI, + bool NullIsValidLoc, + bool RoundToAlign = false) { + uint64_t Size; + ObjectSizeOpts Opts; + Opts.RoundToAlign = RoundToAlign; + Opts.NullIsUnknownSize = NullIsValidLoc; + if (getObjectSize(V, Size, DL, &TLI, Opts)) + return Size; + return MemoryLocation::UnknownSize; +} + +/// Returns true if we can prove that the object specified by V is smaller than +/// Size. +static bool isObjectSmallerThan(const Value *V, uint64_t Size, + const DataLayout &DL, + const TargetLibraryInfo &TLI, + bool NullIsValidLoc) { + // Note that the meanings of the "object" are slightly different in the + // following contexts: + // c1: llvm::getObjectSize() + // c2: llvm.objectsize() intrinsic + // c3: isObjectSmallerThan() + // c1 and c2 share the same meaning; however, the meaning of "object" in c3 + // refers to the "entire object". + // + // Consider this example: + // char *p = (char*)malloc(100) + // char *q = p+80; + // + // In the context of c1 and c2, the "object" pointed by q refers to the + // stretch of memory of q[0:19]. So, getObjectSize(q) should return 20. + // + // However, in the context of c3, the "object" refers to the chunk of memory + // being allocated. So, the "object" has 100 bytes, and q points to the middle + // the "object". In case q is passed to isObjectSmallerThan() as the 1st + // parameter, before the llvm::getObjectSize() is called to get the size of + // entire object, we should: + // - either rewind the pointer q to the base-address of the object in + // question (in this case rewind to p), or + // - just give up. It is up to caller to make sure the pointer is pointing + // to the base address the object. + // + // We go for 2nd option for simplicity. + if (!isIdentifiedObject(V)) + return false; + + // This function needs to use the aligned object size because we allow + // reads a bit past the end given sufficient alignment. + uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc, + /*RoundToAlign*/ true); + + return ObjectSize != MemoryLocation::UnknownSize && ObjectSize < Size; +} + +/// Return the minimal extent from \p V to the end of the underlying object, +/// assuming the result is used in an aliasing query. E.g., we do use the query +/// location size and the fact that null pointers cannot alias here. +static uint64_t getMinimalExtentFrom(const Value &V, + const LocationSize &LocSize, + const DataLayout &DL, + bool NullIsValidLoc) { + // If we have dereferenceability information we know a lower bound for the + // extent as accesses for a lower offset would be valid. We need to exclude + // the "or null" part if null is a valid pointer. + bool CanBeNull; + uint64_t DerefBytes = V.getPointerDereferenceableBytes(DL, CanBeNull); + DerefBytes = (CanBeNull && NullIsValidLoc) ? 0 : DerefBytes; + // If queried with a precise location size, we assume that location size to be + // accessed, thus valid. + if (LocSize.isPrecise()) + DerefBytes = std::max(DerefBytes, LocSize.getValue()); + return DerefBytes; +} + +/// Returns true if we can prove that the object specified by V has size Size. +static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, + const TargetLibraryInfo &TLI, bool NullIsValidLoc) { + uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc); + return ObjectSize != MemoryLocation::UnknownSize && ObjectSize == Size; +} + +//===----------------------------------------------------------------------===// +// GetElementPtr Instruction Decomposition and Analysis +//===----------------------------------------------------------------------===// + +/// Analyzes the specified value as a linear expression: "A*V + B", where A and +/// B are constant integers. +/// +/// Returns the scale and offset values as APInts and return V as a Value*, and +/// return whether we looked through any sign or zero extends. The incoming +/// Value is known to have IntegerType, and it may already be sign or zero +/// extended. +/// +/// Note that this looks through extends, so the high bits may not be +/// represented in the result. +/*static*/ const Value *BasicAAResult::GetLinearExpression( + const Value *V, APInt &Scale, APInt &Offset, unsigned &ZExtBits, + unsigned &SExtBits, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, DominatorTree *DT, bool &NSW, bool &NUW) { + assert(V->getType()->isIntegerTy() && "Not an integer value"); + + // Limit our recursion depth. + if (Depth == 6) { + Scale = 1; + Offset = 0; + return V; + } + + if (const ConstantInt *Const = dyn_cast<ConstantInt>(V)) { + // If it's a constant, just convert it to an offset and remove the variable. + // If we've been called recursively, the Offset bit width will be greater + // than the constant's (the Offset's always as wide as the outermost call), + // so we'll zext here and process any extension in the isa<SExtInst> & + // isa<ZExtInst> cases below. + Offset += Const->getValue().zextOrSelf(Offset.getBitWidth()); + assert(Scale == 0 && "Constant values don't have a scale"); + return V; + } + + if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) { + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) { + // If we've been called recursively, then Offset and Scale will be wider + // than the BOp operands. We'll always zext it here as we'll process sign + // extensions below (see the isa<SExtInst> / isa<ZExtInst> cases). + APInt RHS = RHSC->getValue().zextOrSelf(Offset.getBitWidth()); + + switch (BOp->getOpcode()) { + default: + // We don't understand this instruction, so we can't decompose it any + // further. + Scale = 1; + Offset = 0; + return V; + case Instruction::Or: + // X|C == X+C if all the bits in C are unset in X. Otherwise we can't + // analyze it. + if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), DL, 0, AC, + BOp, DT)) { + Scale = 1; + Offset = 0; + return V; + } + LLVM_FALLTHROUGH; + case Instruction::Add: + V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, + SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); + Offset += RHS; + break; + case Instruction::Sub: + V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, + SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); + Offset -= RHS; + break; + case Instruction::Mul: + V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, + SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); + Offset *= RHS; + Scale *= RHS; + break; + case Instruction::Shl: + V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, + SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); + + // We're trying to linearize an expression of the kind: + // shl i8 -128, 36 + // where the shift count exceeds the bitwidth of the type. + // We can't decompose this further (the expression would return + // a poison value). + if (Offset.getBitWidth() < RHS.getLimitedValue() || + Scale.getBitWidth() < RHS.getLimitedValue()) { + Scale = 1; + Offset = 0; + return V; + } + + Offset <<= RHS.getLimitedValue(); + Scale <<= RHS.getLimitedValue(); + // the semantics of nsw and nuw for left shifts don't match those of + // multiplications, so we won't propagate them. + NSW = NUW = false; + return V; + } + + if (isa<OverflowingBinaryOperator>(BOp)) { + NUW &= BOp->hasNoUnsignedWrap(); + NSW &= BOp->hasNoSignedWrap(); + } + return V; + } + } + + // Since GEP indices are sign extended anyway, we don't care about the high + // bits of a sign or zero extended value - just scales and offsets. The + // extensions have to be consistent though. + if (isa<SExtInst>(V) || isa<ZExtInst>(V)) { + Value *CastOp = cast<CastInst>(V)->getOperand(0); + unsigned NewWidth = V->getType()->getPrimitiveSizeInBits(); + unsigned SmallWidth = CastOp->getType()->getPrimitiveSizeInBits(); + unsigned OldZExtBits = ZExtBits, OldSExtBits = SExtBits; + const Value *Result = + GetLinearExpression(CastOp, Scale, Offset, ZExtBits, SExtBits, DL, + Depth + 1, AC, DT, NSW, NUW); + + // zext(zext(%x)) == zext(%x), and similarly for sext; we'll handle this + // by just incrementing the number of bits we've extended by. + unsigned ExtendedBy = NewWidth - SmallWidth; + + if (isa<SExtInst>(V) && ZExtBits == 0) { + // sext(sext(%x, a), b) == sext(%x, a + b) + + if (NSW) { + // We haven't sign-wrapped, so it's valid to decompose sext(%x + c) + // into sext(%x) + sext(c). We'll sext the Offset ourselves: + unsigned OldWidth = Offset.getBitWidth(); + Offset = Offset.trunc(SmallWidth).sext(NewWidth).zextOrSelf(OldWidth); + } else { + // We may have signed-wrapped, so don't decompose sext(%x + c) into + // sext(%x) + sext(c) + Scale = 1; + Offset = 0; + Result = CastOp; + ZExtBits = OldZExtBits; + SExtBits = OldSExtBits; + } + SExtBits += ExtendedBy; + } else { + // sext(zext(%x, a), b) = zext(zext(%x, a), b) = zext(%x, a + b) + + if (!NUW) { + // We may have unsigned-wrapped, so don't decompose zext(%x + c) into + // zext(%x) + zext(c) + Scale = 1; + Offset = 0; + Result = CastOp; + ZExtBits = OldZExtBits; + SExtBits = OldSExtBits; + } + ZExtBits += ExtendedBy; + } + + return Result; + } + + Scale = 1; + Offset = 0; + return V; +} + +/// To ensure a pointer offset fits in an integer of size PointerSize +/// (in bits) when that size is smaller than the maximum pointer size. This is +/// an issue, for example, in particular for 32b pointers with negative indices +/// that rely on two's complement wrap-arounds for precise alias information +/// where the maximum pointer size is 64b. +static APInt adjustToPointerSize(APInt Offset, unsigned PointerSize) { + assert(PointerSize <= Offset.getBitWidth() && "Invalid PointerSize!"); + unsigned ShiftBits = Offset.getBitWidth() - PointerSize; + return (Offset << ShiftBits).ashr(ShiftBits); +} + +static unsigned getMaxPointerSize(const DataLayout &DL) { + unsigned MaxPointerSize = DL.getMaxPointerSizeInBits(); + if (MaxPointerSize < 64 && ForceAtLeast64Bits) MaxPointerSize = 64; + if (DoubleCalcBits) MaxPointerSize *= 2; + + return MaxPointerSize; +} + +/// If V is a symbolic pointer expression, decompose it into a base pointer +/// with a constant offset and a number of scaled symbolic offsets. +/// +/// The scaled symbolic offsets (represented by pairs of a Value* and a scale +/// in the VarIndices vector) are Value*'s that are known to be scaled by the +/// specified amount, but which may have other unrepresented high bits. As +/// such, the gep cannot necessarily be reconstructed from its decomposed form. +/// +/// When DataLayout is around, this function is capable of analyzing everything +/// that GetUnderlyingObject can look through. To be able to do that +/// GetUnderlyingObject and DecomposeGEPExpression must use the same search +/// depth (MaxLookupSearchDepth). When DataLayout not is around, it just looks +/// through pointer casts. +bool BasicAAResult::DecomposeGEPExpression(const Value *V, + DecomposedGEP &Decomposed, const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { + // Limit recursion depth to limit compile time in crazy cases. + unsigned MaxLookup = MaxLookupSearchDepth; + SearchTimes++; + + unsigned MaxPointerSize = getMaxPointerSize(DL); + Decomposed.VarIndices.clear(); + do { + // See if this is a bitcast or GEP. + const Operator *Op = dyn_cast<Operator>(V); + if (!Op) { + // The only non-operator case we can handle are GlobalAliases. + if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (!GA->isInterposable()) { + V = GA->getAliasee(); + continue; + } + } + Decomposed.Base = V; + return false; + } + + if (Op->getOpcode() == Instruction::BitCast || + Op->getOpcode() == Instruction::AddrSpaceCast) { + V = Op->getOperand(0); + continue; + } + + const GEPOperator *GEPOp = dyn_cast<GEPOperator>(Op); + if (!GEPOp) { + if (const auto *Call = dyn_cast<CallBase>(V)) { + // CaptureTracking can know about special capturing properties of some + // intrinsics like launder.invariant.group, that can't be expressed with + // the attributes, but have properties like returning aliasing pointer. + // Because some analysis may assume that nocaptured pointer is not + // returned from some special intrinsic (because function would have to + // be marked with returns attribute), it is crucial to use this function + // because it should be in sync with CaptureTracking. Not using it may + // cause weird miscompilations where 2 aliasing pointers are assumed to + // noalias. + if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) { + V = RP; + continue; + } + } + + // If it's not a GEP, hand it off to SimplifyInstruction to see if it + // can come up with something. This matches what GetUnderlyingObject does. + if (const Instruction *I = dyn_cast<Instruction>(V)) + // TODO: Get a DominatorTree and AssumptionCache and use them here + // (these are both now available in this function, but this should be + // updated when GetUnderlyingObject is updated). TLI should be + // provided also. + if (const Value *Simplified = + SimplifyInstruction(const_cast<Instruction *>(I), DL)) { + V = Simplified; + continue; + } + + Decomposed.Base = V; + return false; + } + + // Don't attempt to analyze GEPs over unsized objects. + if (!GEPOp->getSourceElementType()->isSized()) { + Decomposed.Base = V; + return false; + } + + unsigned AS = GEPOp->getPointerAddressSpace(); + // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices. + gep_type_iterator GTI = gep_type_begin(GEPOp); + unsigned PointerSize = DL.getPointerSizeInBits(AS); + // Assume all GEP operands are constants until proven otherwise. + bool GepHasConstantOffset = true; + for (User::const_op_iterator I = GEPOp->op_begin() + 1, E = GEPOp->op_end(); + I != E; ++I, ++GTI) { + const Value *Index = *I; + // Compute the (potentially symbolic) offset in bytes for this index. + if (StructType *STy = GTI.getStructTypeOrNull()) { + // For a struct, add the member offset. + unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue(); + if (FieldNo == 0) + continue; + + Decomposed.StructOffset += + DL.getStructLayout(STy)->getElementOffset(FieldNo); + continue; + } + + // For an array/pointer, add the element offset, explicitly scaled. + if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Index)) { + if (CIdx->isZero()) + continue; + Decomposed.OtherOffset += + (DL.getTypeAllocSize(GTI.getIndexedType()) * + CIdx->getValue().sextOrSelf(MaxPointerSize)) + .sextOrTrunc(MaxPointerSize); + continue; + } + + GepHasConstantOffset = false; + + APInt Scale(MaxPointerSize, DL.getTypeAllocSize(GTI.getIndexedType())); + unsigned ZExtBits = 0, SExtBits = 0; + + // If the integer type is smaller than the pointer size, it is implicitly + // sign extended to pointer size. + unsigned Width = Index->getType()->getIntegerBitWidth(); + if (PointerSize > Width) + SExtBits += PointerSize - Width; + + // Use GetLinearExpression to decompose the index into a C1*V+C2 form. + APInt IndexScale(Width, 0), IndexOffset(Width, 0); + bool NSW = true, NUW = true; + const Value *OrigIndex = Index; + Index = GetLinearExpression(Index, IndexScale, IndexOffset, ZExtBits, + SExtBits, DL, 0, AC, DT, NSW, NUW); + + // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. + // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. + + // It can be the case that, even through C1*V+C2 does not overflow for + // relevant values of V, (C2*Scale) can overflow. In that case, we cannot + // decompose the expression in this way. + // + // FIXME: C1*Scale and the other operations in the decomposed + // (C1*Scale)*V+C2*Scale can also overflow. We should check for this + // possibility. + APInt WideScaledOffset = IndexOffset.sextOrTrunc(MaxPointerSize*2) * + Scale.sext(MaxPointerSize*2); + if (WideScaledOffset.getMinSignedBits() > MaxPointerSize) { + Index = OrigIndex; + IndexScale = 1; + IndexOffset = 0; + + ZExtBits = SExtBits = 0; + if (PointerSize > Width) + SExtBits += PointerSize - Width; + } else { + Decomposed.OtherOffset += IndexOffset.sextOrTrunc(MaxPointerSize) * Scale; + Scale *= IndexScale.sextOrTrunc(MaxPointerSize); + } + + // If we already had an occurrence of this index variable, merge this + // scale into it. For example, we want to handle: + // A[x][x] -> x*16 + x*4 -> x*20 + // This also ensures that 'x' only appears in the index list once. + for (unsigned i = 0, e = Decomposed.VarIndices.size(); i != e; ++i) { + if (Decomposed.VarIndices[i].V == Index && + Decomposed.VarIndices[i].ZExtBits == ZExtBits && + Decomposed.VarIndices[i].SExtBits == SExtBits) { + Scale += Decomposed.VarIndices[i].Scale; + Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i); + break; + } + } + + // Make sure that we have a scale that makes sense for this target's + // pointer size. + Scale = adjustToPointerSize(Scale, PointerSize); + + if (!!Scale) { + VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, Scale}; + Decomposed.VarIndices.push_back(Entry); + } + } + + // Take care of wrap-arounds + if (GepHasConstantOffset) { + Decomposed.StructOffset = + adjustToPointerSize(Decomposed.StructOffset, PointerSize); + Decomposed.OtherOffset = + adjustToPointerSize(Decomposed.OtherOffset, PointerSize); + } + + // Analyze the base pointer next. + V = GEPOp->getOperand(0); + } while (--MaxLookup); + + // If the chain of expressions is too deep, just return early. + Decomposed.Base = V; + SearchLimitReached++; + return true; +} + +/// Returns whether the given pointer value points to memory that is local to +/// the function, with global constants being considered local to all +/// functions. +bool BasicAAResult::pointsToConstantMemory(const MemoryLocation &Loc, + AAQueryInfo &AAQI, bool OrLocal) { + assert(Visited.empty() && "Visited must be cleared after use!"); + + unsigned MaxLookup = 8; + SmallVector<const Value *, 16> Worklist; + Worklist.push_back(Loc.Ptr); + do { + const Value *V = GetUnderlyingObject(Worklist.pop_back_val(), DL); + if (!Visited.insert(V).second) { + Visited.clear(); + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + } + + // An alloca instruction defines local memory. + if (OrLocal && isa<AllocaInst>(V)) + continue; + + // A global constant counts as local memory for our purposes. + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) { + // Note: this doesn't require GV to be "ODR" because it isn't legal for a + // global to be marked constant in some modules and non-constant in + // others. GV may even be a declaration, not a definition. + if (!GV->isConstant()) { + Visited.clear(); + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + } + continue; + } + + // If both select values point to local memory, then so does the select. + if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { + Worklist.push_back(SI->getTrueValue()); + Worklist.push_back(SI->getFalseValue()); + continue; + } + + // If all values incoming to a phi node point to local memory, then so does + // the phi. + if (const PHINode *PN = dyn_cast<PHINode>(V)) { + // Don't bother inspecting phi nodes with many operands. + if (PN->getNumIncomingValues() > MaxLookup) { + Visited.clear(); + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + } + for (Value *IncValue : PN->incoming_values()) + Worklist.push_back(IncValue); + continue; + } + + // Otherwise be conservative. + Visited.clear(); + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + } while (!Worklist.empty() && --MaxLookup); + + Visited.clear(); + return Worklist.empty(); +} + +/// Returns the behavior when calling the given call site. +FunctionModRefBehavior BasicAAResult::getModRefBehavior(const CallBase *Call) { + if (Call->doesNotAccessMemory()) + // Can't do better than this. + return FMRB_DoesNotAccessMemory; + + FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + + // If the callsite knows it only reads memory, don't return worse + // than that. + if (Call->onlyReadsMemory()) + Min = FMRB_OnlyReadsMemory; + else if (Call->doesNotReadMemory()) + Min = FMRB_DoesNotReadMemory; + + if (Call->onlyAccessesArgMemory()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); + else if (Call->onlyAccessesInaccessibleMemory()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem); + else if (Call->onlyAccessesInaccessibleMemOrArgMem()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem); + + // If the call has operand bundles then aliasing attributes from the function + // it calls do not directly apply to the call. This can be made more precise + // in the future. + if (!Call->hasOperandBundles()) + if (const Function *F = Call->getCalledFunction()) + Min = + FunctionModRefBehavior(Min & getBestAAResults().getModRefBehavior(F)); + + return Min; +} + +/// Returns the behavior when calling the given function. For use when the call +/// site is not known. +FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) { + // If the function declares it doesn't access memory, we can't do better. + if (F->doesNotAccessMemory()) + return FMRB_DoesNotAccessMemory; + + FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + + // If the function declares it only reads memory, go with that. + if (F->onlyReadsMemory()) + Min = FMRB_OnlyReadsMemory; + else if (F->doesNotReadMemory()) + Min = FMRB_DoesNotReadMemory; + + if (F->onlyAccessesArgMemory()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); + else if (F->onlyAccessesInaccessibleMemory()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem); + else if (F->onlyAccessesInaccessibleMemOrArgMem()) + Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem); + + return Min; +} + +/// Returns true if this is a writeonly (i.e Mod only) parameter. +static bool isWriteOnlyParam(const CallBase *Call, unsigned ArgIdx, + const TargetLibraryInfo &TLI) { + if (Call->paramHasAttr(ArgIdx, Attribute::WriteOnly)) + return true; + + // We can bound the aliasing properties of memset_pattern16 just as we can + // for memcpy/memset. This is particularly important because the + // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 + // whenever possible. + // FIXME Consider handling this in InferFunctionAttr.cpp together with other + // attributes. + LibFunc F; + if (Call->getCalledFunction() && + TLI.getLibFunc(*Call->getCalledFunction(), F) && + F == LibFunc_memset_pattern16 && TLI.has(F)) + if (ArgIdx == 0) + return true; + + // TODO: memset_pattern4, memset_pattern8 + // TODO: _chk variants + // TODO: strcmp, strcpy + + return false; +} + +ModRefInfo BasicAAResult::getArgModRefInfo(const CallBase *Call, + unsigned ArgIdx) { + // Checking for known builtin intrinsics and target library functions. + if (isWriteOnlyParam(Call, ArgIdx, TLI)) + return ModRefInfo::Mod; + + if (Call->paramHasAttr(ArgIdx, Attribute::ReadOnly)) + return ModRefInfo::Ref; + + if (Call->paramHasAttr(ArgIdx, Attribute::ReadNone)) + return ModRefInfo::NoModRef; + + return AAResultBase::getArgModRefInfo(Call, ArgIdx); +} + +static bool isIntrinsicCall(const CallBase *Call, Intrinsic::ID IID) { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call); + return II && II->getIntrinsicID() == IID; +} + +#ifndef NDEBUG +static const Function *getParent(const Value *V) { + if (const Instruction *inst = dyn_cast<Instruction>(V)) { + if (!inst->getParent()) + return nullptr; + return inst->getParent()->getParent(); + } + + if (const Argument *arg = dyn_cast<Argument>(V)) + return arg->getParent(); + + return nullptr; +} + +static bool notDifferentParent(const Value *O1, const Value *O2) { + + const Function *F1 = getParent(O1); + const Function *F2 = getParent(O2); + + return !F1 || !F2 || F1 == F2; +} +#endif + +AliasResult BasicAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + assert(notDifferentParent(LocA.Ptr, LocB.Ptr) && + "BasicAliasAnalysis doesn't support interprocedural queries."); + + // If we have a directly cached entry for these locations, we have recursed + // through this once, so just return the cached results. Notably, when this + // happens, we don't clear the cache. + auto CacheIt = AAQI.AliasCache.find(AAQueryInfo::LocPair(LocA, LocB)); + if (CacheIt != AAQI.AliasCache.end()) + return CacheIt->second; + + CacheIt = AAQI.AliasCache.find(AAQueryInfo::LocPair(LocB, LocA)); + if (CacheIt != AAQI.AliasCache.end()) + return CacheIt->second; + + AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.AATags, LocB.Ptr, + LocB.Size, LocB.AATags, AAQI); + + VisitedPhiBBs.clear(); + return Alias; +} + +/// Checks to see if the specified callsite can clobber the specified memory +/// object. +/// +/// Since we only look at local properties of this function, we really can't +/// say much about this query. We do, however, use simple "address taken" +/// analysis on local objects. +ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + assert(notDifferentParent(Call, Loc.Ptr) && + "AliasAnalysis query involving multiple functions!"); + + const Value *Object = GetUnderlyingObject(Loc.Ptr, DL); + + // Calls marked 'tail' cannot read or write allocas from the current frame + // because the current frame might be destroyed by the time they run. However, + // a tail call may use an alloca with byval. Calling with byval copies the + // contents of the alloca into argument registers or stack slots, so there is + // no lifetime issue. + if (isa<AllocaInst>(Object)) + if (const CallInst *CI = dyn_cast<CallInst>(Call)) + if (CI->isTailCall() && + !CI->getAttributes().hasAttrSomewhere(Attribute::ByVal)) + return ModRefInfo::NoModRef; + + // Stack restore is able to modify unescaped dynamic allocas. Assume it may + // modify them even though the alloca is not escaped. + if (auto *AI = dyn_cast<AllocaInst>(Object)) + if (!AI->isStaticAlloca() && isIntrinsicCall(Call, Intrinsic::stackrestore)) + return ModRefInfo::Mod; + + // If the pointer is to a locally allocated object that does not escape, + // then the call can not mod/ref the pointer unless the call takes the pointer + // as an argument, and itself doesn't capture it. + if (!isa<Constant>(Object) && Call != Object && + isNonEscapingLocalObject(Object, &AAQI.IsCapturedCache)) { + + // Optimistically assume that call doesn't touch Object and check this + // assumption in the following loop. + ModRefInfo Result = ModRefInfo::NoModRef; + bool IsMustAlias = true; + + unsigned OperandNo = 0; + for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end(); + CI != CE; ++CI, ++OperandNo) { + // Only look at the no-capture or byval pointer arguments. If this + // pointer were passed to arguments that were neither of these, then it + // couldn't be no-capture. + if (!(*CI)->getType()->isPointerTy() || + (!Call->doesNotCapture(OperandNo) && + OperandNo < Call->getNumArgOperands() && + !Call->isByValArgument(OperandNo))) + continue; + + // Call doesn't access memory through this operand, so we don't care + // if it aliases with Object. + if (Call->doesNotAccessMemory(OperandNo)) + continue; + + // If this is a no-capture pointer argument, see if we can tell that it + // is impossible to alias the pointer we're checking. + AliasResult AR = getBestAAResults().alias(MemoryLocation(*CI), + MemoryLocation(Object), AAQI); + if (AR != MustAlias) + IsMustAlias = false; + // Operand doesn't alias 'Object', continue looking for other aliases + if (AR == NoAlias) + continue; + // Operand aliases 'Object', but call doesn't modify it. Strengthen + // initial assumption and keep looking in case if there are more aliases. + if (Call->onlyReadsMemory(OperandNo)) { + Result = setRef(Result); + continue; + } + // Operand aliases 'Object' but call only writes into it. + if (Call->doesNotReadMemory(OperandNo)) { + Result = setMod(Result); + continue; + } + // This operand aliases 'Object' and call reads and writes into it. + // Setting ModRef will not yield an early return below, MustAlias is not + // used further. + Result = ModRefInfo::ModRef; + break; + } + + // No operand aliases, reset Must bit. Add below if at least one aliases + // and all aliases found are MustAlias. + if (isNoModRef(Result)) + IsMustAlias = false; + + // Early return if we improved mod ref information + if (!isModAndRefSet(Result)) { + if (isNoModRef(Result)) + return ModRefInfo::NoModRef; + return IsMustAlias ? setMust(Result) : clearMust(Result); + } + } + + // If the call is to malloc or calloc, we can assume that it doesn't + // modify any IR visible value. This is only valid because we assume these + // routines do not read values visible in the IR. TODO: Consider special + // casing realloc and strdup routines which access only their arguments as + // well. Or alternatively, replace all of this with inaccessiblememonly once + // that's implemented fully. + if (isMallocOrCallocLikeFn(Call, &TLI)) { + // Be conservative if the accessed pointer may alias the allocation - + // fallback to the generic handling below. + if (getBestAAResults().alias(MemoryLocation(Call), Loc, AAQI) == NoAlias) + return ModRefInfo::NoModRef; + } + + // The semantics of memcpy intrinsics forbid overlap between their respective + // operands, i.e., source and destination of any given memcpy must no-alias. + // If Loc must-aliases either one of these two locations, then it necessarily + // no-aliases the other. + if (auto *Inst = dyn_cast<AnyMemCpyInst>(Call)) { + AliasResult SrcAA, DestAA; + + if ((SrcAA = getBestAAResults().alias(MemoryLocation::getForSource(Inst), + Loc, AAQI)) == MustAlias) + // Loc is exactly the memcpy source thus disjoint from memcpy dest. + return ModRefInfo::Ref; + if ((DestAA = getBestAAResults().alias(MemoryLocation::getForDest(Inst), + Loc, AAQI)) == MustAlias) + // The converse case. + return ModRefInfo::Mod; + + // It's also possible for Loc to alias both src and dest, or neither. + ModRefInfo rv = ModRefInfo::NoModRef; + if (SrcAA != NoAlias) + rv = setRef(rv); + if (DestAA != NoAlias) + rv = setMod(rv); + return rv; + } + + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isIntrinsicCall(Call, Intrinsic::assume)) + return ModRefInfo::NoModRef; + + // Like assumes, guard intrinsics are also marked as arbitrarily writing so + // that proper control dependencies are maintained but they never mods any + // particular memory location. + // + // *Unlike* assumes, guard intrinsics are modeled as reading memory since the + // heap state at the point the guard is issued needs to be consistent in case + // the guard invokes the "deopt" continuation. + if (isIntrinsicCall(Call, Intrinsic::experimental_guard)) + return ModRefInfo::Ref; + + // Like assumes, invariant.start intrinsics were also marked as arbitrarily + // writing so that proper control dependencies are maintained but they never + // mod any particular memory location visible to the IR. + // *Unlike* assumes (which are now modeled as NoModRef), invariant.start + // intrinsic is now modeled as reading memory. This prevents hoisting the + // invariant.start intrinsic over stores. Consider: + // *ptr = 40; + // *ptr = 50; + // invariant_start(ptr) + // int val = *ptr; + // print(val); + // + // This cannot be transformed to: + // + // *ptr = 40; + // invariant_start(ptr) + // *ptr = 50; + // int val = *ptr; + // print(val); + // + // The transformation will cause the second store to be ignored (based on + // rules of invariant.start) and print 40, while the first program always + // prints 50. + if (isIntrinsicCall(Call, Intrinsic::invariant_start)) + return ModRefInfo::Ref; + + // The AAResultBase base class has some smarts, lets use them. + return AAResultBase::getModRefInfo(Call, Loc, AAQI); +} + +ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2, + AAQueryInfo &AAQI) { + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isIntrinsicCall(Call1, Intrinsic::assume) || + isIntrinsicCall(Call2, Intrinsic::assume)) + return ModRefInfo::NoModRef; + + // Like assumes, guard intrinsics are also marked as arbitrarily writing so + // that proper control dependencies are maintained but they never mod any + // particular memory location. + // + // *Unlike* assumes, guard intrinsics are modeled as reading memory since the + // heap state at the point the guard is issued needs to be consistent in case + // the guard invokes the "deopt" continuation. + + // NB! This function is *not* commutative, so we special case two + // possibilities for guard intrinsics. + + if (isIntrinsicCall(Call1, Intrinsic::experimental_guard)) + return isModSet(createModRefInfo(getModRefBehavior(Call2))) + ? ModRefInfo::Ref + : ModRefInfo::NoModRef; + + if (isIntrinsicCall(Call2, Intrinsic::experimental_guard)) + return isModSet(createModRefInfo(getModRefBehavior(Call1))) + ? ModRefInfo::Mod + : ModRefInfo::NoModRef; + + // The AAResultBase base class has some smarts, lets use them. + return AAResultBase::getModRefInfo(Call1, Call2, AAQI); +} + +/// Provide ad-hoc rules to disambiguate accesses through two GEP operators, +/// both having the exact same pointer operand. +static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, + LocationSize MaybeV1Size, + const GEPOperator *GEP2, + LocationSize MaybeV2Size, + const DataLayout &DL) { + assert(GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() == + GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() && + GEP1->getPointerOperandType() == GEP2->getPointerOperandType() && + "Expected GEPs with the same pointer operand"); + + // Try to determine whether GEP1 and GEP2 index through arrays, into structs, + // such that the struct field accesses provably cannot alias. + // We also need at least two indices (the pointer, and the struct field). + if (GEP1->getNumIndices() != GEP2->getNumIndices() || + GEP1->getNumIndices() < 2) + return MayAlias; + + // If we don't know the size of the accesses through both GEPs, we can't + // determine whether the struct fields accessed can't alias. + if (MaybeV1Size == LocationSize::unknown() || + MaybeV2Size == LocationSize::unknown()) + return MayAlias; + + const uint64_t V1Size = MaybeV1Size.getValue(); + const uint64_t V2Size = MaybeV2Size.getValue(); + + ConstantInt *C1 = + dyn_cast<ConstantInt>(GEP1->getOperand(GEP1->getNumOperands() - 1)); + ConstantInt *C2 = + dyn_cast<ConstantInt>(GEP2->getOperand(GEP2->getNumOperands() - 1)); + + // If the last (struct) indices are constants and are equal, the other indices + // might be also be dynamically equal, so the GEPs can alias. + if (C1 && C2) { + unsigned BitWidth = std::max(C1->getBitWidth(), C2->getBitWidth()); + if (C1->getValue().sextOrSelf(BitWidth) == + C2->getValue().sextOrSelf(BitWidth)) + return MayAlias; + } + + // Find the last-indexed type of the GEP, i.e., the type you'd get if + // you stripped the last index. + // On the way, look at each indexed type. If there's something other + // than an array, different indices can lead to different final types. + SmallVector<Value *, 8> IntermediateIndices; + + // Insert the first index; we don't need to check the type indexed + // through it as it only drops the pointer indirection. + assert(GEP1->getNumIndices() > 1 && "Not enough GEP indices to examine"); + IntermediateIndices.push_back(GEP1->getOperand(1)); + + // Insert all the remaining indices but the last one. + // Also, check that they all index through arrays. + for (unsigned i = 1, e = GEP1->getNumIndices() - 1; i != e; ++i) { + if (!isa<ArrayType>(GetElementPtrInst::getIndexedType( + GEP1->getSourceElementType(), IntermediateIndices))) + return MayAlias; + IntermediateIndices.push_back(GEP1->getOperand(i + 1)); + } + + auto *Ty = GetElementPtrInst::getIndexedType( + GEP1->getSourceElementType(), IntermediateIndices); + StructType *LastIndexedStruct = dyn_cast<StructType>(Ty); + + if (isa<SequentialType>(Ty)) { + // We know that: + // - both GEPs begin indexing from the exact same pointer; + // - the last indices in both GEPs are constants, indexing into a sequential + // type (array or pointer); + // - both GEPs only index through arrays prior to that. + // + // Because array indices greater than the number of elements are valid in + // GEPs, unless we know the intermediate indices are identical between + // GEP1 and GEP2 we cannot guarantee that the last indexed arrays don't + // partially overlap. We also need to check that the loaded size matches + // the element size, otherwise we could still have overlap. + const uint64_t ElementSize = + DL.getTypeStoreSize(cast<SequentialType>(Ty)->getElementType()); + if (V1Size != ElementSize || V2Size != ElementSize) + return MayAlias; + + for (unsigned i = 0, e = GEP1->getNumIndices() - 1; i != e; ++i) + if (GEP1->getOperand(i + 1) != GEP2->getOperand(i + 1)) + return MayAlias; + + // Now we know that the array/pointer that GEP1 indexes into and that + // that GEP2 indexes into must either precisely overlap or be disjoint. + // Because they cannot partially overlap and because fields in an array + // cannot overlap, if we can prove the final indices are different between + // GEP1 and GEP2, we can conclude GEP1 and GEP2 don't alias. + + // If the last indices are constants, we've already checked they don't + // equal each other so we can exit early. + if (C1 && C2) + return NoAlias; + { + Value *GEP1LastIdx = GEP1->getOperand(GEP1->getNumOperands() - 1); + Value *GEP2LastIdx = GEP2->getOperand(GEP2->getNumOperands() - 1); + if (isa<PHINode>(GEP1LastIdx) || isa<PHINode>(GEP2LastIdx)) { + // If one of the indices is a PHI node, be safe and only use + // computeKnownBits so we don't make any assumptions about the + // relationships between the two indices. This is important if we're + // asking about values from different loop iterations. See PR32314. + // TODO: We may be able to change the check so we only do this when + // we definitely looked through a PHINode. + if (GEP1LastIdx != GEP2LastIdx && + GEP1LastIdx->getType() == GEP2LastIdx->getType()) { + KnownBits Known1 = computeKnownBits(GEP1LastIdx, DL); + KnownBits Known2 = computeKnownBits(GEP2LastIdx, DL); + if (Known1.Zero.intersects(Known2.One) || + Known1.One.intersects(Known2.Zero)) + return NoAlias; + } + } else if (isKnownNonEqual(GEP1LastIdx, GEP2LastIdx, DL)) + return NoAlias; + } + return MayAlias; + } else if (!LastIndexedStruct || !C1 || !C2) { + return MayAlias; + } + + if (C1->getValue().getActiveBits() > 64 || + C2->getValue().getActiveBits() > 64) + return MayAlias; + + // We know that: + // - both GEPs begin indexing from the exact same pointer; + // - the last indices in both GEPs are constants, indexing into a struct; + // - said indices are different, hence, the pointed-to fields are different; + // - both GEPs only index through arrays prior to that. + // + // This lets us determine that the struct that GEP1 indexes into and the + // struct that GEP2 indexes into must either precisely overlap or be + // completely disjoint. Because they cannot partially overlap, indexing into + // different non-overlapping fields of the struct will never alias. + + // Therefore, the only remaining thing needed to show that both GEPs can't + // alias is that the fields are not overlapping. + const StructLayout *SL = DL.getStructLayout(LastIndexedStruct); + const uint64_t StructSize = SL->getSizeInBytes(); + const uint64_t V1Off = SL->getElementOffset(C1->getZExtValue()); + const uint64_t V2Off = SL->getElementOffset(C2->getZExtValue()); + + auto EltsDontOverlap = [StructSize](uint64_t V1Off, uint64_t V1Size, + uint64_t V2Off, uint64_t V2Size) { + return V1Off < V2Off && V1Off + V1Size <= V2Off && + ((V2Off + V2Size <= StructSize) || + (V2Off + V2Size - StructSize <= V1Off)); + }; + + if (EltsDontOverlap(V1Off, V1Size, V2Off, V2Size) || + EltsDontOverlap(V2Off, V2Size, V1Off, V1Size)) + return NoAlias; + + return MayAlias; +} + +// If a we have (a) a GEP and (b) a pointer based on an alloca, and the +// beginning of the object the GEP points would have a negative offset with +// repsect to the alloca, that means the GEP can not alias pointer (b). +// Note that the pointer based on the alloca may not be a GEP. For +// example, it may be the alloca itself. +// The same applies if (b) is based on a GlobalVariable. Note that just being +// based on isIdentifiedObject() is not enough - we need an identified object +// that does not permit access to negative offsets. For example, a negative +// offset from a noalias argument or call can be inbounds w.r.t the actual +// underlying object. +// +// For example, consider: +// +// struct { int f0, int f1, ...} foo; +// foo alloca; +// foo* random = bar(alloca); +// int *f0 = &alloca.f0 +// int *f1 = &random->f1; +// +// Which is lowered, approximately, to: +// +// %alloca = alloca %struct.foo +// %random = call %struct.foo* @random(%struct.foo* %alloca) +// %f0 = getelementptr inbounds %struct, %struct.foo* %alloca, i32 0, i32 0 +// %f1 = getelementptr inbounds %struct, %struct.foo* %random, i32 0, i32 1 +// +// Assume %f1 and %f0 alias. Then %f1 would point into the object allocated +// by %alloca. Since the %f1 GEP is inbounds, that means %random must also +// point into the same object. But since %f0 points to the beginning of %alloca, +// the highest %f1 can be is (%alloca + 3). This means %random can not be higher +// than (%alloca - 1), and so is not inbounds, a contradiction. +bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, + const DecomposedGEP &DecompGEP, const DecomposedGEP &DecompObject, + LocationSize MaybeObjectAccessSize) { + // If the object access size is unknown, or the GEP isn't inbounds, bail. + if (MaybeObjectAccessSize == LocationSize::unknown() || !GEPOp->isInBounds()) + return false; + + const uint64_t ObjectAccessSize = MaybeObjectAccessSize.getValue(); + + // We need the object to be an alloca or a globalvariable, and want to know + // the offset of the pointer from the object precisely, so no variable + // indices are allowed. + if (!(isa<AllocaInst>(DecompObject.Base) || + isa<GlobalVariable>(DecompObject.Base)) || + !DecompObject.VarIndices.empty()) + return false; + + APInt ObjectBaseOffset = DecompObject.StructOffset + + DecompObject.OtherOffset; + + // If the GEP has no variable indices, we know the precise offset + // from the base, then use it. If the GEP has variable indices, + // we can't get exact GEP offset to identify pointer alias. So return + // false in that case. + if (!DecompGEP.VarIndices.empty()) + return false; + + APInt GEPBaseOffset = DecompGEP.StructOffset; + GEPBaseOffset += DecompGEP.OtherOffset; + + return GEPBaseOffset.sge(ObjectBaseOffset + (int64_t)ObjectAccessSize); +} + +/// Provides a bunch of ad-hoc rules to disambiguate a GEP instruction against +/// another pointer. +/// +/// We know that V1 is a GEP, but we don't know anything about V2. +/// UnderlyingV1 is GetUnderlyingObject(GEP1, DL), UnderlyingV2 is the same for +/// V2. +AliasResult BasicAAResult::aliasGEP( + const GEPOperator *GEP1, LocationSize V1Size, const AAMDNodes &V1AAInfo, + const Value *V2, LocationSize V2Size, const AAMDNodes &V2AAInfo, + const Value *UnderlyingV1, const Value *UnderlyingV2, AAQueryInfo &AAQI) { + DecomposedGEP DecompGEP1, DecompGEP2; + unsigned MaxPointerSize = getMaxPointerSize(DL); + DecompGEP1.StructOffset = DecompGEP1.OtherOffset = APInt(MaxPointerSize, 0); + DecompGEP2.StructOffset = DecompGEP2.OtherOffset = APInt(MaxPointerSize, 0); + + bool GEP1MaxLookupReached = + DecomposeGEPExpression(GEP1, DecompGEP1, DL, &AC, DT); + bool GEP2MaxLookupReached = + DecomposeGEPExpression(V2, DecompGEP2, DL, &AC, DT); + + APInt GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; + APInt GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; + + assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 && + "DecomposeGEPExpression returned a result different from " + "GetUnderlyingObject"); + + // If the GEP's offset relative to its base is such that the base would + // fall below the start of the object underlying V2, then the GEP and V2 + // cannot alias. + if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && + isGEPBaseAtNegativeOffset(GEP1, DecompGEP1, DecompGEP2, V2Size)) + return NoAlias; + // If we have two gep instructions with must-alias or not-alias'ing base + // pointers, figure out if the indexes to the GEP tell us anything about the + // derived pointer. + if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) { + // Check for the GEP base being at a negative offset, this time in the other + // direction. + if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && + isGEPBaseAtNegativeOffset(GEP2, DecompGEP2, DecompGEP1, V1Size)) + return NoAlias; + // Do the base pointers alias? + AliasResult BaseAlias = + aliasCheck(UnderlyingV1, LocationSize::unknown(), AAMDNodes(), + UnderlyingV2, LocationSize::unknown(), AAMDNodes(), AAQI); + + // Check for geps of non-aliasing underlying pointers where the offsets are + // identical. + if ((BaseAlias == MayAlias) && V1Size == V2Size) { + // Do the base pointers alias assuming type and size. + AliasResult PreciseBaseAlias = aliasCheck( + UnderlyingV1, V1Size, V1AAInfo, UnderlyingV2, V2Size, V2AAInfo, AAQI); + if (PreciseBaseAlias == NoAlias) { + // See if the computed offset from the common pointer tells us about the + // relation of the resulting pointer. + // If the max search depth is reached the result is undefined + if (GEP2MaxLookupReached || GEP1MaxLookupReached) + return MayAlias; + + // Same offsets. + if (GEP1BaseOffset == GEP2BaseOffset && + DecompGEP1.VarIndices == DecompGEP2.VarIndices) + return NoAlias; + } + } + + // If we get a No or May, then return it immediately, no amount of analysis + // will improve this situation. + if (BaseAlias != MustAlias) { + assert(BaseAlias == NoAlias || BaseAlias == MayAlias); + return BaseAlias; + } + + // Otherwise, we have a MustAlias. Since the base pointers alias each other + // exactly, see if the computed offset from the common pointer tells us + // about the relation of the resulting pointer. + // If we know the two GEPs are based off of the exact same pointer (and not + // just the same underlying object), see if that tells us anything about + // the resulting pointers. + if (GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() == + GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() && + GEP1->getPointerOperandType() == GEP2->getPointerOperandType()) { + AliasResult R = aliasSameBasePointerGEPs(GEP1, V1Size, GEP2, V2Size, DL); + // If we couldn't find anything interesting, don't abandon just yet. + if (R != MayAlias) + return R; + } + + // If the max search depth is reached, the result is undefined + if (GEP2MaxLookupReached || GEP1MaxLookupReached) + return MayAlias; + + // Subtract the GEP2 pointer from the GEP1 pointer to find out their + // symbolic difference. + GEP1BaseOffset -= GEP2BaseOffset; + GetIndexDifference(DecompGEP1.VarIndices, DecompGEP2.VarIndices); + + } else { + // Check to see if these two pointers are related by the getelementptr + // instruction. If one pointer is a GEP with a non-zero index of the other + // pointer, we know they cannot alias. + + // If both accesses are unknown size, we can't do anything useful here. + if (V1Size == LocationSize::unknown() && V2Size == LocationSize::unknown()) + return MayAlias; + + AliasResult R = aliasCheck(UnderlyingV1, LocationSize::unknown(), + AAMDNodes(), V2, LocationSize::unknown(), + V2AAInfo, AAQI, nullptr, UnderlyingV2); + if (R != MustAlias) { + // If V2 may alias GEP base pointer, conservatively returns MayAlias. + // If V2 is known not to alias GEP base pointer, then the two values + // cannot alias per GEP semantics: "Any memory access must be done through + // a pointer value associated with an address range of the memory access, + // otherwise the behavior is undefined.". + assert(R == NoAlias || R == MayAlias); + return R; + } + + // If the max search depth is reached the result is undefined + if (GEP1MaxLookupReached) + return MayAlias; + } + + // In the two GEP Case, if there is no difference in the offsets of the + // computed pointers, the resultant pointers are a must alias. This + // happens when we have two lexically identical GEP's (for example). + // + // In the other case, if we have getelementptr <ptr>, 0, 0, 0, 0, ... and V2 + // must aliases the GEP, the end result is a must alias also. + if (GEP1BaseOffset == 0 && DecompGEP1.VarIndices.empty()) + return MustAlias; + + // If there is a constant difference between the pointers, but the difference + // is less than the size of the associated memory object, then we know + // that the objects are partially overlapping. If the difference is + // greater, we know they do not overlap. + if (GEP1BaseOffset != 0 && DecompGEP1.VarIndices.empty()) { + if (GEP1BaseOffset.sge(0)) { + if (V2Size != LocationSize::unknown()) { + if (GEP1BaseOffset.ult(V2Size.getValue())) + return PartialAlias; + return NoAlias; + } + } else { + // We have the situation where: + // + + + // | BaseOffset | + // ---------------->| + // |-->V1Size |-------> V2Size + // GEP1 V2 + // We need to know that V2Size is not unknown, otherwise we might have + // stripped a gep with negative index ('gep <ptr>, -1, ...). + if (V1Size != LocationSize::unknown() && + V2Size != LocationSize::unknown()) { + if ((-GEP1BaseOffset).ult(V1Size.getValue())) + return PartialAlias; + return NoAlias; + } + } + } + + if (!DecompGEP1.VarIndices.empty()) { + APInt Modulo(MaxPointerSize, 0); + bool AllPositive = true; + for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) { + + // Try to distinguish something like &A[i][1] against &A[42][0]. + // Grab the least significant bit set in any of the scales. We + // don't need std::abs here (even if the scale's negative) as we'll + // be ^'ing Modulo with itself later. + Modulo |= DecompGEP1.VarIndices[i].Scale; + + if (AllPositive) { + // If the Value could change between cycles, then any reasoning about + // the Value this cycle may not hold in the next cycle. We'll just + // give up if we can't determine conditions that hold for every cycle: + const Value *V = DecompGEP1.VarIndices[i].V; + + KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, DT); + bool SignKnownZero = Known.isNonNegative(); + bool SignKnownOne = Known.isNegative(); + + // Zero-extension widens the variable, and so forces the sign + // bit to zero. + bool IsZExt = DecompGEP1.VarIndices[i].ZExtBits > 0 || isa<ZExtInst>(V); + SignKnownZero |= IsZExt; + SignKnownOne &= !IsZExt; + + // If the variable begins with a zero then we know it's + // positive, regardless of whether the value is signed or + // unsigned. + APInt Scale = DecompGEP1.VarIndices[i].Scale; + AllPositive = + (SignKnownZero && Scale.sge(0)) || (SignKnownOne && Scale.slt(0)); + } + } + + Modulo = Modulo ^ (Modulo & (Modulo - 1)); + + // We can compute the difference between the two addresses + // mod Modulo. Check whether that difference guarantees that the + // two locations do not alias. + APInt ModOffset = GEP1BaseOffset & (Modulo - 1); + if (V1Size != LocationSize::unknown() && + V2Size != LocationSize::unknown() && ModOffset.uge(V2Size.getValue()) && + (Modulo - ModOffset).uge(V1Size.getValue())) + return NoAlias; + + // If we know all the variables are positive, then GEP1 >= GEP1BasePtr. + // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers + // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr. + if (AllPositive && GEP1BaseOffset.sgt(0) && + V2Size != LocationSize::unknown() && + GEP1BaseOffset.uge(V2Size.getValue())) + return NoAlias; + + if (constantOffsetHeuristic(DecompGEP1.VarIndices, V1Size, V2Size, + GEP1BaseOffset, &AC, DT)) + return NoAlias; + } + + // Statically, we can see that the base objects are the same, but the + // pointers have dynamic offsets which we can't resolve. And none of our + // little tricks above worked. + return MayAlias; +} + +static AliasResult MergeAliasResults(AliasResult A, AliasResult B) { + // If the results agree, take it. + if (A == B) + return A; + // A mix of PartialAlias and MustAlias is PartialAlias. + if ((A == PartialAlias && B == MustAlias) || + (B == PartialAlias && A == MustAlias)) + return PartialAlias; + // Otherwise, we don't know anything. + return MayAlias; +} + +/// Provides a bunch of ad-hoc rules to disambiguate a Select instruction +/// against another. +AliasResult +BasicAAResult::aliasSelect(const SelectInst *SI, LocationSize SISize, + const AAMDNodes &SIAAInfo, const Value *V2, + LocationSize V2Size, const AAMDNodes &V2AAInfo, + const Value *UnderV2, AAQueryInfo &AAQI) { + // If the values are Selects with the same condition, we can do a more precise + // check: just check for aliases between the values on corresponding arms. + if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) + if (SI->getCondition() == SI2->getCondition()) { + AliasResult Alias = + aliasCheck(SI->getTrueValue(), SISize, SIAAInfo, SI2->getTrueValue(), + V2Size, V2AAInfo, AAQI); + if (Alias == MayAlias) + return MayAlias; + AliasResult ThisAlias = + aliasCheck(SI->getFalseValue(), SISize, SIAAInfo, + SI2->getFalseValue(), V2Size, V2AAInfo, AAQI); + return MergeAliasResults(ThisAlias, Alias); + } + + // If both arms of the Select node NoAlias or MustAlias V2, then returns + // NoAlias / MustAlias. Otherwise, returns MayAlias. + AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, SI->getTrueValue(), + SISize, SIAAInfo, AAQI, UnderV2); + if (Alias == MayAlias) + return MayAlias; + + AliasResult ThisAlias = aliasCheck(V2, V2Size, V2AAInfo, SI->getFalseValue(), + SISize, SIAAInfo, AAQI, UnderV2); + return MergeAliasResults(ThisAlias, Alias); +} + +/// Provide a bunch of ad-hoc rules to disambiguate a PHI instruction against +/// another. +AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize, + const AAMDNodes &PNAAInfo, const Value *V2, + LocationSize V2Size, + const AAMDNodes &V2AAInfo, + const Value *UnderV2, AAQueryInfo &AAQI) { + // Track phi nodes we have visited. We use this information when we determine + // value equivalence. + VisitedPhiBBs.insert(PN->getParent()); + + // If the values are PHIs in the same block, we can do a more precise + // as well as efficient check: just check for aliases between the values + // on corresponding edges. + if (const PHINode *PN2 = dyn_cast<PHINode>(V2)) + if (PN2->getParent() == PN->getParent()) { + AAQueryInfo::LocPair Locs(MemoryLocation(PN, PNSize, PNAAInfo), + MemoryLocation(V2, V2Size, V2AAInfo)); + if (PN > V2) + std::swap(Locs.first, Locs.second); + // Analyse the PHIs' inputs under the assumption that the PHIs are + // NoAlias. + // If the PHIs are May/MustAlias there must be (recursively) an input + // operand from outside the PHIs' cycle that is MayAlias/MustAlias or + // there must be an operation on the PHIs within the PHIs' value cycle + // that causes a MayAlias. + // Pretend the phis do not alias. + AliasResult Alias = NoAlias; + AliasResult OrigAliasResult; + { + // Limited lifetime iterator invalidated by the aliasCheck call below. + auto CacheIt = AAQI.AliasCache.find(Locs); + assert((CacheIt != AAQI.AliasCache.end()) && + "There must exist an entry for the phi node"); + OrigAliasResult = CacheIt->second; + CacheIt->second = NoAlias; + } + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + AliasResult ThisAlias = + aliasCheck(PN->getIncomingValue(i), PNSize, PNAAInfo, + PN2->getIncomingValueForBlock(PN->getIncomingBlock(i)), + V2Size, V2AAInfo, AAQI); + Alias = MergeAliasResults(ThisAlias, Alias); + if (Alias == MayAlias) + break; + } + + // Reset if speculation failed. + if (Alias != NoAlias) { + auto Pair = + AAQI.AliasCache.insert(std::make_pair(Locs, OrigAliasResult)); + assert(!Pair.second && "Entry must have existed"); + Pair.first->second = OrigAliasResult; + } + return Alias; + } + + SmallVector<Value *, 4> V1Srcs; + bool isRecursive = false; + if (PV) { + // If we have PhiValues then use it to get the underlying phi values. + const PhiValues::ValueSet &PhiValueSet = PV->getValuesForPhi(PN); + // If we have more phi values than the search depth then return MayAlias + // conservatively to avoid compile time explosion. The worst possible case + // is if both sides are PHI nodes. In which case, this is O(m x n) time + // where 'm' and 'n' are the number of PHI sources. + if (PhiValueSet.size() > MaxLookupSearchDepth) + return MayAlias; + // Add the values to V1Srcs + for (Value *PV1 : PhiValueSet) { + if (EnableRecPhiAnalysis) { + if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) { + // Check whether the incoming value is a GEP that advances the pointer + // result of this PHI node (e.g. in a loop). If this is the case, we + // would recurse and always get a MayAlias. Handle this case specially + // below. + if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 && + isa<ConstantInt>(PV1GEP->idx_begin())) { + isRecursive = true; + continue; + } + } + } + V1Srcs.push_back(PV1); + } + } else { + // If we don't have PhiInfo then just look at the operands of the phi itself + // FIXME: Remove this once we can guarantee that we have PhiInfo always + SmallPtrSet<Value *, 4> UniqueSrc; + for (Value *PV1 : PN->incoming_values()) { + if (isa<PHINode>(PV1)) + // If any of the source itself is a PHI, return MayAlias conservatively + // to avoid compile time explosion. The worst possible case is if both + // sides are PHI nodes. In which case, this is O(m x n) time where 'm' + // and 'n' are the number of PHI sources. + return MayAlias; + + if (EnableRecPhiAnalysis) + if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) { + // Check whether the incoming value is a GEP that advances the pointer + // result of this PHI node (e.g. in a loop). If this is the case, we + // would recurse and always get a MayAlias. Handle this case specially + // below. + if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 && + isa<ConstantInt>(PV1GEP->idx_begin())) { + isRecursive = true; + continue; + } + } + + if (UniqueSrc.insert(PV1).second) + V1Srcs.push_back(PV1); + } + } + + // If V1Srcs is empty then that means that the phi has no underlying non-phi + // value. This should only be possible in blocks unreachable from the entry + // block, but return MayAlias just in case. + if (V1Srcs.empty()) + return MayAlias; + + // If this PHI node is recursive, set the size of the accessed memory to + // unknown to represent all the possible values the GEP could advance the + // pointer to. + if (isRecursive) + PNSize = LocationSize::unknown(); + + AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, V1Srcs[0], PNSize, + PNAAInfo, AAQI, UnderV2); + + // Early exit if the check of the first PHI source against V2 is MayAlias. + // Other results are not possible. + if (Alias == MayAlias) + return MayAlias; + + // If all sources of the PHI node NoAlias or MustAlias V2, then returns + // NoAlias / MustAlias. Otherwise, returns MayAlias. + for (unsigned i = 1, e = V1Srcs.size(); i != e; ++i) { + Value *V = V1Srcs[i]; + + AliasResult ThisAlias = + aliasCheck(V2, V2Size, V2AAInfo, V, PNSize, PNAAInfo, AAQI, UnderV2); + Alias = MergeAliasResults(ThisAlias, Alias); + if (Alias == MayAlias) + break; + } + + return Alias; +} + +/// Provides a bunch of ad-hoc rules to disambiguate in common cases, such as +/// array references. +AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, + AAMDNodes V1AAInfo, const Value *V2, + LocationSize V2Size, AAMDNodes V2AAInfo, + AAQueryInfo &AAQI, const Value *O1, + const Value *O2) { + // If either of the memory references is empty, it doesn't matter what the + // pointer values are. + if (V1Size.isZero() || V2Size.isZero()) + return NoAlias; + + // Strip off any casts if they exist. + V1 = V1->stripPointerCastsAndInvariantGroups(); + V2 = V2->stripPointerCastsAndInvariantGroups(); + + // If V1 or V2 is undef, the result is NoAlias because we can always pick a + // value for undef that aliases nothing in the program. + if (isa<UndefValue>(V1) || isa<UndefValue>(V2)) + return NoAlias; + + // Are we checking for alias of the same value? + // Because we look 'through' phi nodes, we could look at "Value" pointers from + // different iterations. We must therefore make sure that this is not the + // case. The function isValueEqualInPotentialCycles ensures that this cannot + // happen by looking at the visited phi nodes and making sure they cannot + // reach the value. + if (isValueEqualInPotentialCycles(V1, V2)) + return MustAlias; + + if (!V1->getType()->isPointerTy() || !V2->getType()->isPointerTy()) + return NoAlias; // Scalars cannot alias each other + + // Figure out what objects these things are pointing to if we can. + if (O1 == nullptr) + O1 = GetUnderlyingObject(V1, DL, MaxLookupSearchDepth); + + if (O2 == nullptr) + O2 = GetUnderlyingObject(V2, DL, MaxLookupSearchDepth); + + // Null values in the default address space don't point to any object, so they + // don't alias any other pointer. + if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O1)) + if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace())) + return NoAlias; + if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O2)) + if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace())) + return NoAlias; + + if (O1 != O2) { + // If V1/V2 point to two different objects, we know that we have no alias. + if (isIdentifiedObject(O1) && isIdentifiedObject(O2)) + return NoAlias; + + // Constant pointers can't alias with non-const isIdentifiedObject objects. + if ((isa<Constant>(O1) && isIdentifiedObject(O2) && !isa<Constant>(O2)) || + (isa<Constant>(O2) && isIdentifiedObject(O1) && !isa<Constant>(O1))) + return NoAlias; + + // Function arguments can't alias with things that are known to be + // unambigously identified at the function level. + if ((isa<Argument>(O1) && isIdentifiedFunctionLocal(O2)) || + (isa<Argument>(O2) && isIdentifiedFunctionLocal(O1))) + return NoAlias; + + // If one pointer is the result of a call/invoke or load and the other is a + // non-escaping local object within the same function, then we know the + // object couldn't escape to a point where the call could return it. + // + // Note that if the pointers are in different functions, there are a + // variety of complications. A call with a nocapture argument may still + // temporary store the nocapture argument's value in a temporary memory + // location if that memory location doesn't escape. Or it may pass a + // nocapture value to other functions as long as they don't capture it. + if (isEscapeSource(O1) && + isNonEscapingLocalObject(O2, &AAQI.IsCapturedCache)) + return NoAlias; + if (isEscapeSource(O2) && + isNonEscapingLocalObject(O1, &AAQI.IsCapturedCache)) + return NoAlias; + } + + // If the size of one access is larger than the entire object on the other + // side, then we know such behavior is undefined and can assume no alias. + bool NullIsValidLocation = NullPointerIsDefined(&F); + if ((isObjectSmallerThan( + O2, getMinimalExtentFrom(*V1, V1Size, DL, NullIsValidLocation), DL, + TLI, NullIsValidLocation)) || + (isObjectSmallerThan( + O1, getMinimalExtentFrom(*V2, V2Size, DL, NullIsValidLocation), DL, + TLI, NullIsValidLocation))) + return NoAlias; + + // Check the cache before climbing up use-def chains. This also terminates + // otherwise infinitely recursive queries. + AAQueryInfo::LocPair Locs(MemoryLocation(V1, V1Size, V1AAInfo), + MemoryLocation(V2, V2Size, V2AAInfo)); + if (V1 > V2) + std::swap(Locs.first, Locs.second); + std::pair<AAQueryInfo::AliasCacheT::iterator, bool> Pair = + AAQI.AliasCache.try_emplace(Locs, MayAlias); + if (!Pair.second) + return Pair.first->second; + + // FIXME: This isn't aggressively handling alias(GEP, PHI) for example: if the + // GEP can't simplify, we don't even look at the PHI cases. + if (!isa<GEPOperator>(V1) && isa<GEPOperator>(V2)) { + std::swap(V1, V2); + std::swap(V1Size, V2Size); + std::swap(O1, O2); + std::swap(V1AAInfo, V2AAInfo); + } + if (const GEPOperator *GV1 = dyn_cast<GEPOperator>(V1)) { + AliasResult Result = + aliasGEP(GV1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O1, O2, AAQI); + if (Result != MayAlias) { + auto ItInsPair = AAQI.AliasCache.insert(std::make_pair(Locs, Result)); + assert(!ItInsPair.second && "Entry must have existed"); + ItInsPair.first->second = Result; + return Result; + } + } + + if (isa<PHINode>(V2) && !isa<PHINode>(V1)) { + std::swap(V1, V2); + std::swap(O1, O2); + std::swap(V1Size, V2Size); + std::swap(V1AAInfo, V2AAInfo); + } + if (const PHINode *PN = dyn_cast<PHINode>(V1)) { + AliasResult Result = + aliasPHI(PN, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O2, AAQI); + if (Result != MayAlias) { + Pair = AAQI.AliasCache.try_emplace(Locs, Result); + assert(!Pair.second && "Entry must have existed"); + return Pair.first->second = Result; + } + } + + if (isa<SelectInst>(V2) && !isa<SelectInst>(V1)) { + std::swap(V1, V2); + std::swap(O1, O2); + std::swap(V1Size, V2Size); + std::swap(V1AAInfo, V2AAInfo); + } + if (const SelectInst *S1 = dyn_cast<SelectInst>(V1)) { + AliasResult Result = + aliasSelect(S1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O2, AAQI); + if (Result != MayAlias) { + Pair = AAQI.AliasCache.try_emplace(Locs, Result); + assert(!Pair.second && "Entry must have existed"); + return Pair.first->second = Result; + } + } + + // If both pointers are pointing into the same object and one of them + // accesses the entire object, then the accesses must overlap in some way. + if (O1 == O2) + if (V1Size.isPrecise() && V2Size.isPrecise() && + (isObjectSize(O1, V1Size.getValue(), DL, TLI, NullIsValidLocation) || + isObjectSize(O2, V2Size.getValue(), DL, TLI, NullIsValidLocation))) { + Pair = AAQI.AliasCache.try_emplace(Locs, PartialAlias); + assert(!Pair.second && "Entry must have existed"); + return Pair.first->second = PartialAlias; + } + + // Recurse back into the best AA results we have, potentially with refined + // memory locations. We have already ensured that BasicAA has a MayAlias + // cache result for these, so any recursion back into BasicAA won't loop. + AliasResult Result = getBestAAResults().alias(Locs.first, Locs.second, AAQI); + Pair = AAQI.AliasCache.try_emplace(Locs, Result); + assert(!Pair.second && "Entry must have existed"); + return Pair.first->second = Result; +} + +/// Check whether two Values can be considered equivalent. +/// +/// In addition to pointer equivalence of \p V1 and \p V2 this checks whether +/// they can not be part of a cycle in the value graph by looking at all +/// visited phi nodes an making sure that the phis cannot reach the value. We +/// have to do this because we are looking through phi nodes (That is we say +/// noalias(V, phi(VA, VB)) if noalias(V, VA) and noalias(V, VB). +bool BasicAAResult::isValueEqualInPotentialCycles(const Value *V, + const Value *V2) { + if (V != V2) + return false; + + const Instruction *Inst = dyn_cast<Instruction>(V); + if (!Inst) + return true; + + if (VisitedPhiBBs.empty()) + return true; + + if (VisitedPhiBBs.size() > MaxNumPhiBBsValueReachabilityCheck) + return false; + + // Make sure that the visited phis cannot reach the Value. This ensures that + // the Values cannot come from different iterations of a potential cycle the + // phi nodes could be involved in. + for (auto *P : VisitedPhiBBs) + if (isPotentiallyReachable(&P->front(), Inst, nullptr, DT, LI)) + return false; + + return true; +} + +/// Computes the symbolic difference between two de-composed GEPs. +/// +/// Dest and Src are the variable indices from two decomposed GetElementPtr +/// instructions GEP1 and GEP2 which have common base pointers. +void BasicAAResult::GetIndexDifference( + SmallVectorImpl<VariableGEPIndex> &Dest, + const SmallVectorImpl<VariableGEPIndex> &Src) { + if (Src.empty()) + return; + + for (unsigned i = 0, e = Src.size(); i != e; ++i) { + const Value *V = Src[i].V; + unsigned ZExtBits = Src[i].ZExtBits, SExtBits = Src[i].SExtBits; + APInt Scale = Src[i].Scale; + + // Find V in Dest. This is N^2, but pointer indices almost never have more + // than a few variable indexes. + for (unsigned j = 0, e = Dest.size(); j != e; ++j) { + if (!isValueEqualInPotentialCycles(Dest[j].V, V) || + Dest[j].ZExtBits != ZExtBits || Dest[j].SExtBits != SExtBits) + continue; + + // If we found it, subtract off Scale V's from the entry in Dest. If it + // goes to zero, remove the entry. + if (Dest[j].Scale != Scale) + Dest[j].Scale -= Scale; + else + Dest.erase(Dest.begin() + j); + Scale = 0; + break; + } + + // If we didn't consume this entry, add it to the end of the Dest list. + if (!!Scale) { + VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale}; + Dest.push_back(Entry); + } + } +} + +bool BasicAAResult::constantOffsetHeuristic( + const SmallVectorImpl<VariableGEPIndex> &VarIndices, + LocationSize MaybeV1Size, LocationSize MaybeV2Size, APInt BaseOffset, + AssumptionCache *AC, DominatorTree *DT) { + if (VarIndices.size() != 2 || MaybeV1Size == LocationSize::unknown() || + MaybeV2Size == LocationSize::unknown()) + return false; + + const uint64_t V1Size = MaybeV1Size.getValue(); + const uint64_t V2Size = MaybeV2Size.getValue(); + + const VariableGEPIndex &Var0 = VarIndices[0], &Var1 = VarIndices[1]; + + if (Var0.ZExtBits != Var1.ZExtBits || Var0.SExtBits != Var1.SExtBits || + Var0.Scale != -Var1.Scale) + return false; + + unsigned Width = Var1.V->getType()->getIntegerBitWidth(); + + // We'll strip off the Extensions of Var0 and Var1 and do another round + // of GetLinearExpression decomposition. In the example above, if Var0 + // is zext(%x + 1) we should get V1 == %x and V1Offset == 1. + + APInt V0Scale(Width, 0), V0Offset(Width, 0), V1Scale(Width, 0), + V1Offset(Width, 0); + bool NSW = true, NUW = true; + unsigned V0ZExtBits = 0, V0SExtBits = 0, V1ZExtBits = 0, V1SExtBits = 0; + const Value *V0 = GetLinearExpression(Var0.V, V0Scale, V0Offset, V0ZExtBits, + V0SExtBits, DL, 0, AC, DT, NSW, NUW); + NSW = true; + NUW = true; + const Value *V1 = GetLinearExpression(Var1.V, V1Scale, V1Offset, V1ZExtBits, + V1SExtBits, DL, 0, AC, DT, NSW, NUW); + + if (V0Scale != V1Scale || V0ZExtBits != V1ZExtBits || + V0SExtBits != V1SExtBits || !isValueEqualInPotentialCycles(V0, V1)) + return false; + + // We have a hit - Var0 and Var1 only differ by a constant offset! + + // If we've been sext'ed then zext'd the maximum difference between Var0 and + // Var1 is possible to calculate, but we're just interested in the absolute + // minimum difference between the two. The minimum distance may occur due to + // wrapping; consider "add i3 %i, 5": if %i == 7 then 7 + 5 mod 8 == 4, and so + // the minimum distance between %i and %i + 5 is 3. + APInt MinDiff = V0Offset - V1Offset, Wrapped = -MinDiff; + MinDiff = APIntOps::umin(MinDiff, Wrapped); + APInt MinDiffBytes = + MinDiff.zextOrTrunc(Var0.Scale.getBitWidth()) * Var0.Scale.abs(); + + // We can't definitely say whether GEP1 is before or after V2 due to wrapping + // arithmetic (i.e. for some values of GEP1 and V2 GEP1 < V2, and for other + // values GEP1 > V2). We'll therefore only declare NoAlias if both V1Size and + // V2Size can fit in the MinDiffBytes gap. + return MinDiffBytes.uge(V1Size + BaseOffset.abs()) && + MinDiffBytes.uge(V2Size + BaseOffset.abs()); +} + +//===----------------------------------------------------------------------===// +// BasicAliasAnalysis Pass +//===----------------------------------------------------------------------===// + +AnalysisKey BasicAA::Key; + +BasicAAResult BasicAA::run(Function &F, FunctionAnalysisManager &AM) { + return BasicAAResult(F.getParent()->getDataLayout(), + F, + AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + &AM.getResult<DominatorTreeAnalysis>(F), + AM.getCachedResult<LoopAnalysis>(F), + AM.getCachedResult<PhiValuesAnalysis>(F)); +} + +BasicAAWrapperPass::BasicAAWrapperPass() : FunctionPass(ID) { + initializeBasicAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +char BasicAAWrapperPass::ID = 0; + +void BasicAAWrapperPass::anchor() {} + +INITIALIZE_PASS_BEGIN(BasicAAWrapperPass, "basicaa", + "Basic Alias Analysis (stateless AA impl)", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(BasicAAWrapperPass, "basicaa", + "Basic Alias Analysis (stateless AA impl)", false, true) + +FunctionPass *llvm::createBasicAAWrapperPass() { + return new BasicAAWrapperPass(); +} + +bool BasicAAWrapperPass::runOnFunction(Function &F) { + auto &ACT = getAnalysis<AssumptionCacheTracker>(); + auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); + auto &DTWP = getAnalysis<DominatorTreeWrapperPass>(); + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + auto *PVWP = getAnalysisIfAvailable<PhiValuesWrapperPass>(); + + Result.reset(new BasicAAResult(F.getParent()->getDataLayout(), F, + TLIWP.getTLI(F), ACT.getAssumptionCache(F), + &DTWP.getDomTree(), + LIWP ? &LIWP->getLoopInfo() : nullptr, + PVWP ? &PVWP->getResult() : nullptr)); + + return false; +} + +void BasicAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addUsedIfAvailable<PhiValuesWrapperPass>(); +} + +BasicAAResult llvm::createLegacyPMBasicAAResult(Pass &P, Function &F) { + return BasicAAResult( + F.getParent()->getDataLayout(), F, + P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), + P.getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); +} diff --git a/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/llvm/lib/Analysis/BlockFrequencyInfo.cpp new file mode 100644 index 000000000000..de183bbde173 --- /dev/null +++ b/llvm/lib/Analysis/BlockFrequencyInfo.cpp @@ -0,0 +1,342 @@ +//===- BlockFrequencyInfo.cpp - Block Frequency Analysis ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/iterator.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <string> + +using namespace llvm; + +#define DEBUG_TYPE "block-freq" + +static cl::opt<GVDAGType> ViewBlockFreqPropagationDAG( + "view-block-freq-propagation-dags", cl::Hidden, + cl::desc("Pop up a window to show a dag displaying how block " + "frequencies propagation through the CFG."), + cl::values(clEnumValN(GVDT_None, "none", "do not display graphs."), + clEnumValN(GVDT_Fraction, "fraction", + "display a graph using the " + "fractional block frequency representation."), + clEnumValN(GVDT_Integer, "integer", + "display a graph using the raw " + "integer fractional block frequency representation."), + clEnumValN(GVDT_Count, "count", "display a graph using the real " + "profile count if available."))); + +cl::opt<std::string> + ViewBlockFreqFuncName("view-bfi-func-name", cl::Hidden, + cl::desc("The option to specify " + "the name of the function " + "whose CFG will be displayed.")); + +cl::opt<unsigned> + ViewHotFreqPercent("view-hot-freq-percent", cl::init(10), cl::Hidden, + cl::desc("An integer in percent used to specify " + "the hot blocks/edges to be displayed " + "in red: a block or edge whose frequency " + "is no less than the max frequency of the " + "function multiplied by this percent.")); + +// Command line option to turn on CFG dot or text dump after profile annotation. +cl::opt<PGOViewCountsType> PGOViewCounts( + "pgo-view-counts", cl::Hidden, + cl::desc("A boolean option to show CFG dag or text with " + "block profile counts and branch probabilities " + "right after PGO profile annotation step. The " + "profile counts are computed using branch " + "probabilities from the runtime profile data and " + "block frequency propagation algorithm. To view " + "the raw counts from the profile, use option " + "-pgo-view-raw-counts instead. To limit graph " + "display to only one function, use filtering option " + "-view-bfi-func-name."), + cl::values(clEnumValN(PGOVCT_None, "none", "do not show."), + clEnumValN(PGOVCT_Graph, "graph", "show a graph."), + clEnumValN(PGOVCT_Text, "text", "show in text."))); + +static cl::opt<bool> PrintBlockFreq( + "print-bfi", cl::init(false), cl::Hidden, + cl::desc("Print the block frequency info.")); + +cl::opt<std::string> PrintBlockFreqFuncName( + "print-bfi-func-name", cl::Hidden, + cl::desc("The option to specify the name of the function " + "whose block frequency info is printed.")); + +namespace llvm { + +static GVDAGType getGVDT() { + if (PGOViewCounts == PGOVCT_Graph) + return GVDT_Count; + return ViewBlockFreqPropagationDAG; +} + +template <> +struct GraphTraits<BlockFrequencyInfo *> { + using NodeRef = const BasicBlock *; + using ChildIteratorType = succ_const_iterator; + using nodes_iterator = pointer_iterator<Function::const_iterator>; + + static NodeRef getEntryNode(const BlockFrequencyInfo *G) { + return &G->getFunction()->front(); + } + + static ChildIteratorType child_begin(const NodeRef N) { + return succ_begin(N); + } + + static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); } + + static nodes_iterator nodes_begin(const BlockFrequencyInfo *G) { + return nodes_iterator(G->getFunction()->begin()); + } + + static nodes_iterator nodes_end(const BlockFrequencyInfo *G) { + return nodes_iterator(G->getFunction()->end()); + } +}; + +using BFIDOTGTraitsBase = + BFIDOTGraphTraitsBase<BlockFrequencyInfo, BranchProbabilityInfo>; + +template <> +struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase { + explicit DOTGraphTraits(bool isSimple = false) + : BFIDOTGTraitsBase(isSimple) {} + + std::string getNodeLabel(const BasicBlock *Node, + const BlockFrequencyInfo *Graph) { + + return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, getGVDT()); + } + + std::string getNodeAttributes(const BasicBlock *Node, + const BlockFrequencyInfo *Graph) { + return BFIDOTGTraitsBase::getNodeAttributes(Node, Graph, + ViewHotFreqPercent); + } + + std::string getEdgeAttributes(const BasicBlock *Node, EdgeIter EI, + const BlockFrequencyInfo *BFI) { + return BFIDOTGTraitsBase::getEdgeAttributes(Node, EI, BFI, BFI->getBPI(), + ViewHotFreqPercent); + } +}; + +} // end namespace llvm + +BlockFrequencyInfo::BlockFrequencyInfo() = default; + +BlockFrequencyInfo::BlockFrequencyInfo(const Function &F, + const BranchProbabilityInfo &BPI, + const LoopInfo &LI) { + calculate(F, BPI, LI); +} + +BlockFrequencyInfo::BlockFrequencyInfo(BlockFrequencyInfo &&Arg) + : BFI(std::move(Arg.BFI)) {} + +BlockFrequencyInfo &BlockFrequencyInfo::operator=(BlockFrequencyInfo &&RHS) { + releaseMemory(); + BFI = std::move(RHS.BFI); + return *this; +} + +// Explicitly define the default constructor otherwise it would be implicitly +// defined at the first ODR-use which is the BFI member in the +// LazyBlockFrequencyInfo header. The dtor needs the BlockFrequencyInfoImpl +// template instantiated which is not available in the header. +BlockFrequencyInfo::~BlockFrequencyInfo() = default; + +bool BlockFrequencyInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<BlockFrequencyAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + +void BlockFrequencyInfo::calculate(const Function &F, + const BranchProbabilityInfo &BPI, + const LoopInfo &LI) { + if (!BFI) + BFI.reset(new ImplType); + BFI->calculate(F, BPI, LI); + if (ViewBlockFreqPropagationDAG != GVDT_None && + (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { + view(); + } + if (PrintBlockFreq && + (PrintBlockFreqFuncName.empty() || + F.getName().equals(PrintBlockFreqFuncName))) { + print(dbgs()); + } +} + +BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const { + return BFI ? BFI->getBlockFreq(BB) : 0; +} + +Optional<uint64_t> +BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB, + bool AllowSynthetic) const { + if (!BFI) + return None; + + return BFI->getBlockProfileCount(*getFunction(), BB, AllowSynthetic); +} + +Optional<uint64_t> +BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + if (!BFI) + return None; + return BFI->getProfileCountFromFreq(*getFunction(), Freq); +} + +bool BlockFrequencyInfo::isIrrLoopHeader(const BasicBlock *BB) { + assert(BFI && "Expected analysis to be available"); + return BFI->isIrrLoopHeader(BB); +} + +void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { + assert(BFI && "Expected analysis to be available"); + BFI->setBlockFreq(BB, Freq); +} + +void BlockFrequencyInfo::setBlockFreqAndScale( + const BasicBlock *ReferenceBB, uint64_t Freq, + SmallPtrSetImpl<BasicBlock *> &BlocksToScale) { + assert(BFI && "Expected analysis to be available"); + // Use 128 bits APInt to avoid overflow. + APInt NewFreq(128, Freq); + APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency()); + APInt BBFreq(128, 0); + for (auto *BB : BlocksToScale) { + BBFreq = BFI->getBlockFreq(BB).getFrequency(); + // Multiply first by NewFreq and then divide by OldFreq + // to minimize loss of precision. + BBFreq *= NewFreq; + // udiv is an expensive operation in the general case. If this ends up being + // a hot spot, one of the options proposed in + // https://reviews.llvm.org/D28535#650071 could be used to avoid this. + BBFreq = BBFreq.udiv(OldFreq); + BFI->setBlockFreq(BB, BBFreq.getLimitedValue()); + } + BFI->setBlockFreq(ReferenceBB, Freq); +} + +/// Pop up a ghostview window with the current block frequency propagation +/// rendered using dot. +void BlockFrequencyInfo::view(StringRef title) const { + ViewGraph(const_cast<BlockFrequencyInfo *>(this), title); +} + +const Function *BlockFrequencyInfo::getFunction() const { + return BFI ? BFI->getFunction() : nullptr; +} + +const BranchProbabilityInfo *BlockFrequencyInfo::getBPI() const { + return BFI ? &BFI->getBPI() : nullptr; +} + +raw_ostream &BlockFrequencyInfo:: +printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const { + return BFI ? BFI->printBlockFreq(OS, Freq) : OS; +} + +raw_ostream & +BlockFrequencyInfo::printBlockFreq(raw_ostream &OS, + const BasicBlock *BB) const { + return BFI ? BFI->printBlockFreq(OS, BB) : OS; +} + +uint64_t BlockFrequencyInfo::getEntryFreq() const { + return BFI ? BFI->getEntryFreq() : 0; +} + +void BlockFrequencyInfo::releaseMemory() { BFI.reset(); } + +void BlockFrequencyInfo::print(raw_ostream &OS) const { + if (BFI) + BFI->print(OS); +} + +INITIALIZE_PASS_BEGIN(BlockFrequencyInfoWrapperPass, "block-freq", + "Block Frequency Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(BlockFrequencyInfoWrapperPass, "block-freq", + "Block Frequency Analysis", true, true) + +char BlockFrequencyInfoWrapperPass::ID = 0; + +BlockFrequencyInfoWrapperPass::BlockFrequencyInfoWrapperPass() + : FunctionPass(ID) { + initializeBlockFrequencyInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +BlockFrequencyInfoWrapperPass::~BlockFrequencyInfoWrapperPass() = default; + +void BlockFrequencyInfoWrapperPass::print(raw_ostream &OS, + const Module *) const { + BFI.print(OS); +} + +void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesAll(); +} + +void BlockFrequencyInfoWrapperPass::releaseMemory() { BFI.releaseMemory(); } + +bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) { + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + BFI.calculate(F, BPI, LI); + return false; +} + +AnalysisKey BlockFrequencyAnalysis::Key; +BlockFrequencyInfo BlockFrequencyAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + BlockFrequencyInfo BFI; + BFI.calculate(F, AM.getResult<BranchProbabilityAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); + return BFI; +} + +PreservedAnalyses +BlockFrequencyPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "Printing analysis results of BFI for function " + << "'" << F.getName() << "':" + << "\n"; + AM.getResult<BlockFrequencyAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp new file mode 100644 index 000000000000..0db6dd04a7e8 --- /dev/null +++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -0,0 +1,851 @@ +//===- BlockFrequencyImplInfo.cpp - Block Frequency Info Implementation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ScaledNumber.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <list> +#include <numeric> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::bfi_detail; + +#define DEBUG_TYPE "block-freq" + +ScaledNumber<uint64_t> BlockMass::toScaled() const { + if (isFull()) + return ScaledNumber<uint64_t>(1, 0); + return ScaledNumber<uint64_t>(getMass() + 1, -64); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void BlockMass::dump() const { print(dbgs()); } +#endif + +static char getHexDigit(int N) { + assert(N < 16); + if (N < 10) + return '0' + N; + return 'a' + N - 10; +} + +raw_ostream &BlockMass::print(raw_ostream &OS) const { + for (int Digits = 0; Digits < 16; ++Digits) + OS << getHexDigit(Mass >> (60 - Digits * 4) & 0xf); + return OS; +} + +namespace { + +using BlockNode = BlockFrequencyInfoImplBase::BlockNode; +using Distribution = BlockFrequencyInfoImplBase::Distribution; +using WeightList = BlockFrequencyInfoImplBase::Distribution::WeightList; +using Scaled64 = BlockFrequencyInfoImplBase::Scaled64; +using LoopData = BlockFrequencyInfoImplBase::LoopData; +using Weight = BlockFrequencyInfoImplBase::Weight; +using FrequencyData = BlockFrequencyInfoImplBase::FrequencyData; + +/// Dithering mass distributer. +/// +/// This class splits up a single mass into portions by weight, dithering to +/// spread out error. No mass is lost. The dithering precision depends on the +/// precision of the product of \a BlockMass and \a BranchProbability. +/// +/// The distribution algorithm follows. +/// +/// 1. Initialize by saving the sum of the weights in \a RemWeight and the +/// mass to distribute in \a RemMass. +/// +/// 2. For each portion: +/// +/// 1. Construct a branch probability, P, as the portion's weight divided +/// by the current value of \a RemWeight. +/// 2. Calculate the portion's mass as \a RemMass times P. +/// 3. Update \a RemWeight and \a RemMass at each portion by subtracting +/// the current portion's weight and mass. +struct DitheringDistributer { + uint32_t RemWeight; + BlockMass RemMass; + + DitheringDistributer(Distribution &Dist, const BlockMass &Mass); + + BlockMass takeMass(uint32_t Weight); +}; + +} // end anonymous namespace + +DitheringDistributer::DitheringDistributer(Distribution &Dist, + const BlockMass &Mass) { + Dist.normalize(); + RemWeight = Dist.Total; + RemMass = Mass; +} + +BlockMass DitheringDistributer::takeMass(uint32_t Weight) { + assert(Weight && "invalid weight"); + assert(Weight <= RemWeight); + BlockMass Mass = RemMass * BranchProbability(Weight, RemWeight); + + // Decrement totals (dither). + RemWeight -= Weight; + RemMass -= Mass; + return Mass; +} + +void Distribution::add(const BlockNode &Node, uint64_t Amount, + Weight::DistType Type) { + assert(Amount && "invalid weight of 0"); + uint64_t NewTotal = Total + Amount; + + // Check for overflow. It should be impossible to overflow twice. + bool IsOverflow = NewTotal < Total; + assert(!(DidOverflow && IsOverflow) && "unexpected repeated overflow"); + DidOverflow |= IsOverflow; + + // Update the total. + Total = NewTotal; + + // Save the weight. + Weights.push_back(Weight(Type, Node, Amount)); +} + +static void combineWeight(Weight &W, const Weight &OtherW) { + assert(OtherW.TargetNode.isValid()); + if (!W.Amount) { + W = OtherW; + return; + } + assert(W.Type == OtherW.Type); + assert(W.TargetNode == OtherW.TargetNode); + assert(OtherW.Amount && "Expected non-zero weight"); + if (W.Amount > W.Amount + OtherW.Amount) + // Saturate on overflow. + W.Amount = UINT64_MAX; + else + W.Amount += OtherW.Amount; +} + +static void combineWeightsBySorting(WeightList &Weights) { + // Sort so edges to the same node are adjacent. + llvm::sort(Weights, [](const Weight &L, const Weight &R) { + return L.TargetNode < R.TargetNode; + }); + + // Combine adjacent edges. + WeightList::iterator O = Weights.begin(); + for (WeightList::const_iterator I = O, L = O, E = Weights.end(); I != E; + ++O, (I = L)) { + *O = *I; + + // Find the adjacent weights to the same node. + for (++L; L != E && I->TargetNode == L->TargetNode; ++L) + combineWeight(*O, *L); + } + + // Erase extra entries. + Weights.erase(O, Weights.end()); +} + +static void combineWeightsByHashing(WeightList &Weights) { + // Collect weights into a DenseMap. + using HashTable = DenseMap<BlockNode::IndexType, Weight>; + + HashTable Combined(NextPowerOf2(2 * Weights.size())); + for (const Weight &W : Weights) + combineWeight(Combined[W.TargetNode.Index], W); + + // Check whether anything changed. + if (Weights.size() == Combined.size()) + return; + + // Fill in the new weights. + Weights.clear(); + Weights.reserve(Combined.size()); + for (const auto &I : Combined) + Weights.push_back(I.second); +} + +static void combineWeights(WeightList &Weights) { + // Use a hash table for many successors to keep this linear. + if (Weights.size() > 128) { + combineWeightsByHashing(Weights); + return; + } + + combineWeightsBySorting(Weights); +} + +static uint64_t shiftRightAndRound(uint64_t N, int Shift) { + assert(Shift >= 0); + assert(Shift < 64); + if (!Shift) + return N; + return (N >> Shift) + (UINT64_C(1) & N >> (Shift - 1)); +} + +void Distribution::normalize() { + // Early exit for termination nodes. + if (Weights.empty()) + return; + + // Only bother if there are multiple successors. + if (Weights.size() > 1) + combineWeights(Weights); + + // Early exit when combined into a single successor. + if (Weights.size() == 1) { + Total = 1; + Weights.front().Amount = 1; + return; + } + + // Determine how much to shift right so that the total fits into 32-bits. + // + // If we shift at all, shift by 1 extra. Otherwise, the lower limit of 1 + // for each weight can cause a 32-bit overflow. + int Shift = 0; + if (DidOverflow) + Shift = 33; + else if (Total > UINT32_MAX) + Shift = 33 - countLeadingZeros(Total); + + // Early exit if nothing needs to be scaled. + if (!Shift) { + // If we didn't overflow then combineWeights() shouldn't have changed the + // sum of the weights, but let's double-check. + assert(Total == std::accumulate(Weights.begin(), Weights.end(), UINT64_C(0), + [](uint64_t Sum, const Weight &W) { + return Sum + W.Amount; + }) && + "Expected total to be correct"); + return; + } + + // Recompute the total through accumulation (rather than shifting it) so that + // it's accurate after shifting and any changes combineWeights() made above. + Total = 0; + + // Sum the weights to each node and shift right if necessary. + for (Weight &W : Weights) { + // Scale down below UINT32_MAX. Since Shift is larger than necessary, we + // can round here without concern about overflow. + assert(W.TargetNode.isValid()); + W.Amount = std::max(UINT64_C(1), shiftRightAndRound(W.Amount, Shift)); + assert(W.Amount <= UINT32_MAX); + + // Update the total. + Total += W.Amount; + } + assert(Total <= UINT32_MAX); +} + +void BlockFrequencyInfoImplBase::clear() { + // Swap with a default-constructed std::vector, since std::vector<>::clear() + // does not actually clear heap storage. + std::vector<FrequencyData>().swap(Freqs); + IsIrrLoopHeader.clear(); + std::vector<WorkingData>().swap(Working); + Loops.clear(); +} + +/// Clear all memory not needed downstream. +/// +/// Releases all memory not used downstream. In particular, saves Freqs. +static void cleanup(BlockFrequencyInfoImplBase &BFI) { + std::vector<FrequencyData> SavedFreqs(std::move(BFI.Freqs)); + SparseBitVector<> SavedIsIrrLoopHeader(std::move(BFI.IsIrrLoopHeader)); + BFI.clear(); + BFI.Freqs = std::move(SavedFreqs); + BFI.IsIrrLoopHeader = std::move(SavedIsIrrLoopHeader); +} + +bool BlockFrequencyInfoImplBase::addToDist(Distribution &Dist, + const LoopData *OuterLoop, + const BlockNode &Pred, + const BlockNode &Succ, + uint64_t Weight) { + if (!Weight) + Weight = 1; + + auto isLoopHeader = [&OuterLoop](const BlockNode &Node) { + return OuterLoop && OuterLoop->isHeader(Node); + }; + + BlockNode Resolved = Working[Succ.Index].getResolvedNode(); + +#ifndef NDEBUG + auto debugSuccessor = [&](const char *Type) { + dbgs() << " =>" + << " [" << Type << "] weight = " << Weight; + if (!isLoopHeader(Resolved)) + dbgs() << ", succ = " << getBlockName(Succ); + if (Resolved != Succ) + dbgs() << ", resolved = " << getBlockName(Resolved); + dbgs() << "\n"; + }; + (void)debugSuccessor; +#endif + + if (isLoopHeader(Resolved)) { + LLVM_DEBUG(debugSuccessor("backedge")); + Dist.addBackedge(Resolved, Weight); + return true; + } + + if (Working[Resolved.Index].getContainingLoop() != OuterLoop) { + LLVM_DEBUG(debugSuccessor(" exit ")); + Dist.addExit(Resolved, Weight); + return true; + } + + if (Resolved < Pred) { + if (!isLoopHeader(Pred)) { + // If OuterLoop is an irreducible loop, we can't actually handle this. + assert((!OuterLoop || !OuterLoop->isIrreducible()) && + "unhandled irreducible control flow"); + + // Irreducible backedge. Abort. + LLVM_DEBUG(debugSuccessor("abort!!!")); + return false; + } + + // If "Pred" is a loop header, then this isn't really a backedge; rather, + // OuterLoop must be irreducible. These false backedges can come only from + // secondary loop headers. + assert(OuterLoop && OuterLoop->isIrreducible() && !isLoopHeader(Resolved) && + "unhandled irreducible control flow"); + } + + LLVM_DEBUG(debugSuccessor(" local ")); + Dist.addLocal(Resolved, Weight); + return true; +} + +bool BlockFrequencyInfoImplBase::addLoopSuccessorsToDist( + const LoopData *OuterLoop, LoopData &Loop, Distribution &Dist) { + // Copy the exit map into Dist. + for (const auto &I : Loop.Exits) + if (!addToDist(Dist, OuterLoop, Loop.getHeader(), I.first, + I.second.getMass())) + // Irreducible backedge. + return false; + + return true; +} + +/// Compute the loop scale for a loop. +void BlockFrequencyInfoImplBase::computeLoopScale(LoopData &Loop) { + // Compute loop scale. + LLVM_DEBUG(dbgs() << "compute-loop-scale: " << getLoopName(Loop) << "\n"); + + // Infinite loops need special handling. If we give the back edge an infinite + // mass, they may saturate all the other scales in the function down to 1, + // making all the other region temperatures look exactly the same. Choose an + // arbitrary scale to avoid these issues. + // + // FIXME: An alternate way would be to select a symbolic scale which is later + // replaced to be the maximum of all computed scales plus 1. This would + // appropriately describe the loop as having a large scale, without skewing + // the final frequency computation. + const Scaled64 InfiniteLoopScale(1, 12); + + // LoopScale == 1 / ExitMass + // ExitMass == HeadMass - BackedgeMass + BlockMass TotalBackedgeMass; + for (auto &Mass : Loop.BackedgeMass) + TotalBackedgeMass += Mass; + BlockMass ExitMass = BlockMass::getFull() - TotalBackedgeMass; + + // Block scale stores the inverse of the scale. If this is an infinite loop, + // its exit mass will be zero. In this case, use an arbitrary scale for the + // loop scale. + Loop.Scale = + ExitMass.isEmpty() ? InfiniteLoopScale : ExitMass.toScaled().inverse(); + + LLVM_DEBUG(dbgs() << " - exit-mass = " << ExitMass << " (" + << BlockMass::getFull() << " - " << TotalBackedgeMass + << ")\n" + << " - scale = " << Loop.Scale << "\n"); +} + +/// Package up a loop. +void BlockFrequencyInfoImplBase::packageLoop(LoopData &Loop) { + LLVM_DEBUG(dbgs() << "packaging-loop: " << getLoopName(Loop) << "\n"); + + // Clear the subloop exits to prevent quadratic memory usage. + for (const BlockNode &M : Loop.Nodes) { + if (auto *Loop = Working[M.Index].getPackagedLoop()) + Loop->Exits.clear(); + LLVM_DEBUG(dbgs() << " - node: " << getBlockName(M.Index) << "\n"); + } + Loop.IsPackaged = true; +} + +#ifndef NDEBUG +static void debugAssign(const BlockFrequencyInfoImplBase &BFI, + const DitheringDistributer &D, const BlockNode &T, + const BlockMass &M, const char *Desc) { + dbgs() << " => assign " << M << " (" << D.RemMass << ")"; + if (Desc) + dbgs() << " [" << Desc << "]"; + if (T.isValid()) + dbgs() << " to " << BFI.getBlockName(T); + dbgs() << "\n"; +} +#endif + +void BlockFrequencyInfoImplBase::distributeMass(const BlockNode &Source, + LoopData *OuterLoop, + Distribution &Dist) { + BlockMass Mass = Working[Source.Index].getMass(); + LLVM_DEBUG(dbgs() << " => mass: " << Mass << "\n"); + + // Distribute mass to successors as laid out in Dist. + DitheringDistributer D(Dist, Mass); + + for (const Weight &W : Dist.Weights) { + // Check for a local edge (non-backedge and non-exit). + BlockMass Taken = D.takeMass(W.Amount); + if (W.Type == Weight::Local) { + Working[W.TargetNode.Index].getMass() += Taken; + LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); + continue; + } + + // Backedges and exits only make sense if we're processing a loop. + assert(OuterLoop && "backedge or exit outside of loop"); + + // Check for a backedge. + if (W.Type == Weight::Backedge) { + OuterLoop->BackedgeMass[OuterLoop->getHeaderIndex(W.TargetNode)] += Taken; + LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "back")); + continue; + } + + // This must be an exit. + assert(W.Type == Weight::Exit); + OuterLoop->Exits.push_back(std::make_pair(W.TargetNode, Taken)); + LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "exit")); + } +} + +static void convertFloatingToInteger(BlockFrequencyInfoImplBase &BFI, + const Scaled64 &Min, const Scaled64 &Max) { + // Scale the Factor to a size that creates integers. Ideally, integers would + // be scaled so that Max == UINT64_MAX so that they can be best + // differentiated. However, in the presence of large frequency values, small + // frequencies are scaled down to 1, making it impossible to differentiate + // small, unequal numbers. When the spread between Min and Max frequencies + // fits well within MaxBits, we make the scale be at least 8. + const unsigned MaxBits = 64; + const unsigned SpreadBits = (Max / Min).lg(); + Scaled64 ScalingFactor; + if (SpreadBits <= MaxBits - 3) { + // If the values are small enough, make the scaling factor at least 8 to + // allow distinguishing small values. + ScalingFactor = Min.inverse(); + ScalingFactor <<= 3; + } else { + // If the values need more than MaxBits to be represented, saturate small + // frequency values down to 1 by using a scaling factor that benefits large + // frequency values. + ScalingFactor = Scaled64(1, MaxBits) / Max; + } + + // Translate the floats to integers. + LLVM_DEBUG(dbgs() << "float-to-int: min = " << Min << ", max = " << Max + << ", factor = " << ScalingFactor << "\n"); + for (size_t Index = 0; Index < BFI.Freqs.size(); ++Index) { + Scaled64 Scaled = BFI.Freqs[Index].Scaled * ScalingFactor; + BFI.Freqs[Index].Integer = std::max(UINT64_C(1), Scaled.toInt<uint64_t>()); + LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(Index) << ": float = " + << BFI.Freqs[Index].Scaled << ", scaled = " << Scaled + << ", int = " << BFI.Freqs[Index].Integer << "\n"); + } +} + +/// Unwrap a loop package. +/// +/// Visits all the members of a loop, adjusting their BlockData according to +/// the loop's pseudo-node. +static void unwrapLoop(BlockFrequencyInfoImplBase &BFI, LoopData &Loop) { + LLVM_DEBUG(dbgs() << "unwrap-loop-package: " << BFI.getLoopName(Loop) + << ": mass = " << Loop.Mass << ", scale = " << Loop.Scale + << "\n"); + Loop.Scale *= Loop.Mass.toScaled(); + Loop.IsPackaged = false; + LLVM_DEBUG(dbgs() << " => combined-scale = " << Loop.Scale << "\n"); + + // Propagate the head scale through the loop. Since members are visited in + // RPO, the head scale will be updated by the loop scale first, and then the + // final head scale will be used for updated the rest of the members. + for (const BlockNode &N : Loop.Nodes) { + const auto &Working = BFI.Working[N.Index]; + Scaled64 &F = Working.isAPackage() ? Working.getPackagedLoop()->Scale + : BFI.Freqs[N.Index].Scaled; + Scaled64 New = Loop.Scale * F; + LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(N) << ": " << F << " => " + << New << "\n"); + F = New; + } +} + +void BlockFrequencyInfoImplBase::unwrapLoops() { + // Set initial frequencies from loop-local masses. + for (size_t Index = 0; Index < Working.size(); ++Index) + Freqs[Index].Scaled = Working[Index].Mass.toScaled(); + + for (LoopData &Loop : Loops) + unwrapLoop(*this, Loop); +} + +void BlockFrequencyInfoImplBase::finalizeMetrics() { + // Unwrap loop packages in reverse post-order, tracking min and max + // frequencies. + auto Min = Scaled64::getLargest(); + auto Max = Scaled64::getZero(); + for (size_t Index = 0; Index < Working.size(); ++Index) { + // Update min/max scale. + Min = std::min(Min, Freqs[Index].Scaled); + Max = std::max(Max, Freqs[Index].Scaled); + } + + // Convert to integers. + convertFloatingToInteger(*this, Min, Max); + + // Clean up data structures. + cleanup(*this); + + // Print out the final stats. + LLVM_DEBUG(dump()); +} + +BlockFrequency +BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const { + if (!Node.isValid()) + return 0; + return Freqs[Node.Index].Integer; +} + +Optional<uint64_t> +BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, + const BlockNode &Node, + bool AllowSynthetic) const { + return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency(), + AllowSynthetic); +} + +Optional<uint64_t> +BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, + uint64_t Freq, + bool AllowSynthetic) const { + auto EntryCount = F.getEntryCount(AllowSynthetic); + if (!EntryCount) + return None; + // Use 128 bit APInt to do the arithmetic to avoid overflow. + APInt BlockCount(128, EntryCount.getCount()); + APInt BlockFreq(128, Freq); + APInt EntryFreq(128, getEntryFreq()); + BlockCount *= BlockFreq; + // Rounded division of BlockCount by EntryFreq. Since EntryFreq is unsigned + // lshr by 1 gives EntryFreq/2. + BlockCount = (BlockCount + EntryFreq.lshr(1)).udiv(EntryFreq); + return BlockCount.getLimitedValue(); +} + +bool +BlockFrequencyInfoImplBase::isIrrLoopHeader(const BlockNode &Node) { + if (!Node.isValid()) + return false; + return IsIrrLoopHeader.test(Node.Index); +} + +Scaled64 +BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const { + if (!Node.isValid()) + return Scaled64::getZero(); + return Freqs[Node.Index].Scaled; +} + +void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node, + uint64_t Freq) { + assert(Node.isValid() && "Expected valid node"); + assert(Node.Index < Freqs.size() && "Expected legal index"); + Freqs[Node.Index].Integer = Freq; +} + +std::string +BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { + return {}; +} + +std::string +BlockFrequencyInfoImplBase::getLoopName(const LoopData &Loop) const { + return getBlockName(Loop.getHeader()) + (Loop.isIrreducible() ? "**" : "*"); +} + +raw_ostream & +BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS, + const BlockNode &Node) const { + return OS << getFloatingBlockFreq(Node); +} + +raw_ostream & +BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS, + const BlockFrequency &Freq) const { + Scaled64 Block(Freq.getFrequency(), 0); + Scaled64 Entry(getEntryFreq(), 0); + + return OS << Block / Entry; +} + +void IrreducibleGraph::addNodesInLoop(const BFIBase::LoopData &OuterLoop) { + Start = OuterLoop.getHeader(); + Nodes.reserve(OuterLoop.Nodes.size()); + for (auto N : OuterLoop.Nodes) + addNode(N); + indexNodes(); +} + +void IrreducibleGraph::addNodesInFunction() { + Start = 0; + for (uint32_t Index = 0; Index < BFI.Working.size(); ++Index) + if (!BFI.Working[Index].isPackaged()) + addNode(Index); + indexNodes(); +} + +void IrreducibleGraph::indexNodes() { + for (auto &I : Nodes) + Lookup[I.Node.Index] = &I; +} + +void IrreducibleGraph::addEdge(IrrNode &Irr, const BlockNode &Succ, + const BFIBase::LoopData *OuterLoop) { + if (OuterLoop && OuterLoop->isHeader(Succ)) + return; + auto L = Lookup.find(Succ.Index); + if (L == Lookup.end()) + return; + IrrNode &SuccIrr = *L->second; + Irr.Edges.push_back(&SuccIrr); + SuccIrr.Edges.push_front(&Irr); + ++SuccIrr.NumIn; +} + +namespace llvm { + +template <> struct GraphTraits<IrreducibleGraph> { + using GraphT = bfi_detail::IrreducibleGraph; + using NodeRef = const GraphT::IrrNode *; + using ChildIteratorType = GraphT::IrrNode::iterator; + + static NodeRef getEntryNode(const GraphT &G) { return G.StartIrr; } + static ChildIteratorType child_begin(NodeRef N) { return N->succ_begin(); } + static ChildIteratorType child_end(NodeRef N) { return N->succ_end(); } +}; + +} // end namespace llvm + +/// Find extra irreducible headers. +/// +/// Find entry blocks and other blocks with backedges, which exist when \c G +/// contains irreducible sub-SCCs. +static void findIrreducibleHeaders( + const BlockFrequencyInfoImplBase &BFI, + const IrreducibleGraph &G, + const std::vector<const IrreducibleGraph::IrrNode *> &SCC, + LoopData::NodeList &Headers, LoopData::NodeList &Others) { + // Map from nodes in the SCC to whether it's an entry block. + SmallDenseMap<const IrreducibleGraph::IrrNode *, bool, 8> InSCC; + + // InSCC also acts the set of nodes in the graph. Seed it. + for (const auto *I : SCC) + InSCC[I] = false; + + for (auto I = InSCC.begin(), E = InSCC.end(); I != E; ++I) { + auto &Irr = *I->first; + for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) { + if (InSCC.count(P)) + continue; + + // This is an entry block. + I->second = true; + Headers.push_back(Irr.Node); + LLVM_DEBUG(dbgs() << " => entry = " << BFI.getBlockName(Irr.Node) + << "\n"); + break; + } + } + assert(Headers.size() >= 2 && + "Expected irreducible CFG; -loop-info is likely invalid"); + if (Headers.size() == InSCC.size()) { + // Every block is a header. + llvm::sort(Headers); + return; + } + + // Look for extra headers from irreducible sub-SCCs. + for (const auto &I : InSCC) { + // Entry blocks are already headers. + if (I.second) + continue; + + auto &Irr = *I.first; + for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) { + // Skip forward edges. + if (P->Node < Irr.Node) + continue; + + // Skip predecessors from entry blocks. These can have inverted + // ordering. + if (InSCC.lookup(P)) + continue; + + // Store the extra header. + Headers.push_back(Irr.Node); + LLVM_DEBUG(dbgs() << " => extra = " << BFI.getBlockName(Irr.Node) + << "\n"); + break; + } + if (Headers.back() == Irr.Node) + // Added this as a header. + continue; + + // This is not a header. + Others.push_back(Irr.Node); + LLVM_DEBUG(dbgs() << " => other = " << BFI.getBlockName(Irr.Node) << "\n"); + } + llvm::sort(Headers); + llvm::sort(Others); +} + +static void createIrreducibleLoop( + BlockFrequencyInfoImplBase &BFI, const IrreducibleGraph &G, + LoopData *OuterLoop, std::list<LoopData>::iterator Insert, + const std::vector<const IrreducibleGraph::IrrNode *> &SCC) { + // Translate the SCC into RPO. + LLVM_DEBUG(dbgs() << " - found-scc\n"); + + LoopData::NodeList Headers; + LoopData::NodeList Others; + findIrreducibleHeaders(BFI, G, SCC, Headers, Others); + + auto Loop = BFI.Loops.emplace(Insert, OuterLoop, Headers.begin(), + Headers.end(), Others.begin(), Others.end()); + + // Update loop hierarchy. + for (const auto &N : Loop->Nodes) + if (BFI.Working[N.Index].isLoopHeader()) + BFI.Working[N.Index].Loop->Parent = &*Loop; + else + BFI.Working[N.Index].Loop = &*Loop; +} + +iterator_range<std::list<LoopData>::iterator> +BlockFrequencyInfoImplBase::analyzeIrreducible( + const IrreducibleGraph &G, LoopData *OuterLoop, + std::list<LoopData>::iterator Insert) { + assert((OuterLoop == nullptr) == (Insert == Loops.begin())); + auto Prev = OuterLoop ? std::prev(Insert) : Loops.end(); + + for (auto I = scc_begin(G); !I.isAtEnd(); ++I) { + if (I->size() < 2) + continue; + + // Translate the SCC into RPO. + createIrreducibleLoop(*this, G, OuterLoop, Insert, *I); + } + + if (OuterLoop) + return make_range(std::next(Prev), Insert); + return make_range(Loops.begin(), Insert); +} + +void +BlockFrequencyInfoImplBase::updateLoopWithIrreducible(LoopData &OuterLoop) { + OuterLoop.Exits.clear(); + for (auto &Mass : OuterLoop.BackedgeMass) + Mass = BlockMass::getEmpty(); + auto O = OuterLoop.Nodes.begin() + 1; + for (auto I = O, E = OuterLoop.Nodes.end(); I != E; ++I) + if (!Working[I->Index].isPackaged()) + *O++ = *I; + OuterLoop.Nodes.erase(O, OuterLoop.Nodes.end()); +} + +void BlockFrequencyInfoImplBase::adjustLoopHeaderMass(LoopData &Loop) { + assert(Loop.isIrreducible() && "this only makes sense on irreducible loops"); + + // Since the loop has more than one header block, the mass flowing back into + // each header will be different. Adjust the mass in each header loop to + // reflect the masses flowing through back edges. + // + // To do this, we distribute the initial mass using the backedge masses + // as weights for the distribution. + BlockMass LoopMass = BlockMass::getFull(); + Distribution Dist; + + LLVM_DEBUG(dbgs() << "adjust-loop-header-mass:\n"); + for (uint32_t H = 0; H < Loop.NumHeaders; ++H) { + auto &HeaderNode = Loop.Nodes[H]; + auto &BackedgeMass = Loop.BackedgeMass[Loop.getHeaderIndex(HeaderNode)]; + LLVM_DEBUG(dbgs() << " - Add back edge mass for node " + << getBlockName(HeaderNode) << ": " << BackedgeMass + << "\n"); + if (BackedgeMass.getMass() > 0) + Dist.addLocal(HeaderNode, BackedgeMass.getMass()); + else + LLVM_DEBUG(dbgs() << " Nothing added. Back edge mass is zero\n"); + } + + DitheringDistributer D(Dist, LoopMass); + + LLVM_DEBUG(dbgs() << " Distribute loop mass " << LoopMass + << " to headers using above weights\n"); + for (const Weight &W : Dist.Weights) { + BlockMass Taken = D.takeMass(W.Amount); + assert(W.Type == Weight::Local && "all weights should be local"); + Working[W.TargetNode.Index].getMass() = Taken; + LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); + } +} + +void BlockFrequencyInfoImplBase::distributeIrrLoopHeaderMass(Distribution &Dist) { + BlockMass LoopMass = BlockMass::getFull(); + DitheringDistributer D(Dist, LoopMass); + for (const Weight &W : Dist.Weights) { + BlockMass Taken = D.takeMass(W.Amount); + assert(W.Type == Weight::Local && "all weights should be local"); + Working[W.TargetNode.Index].getMass() = Taken; + LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); + } +} diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp new file mode 100644 index 000000000000..a06ee096d54c --- /dev/null +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -0,0 +1,1057 @@ +//===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "branch-prob" + +static cl::opt<bool> PrintBranchProb( + "print-bpi", cl::init(false), cl::Hidden, + cl::desc("Print the branch probability info.")); + +cl::opt<std::string> PrintBranchProbFuncName( + "print-bpi-func-name", cl::Hidden, + cl::desc("The option to specify the name of the function " + "whose branch probability info is printed.")); + +INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob", + "Branch Probability Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob", + "Branch Probability Analysis", false, true) + +char BranchProbabilityInfoWrapperPass::ID = 0; + +// Weights are for internal use only. They are used by heuristics to help to +// estimate edges' probability. Example: +// +// Using "Loop Branch Heuristics" we predict weights of edges for the +// block BB2. +// ... +// | +// V +// BB1<-+ +// | | +// | | (Weight = 124) +// V | +// BB2--+ +// | +// | (Weight = 4) +// V +// BB3 +// +// Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875 +// Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125 +static const uint32_t LBH_TAKEN_WEIGHT = 124; +static const uint32_t LBH_NONTAKEN_WEIGHT = 4; +// Unlikely edges within a loop are half as likely as other edges +static const uint32_t LBH_UNLIKELY_WEIGHT = 62; + +/// Unreachable-terminating branch taken probability. +/// +/// This is the probability for a branch being taken to a block that terminates +/// (eventually) in unreachable. These are predicted as unlikely as possible. +/// All reachable probability will equally share the remaining part. +static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1); + +/// Weight for a branch taken going into a cold block. +/// +/// This is the weight for a branch taken toward a block marked +/// cold. A block is marked cold if it's postdominated by a +/// block containing a call to a cold function. Cold functions +/// are those marked with attribute 'cold'. +static const uint32_t CC_TAKEN_WEIGHT = 4; + +/// Weight for a branch not-taken into a cold block. +/// +/// This is the weight for a branch not taken toward a block marked +/// cold. +static const uint32_t CC_NONTAKEN_WEIGHT = 64; + +static const uint32_t PH_TAKEN_WEIGHT = 20; +static const uint32_t PH_NONTAKEN_WEIGHT = 12; + +static const uint32_t ZH_TAKEN_WEIGHT = 20; +static const uint32_t ZH_NONTAKEN_WEIGHT = 12; + +static const uint32_t FPH_TAKEN_WEIGHT = 20; +static const uint32_t FPH_NONTAKEN_WEIGHT = 12; + +/// This is the probability for an ordered floating point comparison. +static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1; +/// This is the probability for an unordered floating point comparison, it means +/// one or two of the operands are NaN. Usually it is used to test for an +/// exceptional case, so the result is unlikely. +static const uint32_t FPH_UNO_WEIGHT = 1; + +/// Invoke-terminating normal branch taken weight +/// +/// This is the weight for branching to the normal destination of an invoke +/// instruction. We expect this to happen most of the time. Set the weight to an +/// absurdly high value so that nested loops subsume it. +static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1; + +/// Invoke-terminating normal branch not-taken weight. +/// +/// This is the weight for branching to the unwind destination of an invoke +/// instruction. This is essentially never taken. +static const uint32_t IH_NONTAKEN_WEIGHT = 1; + +/// Add \p BB to PostDominatedByUnreachable set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) { + const Instruction *TI = BB->getTerminator(); + if (TI->getNumSuccessors() == 0) { + if (isa<UnreachableInst>(TI) || + // If this block is terminated by a call to + // @llvm.experimental.deoptimize then treat it like an unreachable since + // the @llvm.experimental.deoptimize call is expected to practically + // never execute. + BB->getTerminatingDeoptimizeCall()) + PostDominatedByUnreachable.insert(BB); + return; + } + + // If the terminator is an InvokeInst, check only the normal destination block + // as the unwind edge of InvokeInst is also very unlikely taken. + if (auto *II = dyn_cast<InvokeInst>(TI)) { + if (PostDominatedByUnreachable.count(II->getNormalDest())) + PostDominatedByUnreachable.insert(BB); + return; + } + + for (auto *I : successors(BB)) + // If any of successor is not post dominated then BB is also not. + if (!PostDominatedByUnreachable.count(I)) + return; + + PostDominatedByUnreachable.insert(BB); +} + +/// Add \p BB to PostDominatedByColdCall set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) { + assert(!PostDominatedByColdCall.count(BB)); + const Instruction *TI = BB->getTerminator(); + if (TI->getNumSuccessors() == 0) + return; + + // If all of successor are post dominated then BB is also done. + if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) { + return PostDominatedByColdCall.count(SuccBB); + })) { + PostDominatedByColdCall.insert(BB); + return; + } + + // If the terminator is an InvokeInst, check only the normal destination + // block as the unwind edge of InvokeInst is also very unlikely taken. + if (auto *II = dyn_cast<InvokeInst>(TI)) + if (PostDominatedByColdCall.count(II->getNormalDest())) { + PostDominatedByColdCall.insert(BB); + return; + } + + // Otherwise, if the block itself contains a cold function, add it to the + // set of blocks post-dominated by a cold call. + for (auto &I : *BB) + if (const CallInst *CI = dyn_cast<CallInst>(&I)) + if (CI->hasFnAttr(Attribute::Cold)) { + PostDominatedByColdCall.insert(BB); + return; + } +} + +/// Calculate edge weights for successors lead to unreachable. +/// +/// Predict that a successor which leads necessarily to an +/// unreachable-terminated block as extremely unlikely. +bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { + const Instruction *TI = BB->getTerminator(); + (void) TI; + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); + assert(!isa<InvokeInst>(TI) && + "Invokes should have already been handled by calcInvokeHeuristics"); + + SmallVector<unsigned, 4> UnreachableEdges; + SmallVector<unsigned, 4> ReachableEdges; + + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) + if (PostDominatedByUnreachable.count(*I)) + UnreachableEdges.push_back(I.getSuccessorIndex()); + else + ReachableEdges.push_back(I.getSuccessorIndex()); + + // Skip probabilities if all were reachable. + if (UnreachableEdges.empty()) + return false; + + if (ReachableEdges.empty()) { + BranchProbability Prob(1, UnreachableEdges.size()); + for (unsigned SuccIdx : UnreachableEdges) + setEdgeProbability(BB, SuccIdx, Prob); + return true; + } + + auto UnreachableProb = UR_TAKEN_PROB; + auto ReachableProb = + (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) / + ReachableEdges.size(); + + for (unsigned SuccIdx : UnreachableEdges) + setEdgeProbability(BB, SuccIdx, UnreachableProb); + for (unsigned SuccIdx : ReachableEdges) + setEdgeProbability(BB, SuccIdx, ReachableProb); + + return true; +} + +// Propagate existing explicit probabilities from either profile data or +// 'expect' intrinsic processing. Examine metadata against unreachable +// heuristic. The probability of the edge coming to unreachable block is +// set to min of metadata and unreachable heuristic. +bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { + const Instruction *TI = BB->getTerminator(); + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); + if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI))) + return false; + + MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + if (!WeightsNode) + return false; + + // Check that the number of successors is manageable. + assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors"); + + // Ensure there are weights for all of the successors. Note that the first + // operand to the metadata node is a name, not a weight. + if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1) + return false; + + // Build up the final weights that will be used in a temporary buffer. + // Compute the sum of all weights to later decide whether they need to + // be scaled to fit in 32 bits. + uint64_t WeightSum = 0; + SmallVector<uint32_t, 2> Weights; + SmallVector<unsigned, 2> UnreachableIdxs; + SmallVector<unsigned, 2> ReachableIdxs; + Weights.reserve(TI->getNumSuccessors()); + for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) { + ConstantInt *Weight = + mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i)); + if (!Weight) + return false; + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights.push_back(Weight->getZExtValue()); + WeightSum += Weights.back(); + if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1))) + UnreachableIdxs.push_back(i - 1); + else + ReachableIdxs.push_back(i - 1); + } + assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); + + // If the sum of weights does not fit in 32 bits, scale every weight down + // accordingly. + uint64_t ScalingFactor = + (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1; + + if (ScalingFactor > 1) { + WeightSum = 0; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + Weights[i] /= ScalingFactor; + WeightSum += Weights[i]; + } + } + assert(WeightSum <= UINT32_MAX && + "Expected weights to scale down to 32 bits"); + + if (WeightSum == 0 || ReachableIdxs.size() == 0) { + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + Weights[i] = 1; + WeightSum = TI->getNumSuccessors(); + } + + // Set the probability. + SmallVector<BranchProbability, 2> BP; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) }); + + // Examine the metadata against unreachable heuristic. + // If the unreachable heuristic is more strong then we use it for this edge. + if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) { + auto ToDistribute = BranchProbability::getZero(); + auto UnreachableProb = UR_TAKEN_PROB; + for (auto i : UnreachableIdxs) + if (UnreachableProb < BP[i]) { + ToDistribute += BP[i] - UnreachableProb; + BP[i] = UnreachableProb; + } + + // If we modified the probability of some edges then we must distribute + // the difference between reachable blocks. + if (ToDistribute > BranchProbability::getZero()) { + BranchProbability PerEdge = ToDistribute / ReachableIdxs.size(); + for (auto i : ReachableIdxs) + BP[i] += PerEdge; + } + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + setEdgeProbability(BB, i, BP[i]); + + return true; +} + +/// Calculate edge weights for edges leading to cold blocks. +/// +/// A cold block is one post-dominated by a block with a call to a +/// cold function. Those edges are unlikely to be taken, so we give +/// them relatively low weight. +/// +/// Return true if we could compute the weights for cold edges. +/// Return false, otherwise. +bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { + const Instruction *TI = BB->getTerminator(); + (void) TI; + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); + assert(!isa<InvokeInst>(TI) && + "Invokes should have already been handled by calcInvokeHeuristics"); + + // Determine which successors are post-dominated by a cold block. + SmallVector<unsigned, 4> ColdEdges; + SmallVector<unsigned, 4> NormalEdges; + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) + if (PostDominatedByColdCall.count(*I)) + ColdEdges.push_back(I.getSuccessorIndex()); + else + NormalEdges.push_back(I.getSuccessorIndex()); + + // Skip probabilities if no cold edges. + if (ColdEdges.empty()) + return false; + + if (NormalEdges.empty()) { + BranchProbability Prob(1, ColdEdges.size()); + for (unsigned SuccIdx : ColdEdges) + setEdgeProbability(BB, SuccIdx, Prob); + return true; + } + + auto ColdProb = BranchProbability::getBranchProbability( + CC_TAKEN_WEIGHT, + (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size())); + auto NormalProb = BranchProbability::getBranchProbability( + CC_NONTAKEN_WEIGHT, + (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size())); + + for (unsigned SuccIdx : ColdEdges) + setEdgeProbability(BB, SuccIdx, ColdProb); + for (unsigned SuccIdx : NormalEdges) + setEdgeProbability(BB, SuccIdx, NormalProb); + + return true; +} + +// Calculate Edge Weights using "Pointer Heuristics". Predict a comparison +// between two pointer or pointer and NULL will fail. +bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + Value *Cond = BI->getCondition(); + ICmpInst *CI = dyn_cast<ICmpInst>(Cond); + if (!CI || !CI->isEquality()) + return false; + + Value *LHS = CI->getOperand(0); + + if (!LHS->getType()->isPointerTy()) + return false; + + assert(CI->getOperand(1)->getType()->isPointerTy()); + + // p != 0 -> isProb = true + // p == 0 -> isProb = false + // p != q -> isProb = true + // p == q -> isProb = false; + unsigned TakenIdx = 0, NonTakenIdx = 1; + bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE; + if (!isProb) + std::swap(TakenIdx, NonTakenIdx); + + BranchProbability TakenProb(PH_TAKEN_WEIGHT, + PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT); + setEdgeProbability(BB, TakenIdx, TakenProb); + setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); + return true; +} + +static int getSCCNum(const BasicBlock *BB, + const BranchProbabilityInfo::SccInfo &SccI) { + auto SccIt = SccI.SccNums.find(BB); + if (SccIt == SccI.SccNums.end()) + return -1; + return SccIt->second; +} + +// Consider any block that is an entry point to the SCC as a header. +static bool isSCCHeader(const BasicBlock *BB, int SccNum, + BranchProbabilityInfo::SccInfo &SccI) { + assert(getSCCNum(BB, SccI) == SccNum); + + // Lazily compute the set of headers for a given SCC and cache the results + // in the SccHeaderMap. + if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum)) + SccI.SccHeaders.resize(SccNum + 1); + auto &HeaderMap = SccI.SccHeaders[SccNum]; + bool Inserted; + BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt; + std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false)); + if (Inserted) { + bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)), + [&](const BasicBlock *Pred) { + return getSCCNum(Pred, SccI) != SccNum; + }); + HeaderMapIt->second = IsHeader; + return IsHeader; + } else + return HeaderMapIt->second; +} + +// Compute the unlikely successors to the block BB in the loop L, specifically +// those that are unlikely because this is a loop, and add them to the +// UnlikelyBlocks set. +static void +computeUnlikelySuccessors(const BasicBlock *BB, Loop *L, + SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) { + // Sometimes in a loop we have a branch whose condition is made false by + // taking it. This is typically something like + // int n = 0; + // while (...) { + // if (++n >= MAX) { + // n = 0; + // } + // } + // In this sort of situation taking the branch means that at the very least it + // won't be taken again in the next iteration of the loop, so we should + // consider it less likely than a typical branch. + // + // We detect this by looking back through the graph of PHI nodes that sets the + // value that the condition depends on, and seeing if we can reach a successor + // block which can be determined to make the condition false. + // + // FIXME: We currently consider unlikely blocks to be half as likely as other + // blocks, but if we consider the example above the likelyhood is actually + // 1/MAX. We could therefore be more precise in how unlikely we consider + // blocks to be, but it would require more careful examination of the form + // of the comparison expression. + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return; + + // Check if the branch is based on an instruction compared with a constant + CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); + if (!CI || !isa<Instruction>(CI->getOperand(0)) || + !isa<Constant>(CI->getOperand(1))) + return; + + // Either the instruction must be a PHI, or a chain of operations involving + // constants that ends in a PHI which we can then collapse into a single value + // if the PHI value is known. + Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0)); + PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS); + Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1)); + // Collect the instructions until we hit a PHI + SmallVector<BinaryOperator *, 1> InstChain; + while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) && + isa<Constant>(CmpLHS->getOperand(1))) { + // Stop if the chain extends outside of the loop + if (!L->contains(CmpLHS)) + return; + InstChain.push_back(cast<BinaryOperator>(CmpLHS)); + CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0)); + if (CmpLHS) + CmpPHI = dyn_cast<PHINode>(CmpLHS); + } + if (!CmpPHI || !L->contains(CmpPHI)) + return; + + // Trace the phi node to find all values that come from successors of BB + SmallPtrSet<PHINode*, 8> VisitedInsts; + SmallVector<PHINode*, 8> WorkList; + WorkList.push_back(CmpPHI); + VisitedInsts.insert(CmpPHI); + while (!WorkList.empty()) { + PHINode *P = WorkList.back(); + WorkList.pop_back(); + for (BasicBlock *B : P->blocks()) { + // Skip blocks that aren't part of the loop + if (!L->contains(B)) + continue; + Value *V = P->getIncomingValueForBlock(B); + // If the source is a PHI add it to the work list if we haven't + // already visited it. + if (PHINode *PN = dyn_cast<PHINode>(V)) { + if (VisitedInsts.insert(PN).second) + WorkList.push_back(PN); + continue; + } + // If this incoming value is a constant and B is a successor of BB, then + // we can constant-evaluate the compare to see if it makes the branch be + // taken or not. + Constant *CmpLHSConst = dyn_cast<Constant>(V); + if (!CmpLHSConst || + std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB)) + continue; + // First collapse InstChain + for (Instruction *I : llvm::reverse(InstChain)) { + CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst, + cast<Constant>(I->getOperand(1)), true); + if (!CmpLHSConst) + break; + } + if (!CmpLHSConst) + continue; + // Now constant-evaluate the compare + Constant *Result = ConstantExpr::getCompare(CI->getPredicate(), + CmpLHSConst, CmpConst, true); + // If the result means we don't branch to the block then that block is + // unlikely. + if (Result && + ((Result->isZeroValue() && B == BI->getSuccessor(0)) || + (Result->isOneValue() && B == BI->getSuccessor(1)))) + UnlikelyBlocks.insert(B); + } + } +} + +// Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges +// as taken, exiting edges as not-taken. +bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB, + const LoopInfo &LI, + SccInfo &SccI) { + int SccNum; + Loop *L = LI.getLoopFor(BB); + if (!L) { + SccNum = getSCCNum(BB, SccI); + if (SccNum < 0) + return false; + } + + SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks; + if (L) + computeUnlikelySuccessors(BB, L, UnlikelyBlocks); + + SmallVector<unsigned, 8> BackEdges; + SmallVector<unsigned, 8> ExitingEdges; + SmallVector<unsigned, 8> InEdges; // Edges from header to the loop. + SmallVector<unsigned, 8> UnlikelyEdges; + + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch + // irreducible loops. + if (L) { + if (UnlikelyBlocks.count(*I) != 0) + UnlikelyEdges.push_back(I.getSuccessorIndex()); + else if (!L->contains(*I)) + ExitingEdges.push_back(I.getSuccessorIndex()); + else if (L->getHeader() == *I) + BackEdges.push_back(I.getSuccessorIndex()); + else + InEdges.push_back(I.getSuccessorIndex()); + } else { + if (getSCCNum(*I, SccI) != SccNum) + ExitingEdges.push_back(I.getSuccessorIndex()); + else if (isSCCHeader(*I, SccNum, SccI)) + BackEdges.push_back(I.getSuccessorIndex()); + else + InEdges.push_back(I.getSuccessorIndex()); + } + } + + if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty()) + return false; + + // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and + // normalize them so that they sum up to one. + unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + + (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + + (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) + + (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT); + + if (uint32_t numBackEdges = BackEdges.size()) { + BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom); + auto Prob = TakenProb / numBackEdges; + for (unsigned SuccIdx : BackEdges) + setEdgeProbability(BB, SuccIdx, Prob); + } + + if (uint32_t numInEdges = InEdges.size()) { + BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom); + auto Prob = TakenProb / numInEdges; + for (unsigned SuccIdx : InEdges) + setEdgeProbability(BB, SuccIdx, Prob); + } + + if (uint32_t numExitingEdges = ExitingEdges.size()) { + BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT, + Denom); + auto Prob = NotTakenProb / numExitingEdges; + for (unsigned SuccIdx : ExitingEdges) + setEdgeProbability(BB, SuccIdx, Prob); + } + + if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) { + BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT, + Denom); + auto Prob = UnlikelyProb / numUnlikelyEdges; + for (unsigned SuccIdx : UnlikelyEdges) + setEdgeProbability(BB, SuccIdx, Prob); + } + + return true; +} + +bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB, + const TargetLibraryInfo *TLI) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + Value *Cond = BI->getCondition(); + ICmpInst *CI = dyn_cast<ICmpInst>(Cond); + if (!CI) + return false; + + auto GetConstantInt = [](Value *V) { + if (auto *I = dyn_cast<BitCastInst>(V)) + return dyn_cast<ConstantInt>(I->getOperand(0)); + return dyn_cast<ConstantInt>(V); + }; + + Value *RHS = CI->getOperand(1); + ConstantInt *CV = GetConstantInt(RHS); + if (!CV) + return false; + + // If the LHS is the result of AND'ing a value with a single bit bitmask, + // we don't have information about probabilities. + if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0))) + if (LHS->getOpcode() == Instruction::And) + if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) + if (AndRHS->getValue().isPowerOf2()) + return false; + + // Check if the LHS is the return value of a library function + LibFunc Func = NumLibFuncs; + if (TLI) + if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0))) + if (Function *CalledFn = Call->getCalledFunction()) + TLI->getLibFunc(*CalledFn, Func); + + bool isProb; + if (Func == LibFunc_strcasecmp || + Func == LibFunc_strcmp || + Func == LibFunc_strncasecmp || + Func == LibFunc_strncmp || + Func == LibFunc_memcmp) { + // strcmp and similar functions return zero, negative, or positive, if the + // first string is equal, less, or greater than the second. We consider it + // likely that the strings are not equal, so a comparison with zero is + // probably false, but also a comparison with any other number is also + // probably false given that what exactly is returned for nonzero values is + // not specified. Any kind of comparison other than equality we know + // nothing about. + switch (CI->getPredicate()) { + case CmpInst::ICMP_EQ: + isProb = false; + break; + case CmpInst::ICMP_NE: + isProb = true; + break; + default: + return false; + } + } else if (CV->isZero()) { + switch (CI->getPredicate()) { + case CmpInst::ICMP_EQ: + // X == 0 -> Unlikely + isProb = false; + break; + case CmpInst::ICMP_NE: + // X != 0 -> Likely + isProb = true; + break; + case CmpInst::ICMP_SLT: + // X < 0 -> Unlikely + isProb = false; + break; + case CmpInst::ICMP_SGT: + // X > 0 -> Likely + isProb = true; + break; + default: + return false; + } + } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) { + // InstCombine canonicalizes X <= 0 into X < 1. + // X <= 0 -> Unlikely + isProb = false; + } else if (CV->isMinusOne()) { + switch (CI->getPredicate()) { + case CmpInst::ICMP_EQ: + // X == -1 -> Unlikely + isProb = false; + break; + case CmpInst::ICMP_NE: + // X != -1 -> Likely + isProb = true; + break; + case CmpInst::ICMP_SGT: + // InstCombine canonicalizes X >= 0 into X > -1. + // X >= 0 -> Likely + isProb = true; + break; + default: + return false; + } + } else { + return false; + } + + unsigned TakenIdx = 0, NonTakenIdx = 1; + + if (!isProb) + std::swap(TakenIdx, NonTakenIdx); + + BranchProbability TakenProb(ZH_TAKEN_WEIGHT, + ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT); + setEdgeProbability(BB, TakenIdx, TakenProb); + setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); + return true; +} + +bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + Value *Cond = BI->getCondition(); + FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond); + if (!FCmp) + return false; + + uint32_t TakenWeight = FPH_TAKEN_WEIGHT; + uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT; + bool isProb; + if (FCmp->isEquality()) { + // f1 == f2 -> Unlikely + // f1 != f2 -> Likely + isProb = !FCmp->isTrueWhenEqual(); + } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) { + // !isnan -> Likely + isProb = true; + TakenWeight = FPH_ORD_WEIGHT; + NontakenWeight = FPH_UNO_WEIGHT; + } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) { + // isnan -> Unlikely + isProb = false; + TakenWeight = FPH_ORD_WEIGHT; + NontakenWeight = FPH_UNO_WEIGHT; + } else { + return false; + } + + unsigned TakenIdx = 0, NonTakenIdx = 1; + + if (!isProb) + std::swap(TakenIdx, NonTakenIdx); + + BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight); + setEdgeProbability(BB, TakenIdx, TakenProb); + setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); + return true; +} + +bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) { + const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()); + if (!II) + return false; + + BranchProbability TakenProb(IH_TAKEN_WEIGHT, + IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT); + setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb); + setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl()); + return true; +} + +void BranchProbabilityInfo::releaseMemory() { + Probs.clear(); +} + +void BranchProbabilityInfo::print(raw_ostream &OS) const { + OS << "---- Branch Probabilities ----\n"; + // We print the probabilities from the last function the analysis ran over, + // or the function it is currently running over. + assert(LastF && "Cannot print prior to running over a function"); + for (const auto &BI : *LastF) { + for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE; + ++SI) { + printEdgeProbability(OS << " ", &BI, *SI); + } + } +} + +bool BranchProbabilityInfo:: +isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const { + // Hot probability is at least 4/5 = 80% + // FIXME: Compare against a static "hot" BranchProbability. + return getEdgeProbability(Src, Dst) > BranchProbability(4, 5); +} + +const BasicBlock * +BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const { + auto MaxProb = BranchProbability::getZero(); + const BasicBlock *MaxSucc = nullptr; + + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + const BasicBlock *Succ = *I; + auto Prob = getEdgeProbability(BB, Succ); + if (Prob > MaxProb) { + MaxProb = Prob; + MaxSucc = Succ; + } + } + + // Hot probability is at least 4/5 = 80% + if (MaxProb > BranchProbability(4, 5)) + return MaxSucc; + + return nullptr; +} + +/// Get the raw edge probability for the edge. If can't find it, return a +/// default probability 1/N where N is the number of successors. Here an edge is +/// specified using PredBlock and an +/// index to the successors. +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, + unsigned IndexInSuccessors) const { + auto I = Probs.find(std::make_pair(Src, IndexInSuccessors)); + + if (I != Probs.end()) + return I->second; + + return {1, static_cast<uint32_t>(succ_size(Src))}; +} + +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, + succ_const_iterator Dst) const { + return getEdgeProbability(Src, Dst.getSuccessorIndex()); +} + +/// Get the raw edge probability calculated for the block pair. This returns the +/// sum of all raw edge probabilities from Src to Dst. +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, + const BasicBlock *Dst) const { + auto Prob = BranchProbability::getZero(); + bool FoundProb = false; + for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I) + if (*I == Dst) { + auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex())); + if (MapI != Probs.end()) { + FoundProb = true; + Prob += MapI->second; + } + } + uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src)); + return FoundProb ? Prob : BranchProbability(1, succ_num); +} + +/// Set the edge probability for a given edge specified by PredBlock and an +/// index to the successors. +void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src, + unsigned IndexInSuccessors, + BranchProbability Prob) { + Probs[std::make_pair(Src, IndexInSuccessors)] = Prob; + Handles.insert(BasicBlockCallbackVH(Src, this)); + LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " + << IndexInSuccessors << " successor probability to " << Prob + << "\n"); +} + +raw_ostream & +BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, + const BasicBlock *Src, + const BasicBlock *Dst) const { + const BranchProbability Prob = getEdgeProbability(Src, Dst); + OS << "edge " << Src->getName() << " -> " << Dst->getName() + << " probability is " << Prob + << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n"); + + return OS; +} + +void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) { + for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) { + auto Key = I->first; + if (Key.first == BB) + Probs.erase(Key); + } +} + +void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI, + const TargetLibraryInfo *TLI) { + LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName() + << " ----\n\n"); + LastF = &F; // Store the last function we ran on for printing. + assert(PostDominatedByUnreachable.empty()); + assert(PostDominatedByColdCall.empty()); + + // Record SCC numbers of blocks in the CFG to identify irreducible loops. + // FIXME: We could only calculate this if the CFG is known to be irreducible + // (perhaps cache this info in LoopInfo if we can easily calculate it there?). + int SccNum = 0; + SccInfo SccI; + for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd(); + ++It, ++SccNum) { + // Ignore single-block SCCs since they either aren't loops or LoopInfo will + // catch them. + const std::vector<const BasicBlock *> &Scc = *It; + if (Scc.size() == 1) + continue; + + LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":"); + for (auto *BB : Scc) { + LLVM_DEBUG(dbgs() << " " << BB->getName()); + SccI.SccNums[BB] = SccNum; + } + LLVM_DEBUG(dbgs() << "\n"); + } + + // Walk the basic blocks in post-order so that we can build up state about + // the successors of a block iteratively. + for (auto BB : post_order(&F.getEntryBlock())) { + LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName() + << "\n"); + updatePostDominatedByUnreachable(BB); + updatePostDominatedByColdCall(BB); + // If there is no at least two successors, no sense to set probability. + if (BB->getTerminator()->getNumSuccessors() < 2) + continue; + if (calcMetadataWeights(BB)) + continue; + if (calcInvokeHeuristics(BB)) + continue; + if (calcUnreachableHeuristics(BB)) + continue; + if (calcColdCallHeuristics(BB)) + continue; + if (calcLoopBranchHeuristics(BB, LI, SccI)) + continue; + if (calcPointerHeuristics(BB)) + continue; + if (calcZeroHeuristics(BB, TLI)) + continue; + if (calcFloatingPointHeuristics(BB)) + continue; + } + + PostDominatedByUnreachable.clear(); + PostDominatedByColdCall.clear(); + + if (PrintBranchProb && + (PrintBranchProbFuncName.empty() || + F.getName().equals(PrintBranchProbFuncName))) { + print(dbgs()); + } +} + +void BranchProbabilityInfoWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + // We require DT so it's available when LI is available. The LI updating code + // asserts that DT is also present so if we don't make sure that we have DT + // here, that assert will trigger. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesAll(); +} + +bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) { + const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + BPI.calculate(F, LI, &TLI); + return false; +} + +void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); } + +void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS, + const Module *) const { + BPI.print(OS); +} + +AnalysisKey BranchProbabilityAnalysis::Key; +BranchProbabilityInfo +BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + BranchProbabilityInfo BPI; + BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F)); + return BPI; +} + +PreservedAnalyses +BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "Printing analysis results of BPI for function " + << "'" << F.getName() << "':" + << "\n"; + AM.getResult<BranchProbabilityAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/CFG.cpp b/llvm/lib/Analysis/CFG.cpp new file mode 100644 index 000000000000..8215b4ecbb03 --- /dev/null +++ b/llvm/lib/Analysis/CFG.cpp @@ -0,0 +1,279 @@ +//===-- CFG.cpp - BasicBlock analysis --------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This family of functions performs analyses on basic blocks, and instructions +// contained within basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CFG.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +/// FindFunctionBackedges - Analyze the specified function to find all of the +/// loop backedges in the function and return them. This is a relatively cheap +/// (compared to computing dominators and loop info) analysis. +/// +/// The output is added to Result, as pairs of <from,to> edge info. +void llvm::FindFunctionBackedges(const Function &F, + SmallVectorImpl<std::pair<const BasicBlock*,const BasicBlock*> > &Result) { + const BasicBlock *BB = &F.getEntryBlock(); + if (succ_empty(BB)) + return; + + SmallPtrSet<const BasicBlock*, 8> Visited; + SmallVector<std::pair<const BasicBlock*, succ_const_iterator>, 8> VisitStack; + SmallPtrSet<const BasicBlock*, 8> InStack; + + Visited.insert(BB); + VisitStack.push_back(std::make_pair(BB, succ_begin(BB))); + InStack.insert(BB); + do { + std::pair<const BasicBlock*, succ_const_iterator> &Top = VisitStack.back(); + const BasicBlock *ParentBB = Top.first; + succ_const_iterator &I = Top.second; + + bool FoundNew = false; + while (I != succ_end(ParentBB)) { + BB = *I++; + if (Visited.insert(BB).second) { + FoundNew = true; + break; + } + // Successor is in VisitStack, it's a back edge. + if (InStack.count(BB)) + Result.push_back(std::make_pair(ParentBB, BB)); + } + + if (FoundNew) { + // Go down one level if there is a unvisited successor. + InStack.insert(BB); + VisitStack.push_back(std::make_pair(BB, succ_begin(BB))); + } else { + // Go up one level. + InStack.erase(VisitStack.pop_back_val().first); + } + } while (!VisitStack.empty()); +} + +/// GetSuccessorNumber - Search for the specified successor of basic block BB +/// and return its position in the terminator instruction's list of +/// successors. It is an error to call this with a block that is not a +/// successor. +unsigned llvm::GetSuccessorNumber(const BasicBlock *BB, + const BasicBlock *Succ) { + const Instruction *Term = BB->getTerminator(); +#ifndef NDEBUG + unsigned e = Term->getNumSuccessors(); +#endif + for (unsigned i = 0; ; ++i) { + assert(i != e && "Didn't find edge?"); + if (Term->getSuccessor(i) == Succ) + return i; + } +} + +/// isCriticalEdge - Return true if the specified edge is a critical edge. +/// Critical edges are edges from a block with multiple successors to a block +/// with multiple predecessors. +bool llvm::isCriticalEdge(const Instruction *TI, unsigned SuccNum, + bool AllowIdenticalEdges) { + assert(SuccNum < TI->getNumSuccessors() && "Illegal edge specification!"); + return isCriticalEdge(TI, TI->getSuccessor(SuccNum), AllowIdenticalEdges); +} + +bool llvm::isCriticalEdge(const Instruction *TI, const BasicBlock *Dest, + bool AllowIdenticalEdges) { + assert(TI->isTerminator() && "Must be a terminator to have successors!"); + if (TI->getNumSuccessors() == 1) return false; + + assert(find(predecessors(Dest), TI->getParent()) != pred_end(Dest) && + "No edge between TI's block and Dest."); + + const_pred_iterator I = pred_begin(Dest), E = pred_end(Dest); + + // If there is more than one predecessor, this is a critical edge... + assert(I != E && "No preds, but we have an edge to the block?"); + const BasicBlock *FirstPred = *I; + ++I; // Skip one edge due to the incoming arc from TI. + if (!AllowIdenticalEdges) + return I != E; + + // If AllowIdenticalEdges is true, then we allow this edge to be considered + // non-critical iff all preds come from TI's block. + for (; I != E; ++I) + if (*I != FirstPred) + return true; + return false; +} + +// LoopInfo contains a mapping from basic block to the innermost loop. Find +// the outermost loop in the loop nest that contains BB. +static const Loop *getOutermostLoop(const LoopInfo *LI, const BasicBlock *BB) { + const Loop *L = LI->getLoopFor(BB); + if (L) { + while (const Loop *Parent = L->getParentLoop()) + L = Parent; + } + return L; +} + +bool llvm::isPotentiallyReachableFromMany( + SmallVectorImpl<BasicBlock *> &Worklist, BasicBlock *StopBB, + const SmallPtrSetImpl<BasicBlock *> *ExclusionSet, const DominatorTree *DT, + const LoopInfo *LI) { + // When the stop block is unreachable, it's dominated from everywhere, + // regardless of whether there's a path between the two blocks. + if (DT && !DT->isReachableFromEntry(StopBB)) + DT = nullptr; + + // We can't skip directly from a block that dominates the stop block if the + // exclusion block is potentially in between. + if (ExclusionSet && !ExclusionSet->empty()) + DT = nullptr; + + // Normally any block in a loop is reachable from any other block in a loop, + // however excluded blocks might partition the body of a loop to make that + // untrue. + SmallPtrSet<const Loop *, 8> LoopsWithHoles; + if (LI && ExclusionSet) { + for (auto BB : *ExclusionSet) { + if (const Loop *L = getOutermostLoop(LI, BB)) + LoopsWithHoles.insert(L); + } + } + + const Loop *StopLoop = LI ? getOutermostLoop(LI, StopBB) : nullptr; + + // Limit the number of blocks we visit. The goal is to avoid run-away compile + // times on large CFGs without hampering sensible code. Arbitrarily chosen. + unsigned Limit = 32; + SmallPtrSet<const BasicBlock*, 32> Visited; + do { + BasicBlock *BB = Worklist.pop_back_val(); + if (!Visited.insert(BB).second) + continue; + if (BB == StopBB) + return true; + if (ExclusionSet && ExclusionSet->count(BB)) + continue; + if (DT && DT->dominates(BB, StopBB)) + return true; + + const Loop *Outer = nullptr; + if (LI) { + Outer = getOutermostLoop(LI, BB); + // If we're in a loop with a hole, not all blocks in the loop are + // reachable from all other blocks. That implies we can't simply jump to + // the loop's exit blocks, as that exit might need to pass through an + // excluded block. Clear Outer so we process BB's successors. + if (LoopsWithHoles.count(Outer)) + Outer = nullptr; + if (StopLoop && Outer == StopLoop) + return true; + } + + if (!--Limit) { + // We haven't been able to prove it one way or the other. Conservatively + // answer true -- that there is potentially a path. + return true; + } + + if (Outer) { + // All blocks in a single loop are reachable from all other blocks. From + // any of these blocks, we can skip directly to the exits of the loop, + // ignoring any other blocks inside the loop body. + Outer->getExitBlocks(Worklist); + } else { + Worklist.append(succ_begin(BB), succ_end(BB)); + } + } while (!Worklist.empty()); + + // We have exhausted all possible paths and are certain that 'To' can not be + // reached from 'From'. + return false; +} + +bool llvm::isPotentiallyReachable(const BasicBlock *A, const BasicBlock *B, + const DominatorTree *DT, const LoopInfo *LI) { + assert(A->getParent() == B->getParent() && + "This analysis is function-local!"); + + SmallVector<BasicBlock*, 32> Worklist; + Worklist.push_back(const_cast<BasicBlock*>(A)); + + return isPotentiallyReachableFromMany(Worklist, const_cast<BasicBlock *>(B), + nullptr, DT, LI); +} + +bool llvm::isPotentiallyReachable( + const Instruction *A, const Instruction *B, + const SmallPtrSetImpl<BasicBlock *> *ExclusionSet, const DominatorTree *DT, + const LoopInfo *LI) { + assert(A->getParent()->getParent() == B->getParent()->getParent() && + "This analysis is function-local!"); + + SmallVector<BasicBlock*, 32> Worklist; + + if (A->getParent() == B->getParent()) { + // The same block case is special because it's the only time we're looking + // within a single block to see which instruction comes first. Once we + // start looking at multiple blocks, the first instruction of the block is + // reachable, so we only need to determine reachability between whole + // blocks. + BasicBlock *BB = const_cast<BasicBlock *>(A->getParent()); + + // If the block is in a loop then we can reach any instruction in the block + // from any other instruction in the block by going around a backedge. + if (LI && LI->getLoopFor(BB) != nullptr) + return true; + + // Linear scan, start at 'A', see whether we hit 'B' or the end first. + for (BasicBlock::const_iterator I = A->getIterator(), E = BB->end(); I != E; + ++I) { + if (&*I == B) + return true; + } + + // Can't be in a loop if it's the entry block -- the entry block may not + // have predecessors. + if (BB == &BB->getParent()->getEntryBlock()) + return false; + + // Otherwise, continue doing the normal per-BB CFG walk. + Worklist.append(succ_begin(BB), succ_end(BB)); + + if (Worklist.empty()) { + // We've proven that there's no path! + return false; + } + } else { + Worklist.push_back(const_cast<BasicBlock*>(A->getParent())); + } + + if (DT) { + if (DT->isReachableFromEntry(A->getParent()) && + !DT->isReachableFromEntry(B->getParent())) + return false; + if (!ExclusionSet || ExclusionSet->empty()) { + if (A->getParent() == &A->getParent()->getParent()->getEntryBlock() && + DT->isReachableFromEntry(B->getParent())) + return true; + if (B->getParent() == &A->getParent()->getParent()->getEntryBlock() && + DT->isReachableFromEntry(A->getParent())) + return false; + } + } + + return isPotentiallyReachableFromMany( + Worklist, const_cast<BasicBlock *>(B->getParent()), ExclusionSet, DT, LI); +} diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp new file mode 100644 index 000000000000..4f4103fefa25 --- /dev/null +++ b/llvm/lib/Analysis/CFGPrinter.cpp @@ -0,0 +1,200 @@ +//===- CFGPrinter.cpp - DOT printer for the control flow graph ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a `-dot-cfg` analysis pass, which emits the +// `<prefix>.<fnname>.dot` file for each function in the program, with a graph +// of the CFG for that function. The default value for `<prefix>` is `cfg` but +// can be customized as needed. +// +// The other main feature of this file is that it implements the +// Function::viewCFG method, which is useful for debugging passes which operate +// on the CFG. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CFGPrinter.h" +#include "llvm/Pass.h" +#include "llvm/Support/FileSystem.h" +using namespace llvm; + +static cl::opt<std::string> CFGFuncName( + "cfg-func-name", cl::Hidden, + cl::desc("The name of a function (or its substring)" + " whose CFG is viewed/printed.")); + +static cl::opt<std::string> CFGDotFilenamePrefix( + "cfg-dot-filename-prefix", cl::Hidden, + cl::desc("The prefix used for the CFG dot file names.")); + +namespace { + struct CFGViewerLegacyPass : public FunctionPass { + static char ID; // Pass identifcation, replacement for typeid + CFGViewerLegacyPass() : FunctionPass(ID) { + initializeCFGViewerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + F.viewCFG(); + return false; + } + + void print(raw_ostream &OS, const Module* = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + }; +} + +char CFGViewerLegacyPass::ID = 0; +INITIALIZE_PASS(CFGViewerLegacyPass, "view-cfg", "View CFG of function", false, true) + +PreservedAnalyses CFGViewerPass::run(Function &F, + FunctionAnalysisManager &AM) { + F.viewCFG(); + return PreservedAnalyses::all(); +} + + +namespace { + struct CFGOnlyViewerLegacyPass : public FunctionPass { + static char ID; // Pass identifcation, replacement for typeid + CFGOnlyViewerLegacyPass() : FunctionPass(ID) { + initializeCFGOnlyViewerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + F.viewCFGOnly(); + return false; + } + + void print(raw_ostream &OS, const Module* = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + }; +} + +char CFGOnlyViewerLegacyPass::ID = 0; +INITIALIZE_PASS(CFGOnlyViewerLegacyPass, "view-cfg-only", + "View CFG of function (with no function bodies)", false, true) + +PreservedAnalyses CFGOnlyViewerPass::run(Function &F, + FunctionAnalysisManager &AM) { + F.viewCFGOnly(); + return PreservedAnalyses::all(); +} + +static void writeCFGToDotFile(Function &F, bool CFGOnly = false) { + if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName)) + return; + std::string Filename = + (CFGDotFilenamePrefix + "." + F.getName() + ".dot").str(); + errs() << "Writing '" << Filename << "'..."; + + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::OF_Text); + + if (!EC) + WriteGraph(File, (const Function*)&F, CFGOnly); + else + errs() << " error opening file for writing!"; + errs() << "\n"; +} + +namespace { + struct CFGPrinterLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CFGPrinterLegacyPass() : FunctionPass(ID) { + initializeCFGPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + writeCFGToDotFile(F); + return false; + } + + void print(raw_ostream &OS, const Module* = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + }; +} + +char CFGPrinterLegacyPass::ID = 0; +INITIALIZE_PASS(CFGPrinterLegacyPass, "dot-cfg", "Print CFG of function to 'dot' file", + false, true) + +PreservedAnalyses CFGPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + writeCFGToDotFile(F); + return PreservedAnalyses::all(); +} + +namespace { + struct CFGOnlyPrinterLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CFGOnlyPrinterLegacyPass() : FunctionPass(ID) { + initializeCFGOnlyPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + writeCFGToDotFile(F, /*CFGOnly=*/true); + return false; + } + void print(raw_ostream &OS, const Module* = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + }; +} + +char CFGOnlyPrinterLegacyPass::ID = 0; +INITIALIZE_PASS(CFGOnlyPrinterLegacyPass, "dot-cfg-only", + "Print CFG of function to 'dot' file (with no function bodies)", + false, true) + +PreservedAnalyses CFGOnlyPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + writeCFGToDotFile(F, /*CFGOnly=*/true); + return PreservedAnalyses::all(); +} + +/// viewCFG - This function is meant for use from the debugger. You can just +/// say 'call F->viewCFG()' and a ghostview window should pop up from the +/// program, displaying the CFG of the current function. This depends on there +/// being a 'dot' and 'gv' program in your path. +/// +void Function::viewCFG() const { + if (!CFGFuncName.empty() && !getName().contains(CFGFuncName)) + return; + ViewGraph(this, "cfg" + getName()); +} + +/// viewCFGOnly - This function is meant for use from the debugger. It works +/// just like viewCFG, but it does not include the contents of basic blocks +/// into the nodes, just the label. If you are only interested in the CFG +/// this can make the graph smaller. +/// +void Function::viewCFGOnly() const { + if (!CFGFuncName.empty() && !getName().contains(CFGFuncName)) + return; + ViewGraph(this, "cfg" + getName(), true); +} + +FunctionPass *llvm::createCFGPrinterLegacyPassPass () { + return new CFGPrinterLegacyPass(); +} + +FunctionPass *llvm::createCFGOnlyPrinterLegacyPassPass () { + return new CFGOnlyPrinterLegacyPass(); +} + diff --git a/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp new file mode 100644 index 000000000000..fd90bd1521d6 --- /dev/null +++ b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -0,0 +1,931 @@ +//===- CFLAndersAliasAnalysis.cpp - Unification-based Alias Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-based, summary-based alias analysis algorithm. It +// differs from CFLSteensAliasAnalysis in its inclusion-based nature while +// CFLSteensAliasAnalysis is unification-based. This pass has worse performance +// than CFLSteensAliasAnalysis (the worst case complexity of +// CFLAndersAliasAnalysis is cubic, while the worst case complexity of +// CFLSteensAliasAnalysis is almost linear), but it is able to yield more +// precise analysis result. The precision of this analysis is roughly the same +// as that of an one level context-sensitive Andersen's algorithm. +// +// The algorithm used here is based on recursive state machine matching scheme +// proposed in "Demand-driven alias analysis for C" by Xin Zheng and Radu +// Rugina. The general idea is to extend the traditional transitive closure +// algorithm to perform CFL matching along the way: instead of recording +// "whether X is reachable from Y", we keep track of "whether X is reachable +// from Y at state Z", where the "state" field indicates where we are in the CFL +// matching process. To understand the matching better, it is advisable to have +// the state machine shown in Figure 3 of the paper available when reading the +// codes: all we do here is to selectively expand the transitive closure by +// discarding edges that are not recognized by the state machine. +// +// There are two differences between our current implementation and the one +// described in the paper: +// - Our algorithm eagerly computes all alias pairs after the CFLGraph is built, +// while in the paper the authors did the computation in a demand-driven +// fashion. We did not implement the demand-driven algorithm due to the +// additional coding complexity and higher memory profile, but if we found it +// necessary we may switch to it eventually. +// - In the paper the authors use a state machine that does not distinguish +// value reads from value writes. For example, if Y is reachable from X at state +// S3, it may be the case that X is written into Y, or it may be the case that +// there's a third value Z that writes into both X and Y. To make that +// distinction (which is crucial in building function summary as well as +// retrieving mod-ref info), we choose to duplicate some of the states in the +// paper's proposed state machine. The duplication does not change the set the +// machine accepts. Given a pair of reachable values, it only provides more +// detailed information on which value is being written into and which is being +// read from. +// +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLAndersAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "AliasAnalysisSummary.h" +#include "CFLGraph.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <bitset> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <functional> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-anders-aa" + +CFLAndersAAResult::CFLAndersAAResult( + std::function<const TargetLibraryInfo &(Function &F)> GetTLI) + : GetTLI(std::move(GetTLI)) {} +CFLAndersAAResult::CFLAndersAAResult(CFLAndersAAResult &&RHS) + : AAResultBase(std::move(RHS)), GetTLI(std::move(RHS.GetTLI)) {} +CFLAndersAAResult::~CFLAndersAAResult() = default; + +namespace { + +enum class MatchState : uint8_t { + // The following state represents S1 in the paper. + FlowFromReadOnly = 0, + // The following two states together represent S2 in the paper. + // The 'NoReadWrite' suffix indicates that there exists an alias path that + // does not contain assignment and reverse assignment edges. + // The 'ReadOnly' suffix indicates that there exists an alias path that + // contains reverse assignment edges only. + FlowFromMemAliasNoReadWrite, + FlowFromMemAliasReadOnly, + // The following two states together represent S3 in the paper. + // The 'WriteOnly' suffix indicates that there exists an alias path that + // contains assignment edges only. + // The 'ReadWrite' suffix indicates that there exists an alias path that + // contains both assignment and reverse assignment edges. Note that if X and Y + // are reachable at 'ReadWrite' state, it does NOT mean X is both read from + // and written to Y. Instead, it means that a third value Z is written to both + // X and Y. + FlowToWriteOnly, + FlowToReadWrite, + // The following two states together represent S4 in the paper. + FlowToMemAliasWriteOnly, + FlowToMemAliasReadWrite, +}; + +using StateSet = std::bitset<7>; + +const unsigned ReadOnlyStateMask = + (1U << static_cast<uint8_t>(MatchState::FlowFromReadOnly)) | + (1U << static_cast<uint8_t>(MatchState::FlowFromMemAliasReadOnly)); +const unsigned WriteOnlyStateMask = + (1U << static_cast<uint8_t>(MatchState::FlowToWriteOnly)) | + (1U << static_cast<uint8_t>(MatchState::FlowToMemAliasWriteOnly)); + +// A pair that consists of a value and an offset +struct OffsetValue { + const Value *Val; + int64_t Offset; +}; + +bool operator==(OffsetValue LHS, OffsetValue RHS) { + return LHS.Val == RHS.Val && LHS.Offset == RHS.Offset; +} +bool operator<(OffsetValue LHS, OffsetValue RHS) { + return std::less<const Value *>()(LHS.Val, RHS.Val) || + (LHS.Val == RHS.Val && LHS.Offset < RHS.Offset); +} + +// A pair that consists of an InstantiatedValue and an offset +struct OffsetInstantiatedValue { + InstantiatedValue IVal; + int64_t Offset; +}; + +bool operator==(OffsetInstantiatedValue LHS, OffsetInstantiatedValue RHS) { + return LHS.IVal == RHS.IVal && LHS.Offset == RHS.Offset; +} + +// We use ReachabilitySet to keep track of value aliases (The nonterminal "V" in +// the paper) during the analysis. +class ReachabilitySet { + using ValueStateMap = DenseMap<InstantiatedValue, StateSet>; + using ValueReachMap = DenseMap<InstantiatedValue, ValueStateMap>; + + ValueReachMap ReachMap; + +public: + using const_valuestate_iterator = ValueStateMap::const_iterator; + using const_value_iterator = ValueReachMap::const_iterator; + + // Insert edge 'From->To' at state 'State' + bool insert(InstantiatedValue From, InstantiatedValue To, MatchState State) { + assert(From != To); + auto &States = ReachMap[To][From]; + auto Idx = static_cast<size_t>(State); + if (!States.test(Idx)) { + States.set(Idx); + return true; + } + return false; + } + + // Return the set of all ('From', 'State') pair for a given node 'To' + iterator_range<const_valuestate_iterator> + reachableValueAliases(InstantiatedValue V) const { + auto Itr = ReachMap.find(V); + if (Itr == ReachMap.end()) + return make_range<const_valuestate_iterator>(const_valuestate_iterator(), + const_valuestate_iterator()); + return make_range<const_valuestate_iterator>(Itr->second.begin(), + Itr->second.end()); + } + + iterator_range<const_value_iterator> value_mappings() const { + return make_range<const_value_iterator>(ReachMap.begin(), ReachMap.end()); + } +}; + +// We use AliasMemSet to keep track of all memory aliases (the nonterminal "M" +// in the paper) during the analysis. +class AliasMemSet { + using MemSet = DenseSet<InstantiatedValue>; + using MemMapType = DenseMap<InstantiatedValue, MemSet>; + + MemMapType MemMap; + +public: + using const_mem_iterator = MemSet::const_iterator; + + bool insert(InstantiatedValue LHS, InstantiatedValue RHS) { + // Top-level values can never be memory aliases because one cannot take the + // addresses of them + assert(LHS.DerefLevel > 0 && RHS.DerefLevel > 0); + return MemMap[LHS].insert(RHS).second; + } + + const MemSet *getMemoryAliases(InstantiatedValue V) const { + auto Itr = MemMap.find(V); + if (Itr == MemMap.end()) + return nullptr; + return &Itr->second; + } +}; + +// We use AliasAttrMap to keep track of the AliasAttr of each node. +class AliasAttrMap { + using MapType = DenseMap<InstantiatedValue, AliasAttrs>; + + MapType AttrMap; + +public: + using const_iterator = MapType::const_iterator; + + bool add(InstantiatedValue V, AliasAttrs Attr) { + auto &OldAttr = AttrMap[V]; + auto NewAttr = OldAttr | Attr; + if (OldAttr == NewAttr) + return false; + OldAttr = NewAttr; + return true; + } + + AliasAttrs getAttrs(InstantiatedValue V) const { + AliasAttrs Attr; + auto Itr = AttrMap.find(V); + if (Itr != AttrMap.end()) + Attr = Itr->second; + return Attr; + } + + iterator_range<const_iterator> mappings() const { + return make_range<const_iterator>(AttrMap.begin(), AttrMap.end()); + } +}; + +struct WorkListItem { + InstantiatedValue From; + InstantiatedValue To; + MatchState State; +}; + +struct ValueSummary { + struct Record { + InterfaceValue IValue; + unsigned DerefLevel; + }; + SmallVector<Record, 4> FromRecords, ToRecords; +}; + +} // end anonymous namespace + +namespace llvm { + +// Specialize DenseMapInfo for OffsetValue. +template <> struct DenseMapInfo<OffsetValue> { + static OffsetValue getEmptyKey() { + return OffsetValue{DenseMapInfo<const Value *>::getEmptyKey(), + DenseMapInfo<int64_t>::getEmptyKey()}; + } + + static OffsetValue getTombstoneKey() { + return OffsetValue{DenseMapInfo<const Value *>::getTombstoneKey(), + DenseMapInfo<int64_t>::getEmptyKey()}; + } + + static unsigned getHashValue(const OffsetValue &OVal) { + return DenseMapInfo<std::pair<const Value *, int64_t>>::getHashValue( + std::make_pair(OVal.Val, OVal.Offset)); + } + + static bool isEqual(const OffsetValue &LHS, const OffsetValue &RHS) { + return LHS == RHS; + } +}; + +// Specialize DenseMapInfo for OffsetInstantiatedValue. +template <> struct DenseMapInfo<OffsetInstantiatedValue> { + static OffsetInstantiatedValue getEmptyKey() { + return OffsetInstantiatedValue{ + DenseMapInfo<InstantiatedValue>::getEmptyKey(), + DenseMapInfo<int64_t>::getEmptyKey()}; + } + + static OffsetInstantiatedValue getTombstoneKey() { + return OffsetInstantiatedValue{ + DenseMapInfo<InstantiatedValue>::getTombstoneKey(), + DenseMapInfo<int64_t>::getEmptyKey()}; + } + + static unsigned getHashValue(const OffsetInstantiatedValue &OVal) { + return DenseMapInfo<std::pair<InstantiatedValue, int64_t>>::getHashValue( + std::make_pair(OVal.IVal, OVal.Offset)); + } + + static bool isEqual(const OffsetInstantiatedValue &LHS, + const OffsetInstantiatedValue &RHS) { + return LHS == RHS; + } +}; + +} // end namespace llvm + +class CFLAndersAAResult::FunctionInfo { + /// Map a value to other values that may alias it + /// Since the alias relation is symmetric, to save some space we assume values + /// are properly ordered: if a and b alias each other, and a < b, then b is in + /// AliasMap[a] but not vice versa. + DenseMap<const Value *, std::vector<OffsetValue>> AliasMap; + + /// Map a value to its corresponding AliasAttrs + DenseMap<const Value *, AliasAttrs> AttrMap; + + /// Summary of externally visible effects. + AliasSummary Summary; + + Optional<AliasAttrs> getAttrs(const Value *) const; + +public: + FunctionInfo(const Function &, const SmallVectorImpl<Value *> &, + const ReachabilitySet &, const AliasAttrMap &); + + bool mayAlias(const Value *, LocationSize, const Value *, LocationSize) const; + const AliasSummary &getAliasSummary() const { return Summary; } +}; + +static bool hasReadOnlyState(StateSet Set) { + return (Set & StateSet(ReadOnlyStateMask)).any(); +} + +static bool hasWriteOnlyState(StateSet Set) { + return (Set & StateSet(WriteOnlyStateMask)).any(); +} + +static Optional<InterfaceValue> +getInterfaceValue(InstantiatedValue IValue, + const SmallVectorImpl<Value *> &RetVals) { + auto Val = IValue.Val; + + Optional<unsigned> Index; + if (auto Arg = dyn_cast<Argument>(Val)) + Index = Arg->getArgNo() + 1; + else if (is_contained(RetVals, Val)) + Index = 0; + + if (Index) + return InterfaceValue{*Index, IValue.DerefLevel}; + return None; +} + +static void populateAttrMap(DenseMap<const Value *, AliasAttrs> &AttrMap, + const AliasAttrMap &AMap) { + for (const auto &Mapping : AMap.mappings()) { + auto IVal = Mapping.first; + + // Insert IVal into the map + auto &Attr = AttrMap[IVal.Val]; + // AttrMap only cares about top-level values + if (IVal.DerefLevel == 0) + Attr |= Mapping.second; + } +} + +static void +populateAliasMap(DenseMap<const Value *, std::vector<OffsetValue>> &AliasMap, + const ReachabilitySet &ReachSet) { + for (const auto &OuterMapping : ReachSet.value_mappings()) { + // AliasMap only cares about top-level values + if (OuterMapping.first.DerefLevel > 0) + continue; + + auto Val = OuterMapping.first.Val; + auto &AliasList = AliasMap[Val]; + for (const auto &InnerMapping : OuterMapping.second) { + // Again, AliasMap only cares about top-level values + if (InnerMapping.first.DerefLevel == 0) + AliasList.push_back(OffsetValue{InnerMapping.first.Val, UnknownOffset}); + } + + // Sort AliasList for faster lookup + llvm::sort(AliasList); + } +} + +static void populateExternalRelations( + SmallVectorImpl<ExternalRelation> &ExtRelations, const Function &Fn, + const SmallVectorImpl<Value *> &RetVals, const ReachabilitySet &ReachSet) { + // If a function only returns one of its argument X, then X will be both an + // argument and a return value at the same time. This is an edge case that + // needs special handling here. + for (const auto &Arg : Fn.args()) { + if (is_contained(RetVals, &Arg)) { + auto ArgVal = InterfaceValue{Arg.getArgNo() + 1, 0}; + auto RetVal = InterfaceValue{0, 0}; + ExtRelations.push_back(ExternalRelation{ArgVal, RetVal, 0}); + } + } + + // Below is the core summary construction logic. + // A naive solution of adding only the value aliases that are parameters or + // return values in ReachSet to the summary won't work: It is possible that a + // parameter P is written into an intermediate value I, and the function + // subsequently returns *I. In that case, *I is does not value alias anything + // in ReachSet, and the naive solution will miss a summary edge from (P, 1) to + // (I, 1). + // To account for the aforementioned case, we need to check each non-parameter + // and non-return value for the possibility of acting as an intermediate. + // 'ValueMap' here records, for each value, which InterfaceValues read from or + // write into it. If both the read list and the write list of a given value + // are non-empty, we know that a particular value is an intermidate and we + // need to add summary edges from the writes to the reads. + DenseMap<Value *, ValueSummary> ValueMap; + for (const auto &OuterMapping : ReachSet.value_mappings()) { + if (auto Dst = getInterfaceValue(OuterMapping.first, RetVals)) { + for (const auto &InnerMapping : OuterMapping.second) { + // If Src is a param/return value, we get a same-level assignment. + if (auto Src = getInterfaceValue(InnerMapping.first, RetVals)) { + // This may happen if both Dst and Src are return values + if (*Dst == *Src) + continue; + + if (hasReadOnlyState(InnerMapping.second)) + ExtRelations.push_back(ExternalRelation{*Dst, *Src, UnknownOffset}); + // No need to check for WriteOnly state, since ReachSet is symmetric + } else { + // If Src is not a param/return, add it to ValueMap + auto SrcIVal = InnerMapping.first; + if (hasReadOnlyState(InnerMapping.second)) + ValueMap[SrcIVal.Val].FromRecords.push_back( + ValueSummary::Record{*Dst, SrcIVal.DerefLevel}); + if (hasWriteOnlyState(InnerMapping.second)) + ValueMap[SrcIVal.Val].ToRecords.push_back( + ValueSummary::Record{*Dst, SrcIVal.DerefLevel}); + } + } + } + } + + for (const auto &Mapping : ValueMap) { + for (const auto &FromRecord : Mapping.second.FromRecords) { + for (const auto &ToRecord : Mapping.second.ToRecords) { + auto ToLevel = ToRecord.DerefLevel; + auto FromLevel = FromRecord.DerefLevel; + // Same-level assignments should have already been processed by now + if (ToLevel == FromLevel) + continue; + + auto SrcIndex = FromRecord.IValue.Index; + auto SrcLevel = FromRecord.IValue.DerefLevel; + auto DstIndex = ToRecord.IValue.Index; + auto DstLevel = ToRecord.IValue.DerefLevel; + if (ToLevel > FromLevel) + SrcLevel += ToLevel - FromLevel; + else + DstLevel += FromLevel - ToLevel; + + ExtRelations.push_back(ExternalRelation{ + InterfaceValue{SrcIndex, SrcLevel}, + InterfaceValue{DstIndex, DstLevel}, UnknownOffset}); + } + } + } + + // Remove duplicates in ExtRelations + llvm::sort(ExtRelations); + ExtRelations.erase(std::unique(ExtRelations.begin(), ExtRelations.end()), + ExtRelations.end()); +} + +static void populateExternalAttributes( + SmallVectorImpl<ExternalAttribute> &ExtAttributes, const Function &Fn, + const SmallVectorImpl<Value *> &RetVals, const AliasAttrMap &AMap) { + for (const auto &Mapping : AMap.mappings()) { + if (auto IVal = getInterfaceValue(Mapping.first, RetVals)) { + auto Attr = getExternallyVisibleAttrs(Mapping.second); + if (Attr.any()) + ExtAttributes.push_back(ExternalAttribute{*IVal, Attr}); + } + } +} + +CFLAndersAAResult::FunctionInfo::FunctionInfo( + const Function &Fn, const SmallVectorImpl<Value *> &RetVals, + const ReachabilitySet &ReachSet, const AliasAttrMap &AMap) { + populateAttrMap(AttrMap, AMap); + populateExternalAttributes(Summary.RetParamAttributes, Fn, RetVals, AMap); + populateAliasMap(AliasMap, ReachSet); + populateExternalRelations(Summary.RetParamRelations, Fn, RetVals, ReachSet); +} + +Optional<AliasAttrs> +CFLAndersAAResult::FunctionInfo::getAttrs(const Value *V) const { + assert(V != nullptr); + + auto Itr = AttrMap.find(V); + if (Itr != AttrMap.end()) + return Itr->second; + return None; +} + +bool CFLAndersAAResult::FunctionInfo::mayAlias( + const Value *LHS, LocationSize MaybeLHSSize, const Value *RHS, + LocationSize MaybeRHSSize) const { + assert(LHS && RHS); + + // Check if we've seen LHS and RHS before. Sometimes LHS or RHS can be created + // after the analysis gets executed, and we want to be conservative in those + // cases. + auto MaybeAttrsA = getAttrs(LHS); + auto MaybeAttrsB = getAttrs(RHS); + if (!MaybeAttrsA || !MaybeAttrsB) + return true; + + // Check AliasAttrs before AliasMap lookup since it's cheaper + auto AttrsA = *MaybeAttrsA; + auto AttrsB = *MaybeAttrsB; + if (hasUnknownOrCallerAttr(AttrsA)) + return AttrsB.any(); + if (hasUnknownOrCallerAttr(AttrsB)) + return AttrsA.any(); + if (isGlobalOrArgAttr(AttrsA)) + return isGlobalOrArgAttr(AttrsB); + if (isGlobalOrArgAttr(AttrsB)) + return isGlobalOrArgAttr(AttrsA); + + // At this point both LHS and RHS should point to locally allocated objects + + auto Itr = AliasMap.find(LHS); + if (Itr != AliasMap.end()) { + + // Find out all (X, Offset) where X == RHS + auto Comparator = [](OffsetValue LHS, OffsetValue RHS) { + return std::less<const Value *>()(LHS.Val, RHS.Val); + }; +#ifdef EXPENSIVE_CHECKS + assert(std::is_sorted(Itr->second.begin(), Itr->second.end(), Comparator)); +#endif + auto RangePair = std::equal_range(Itr->second.begin(), Itr->second.end(), + OffsetValue{RHS, 0}, Comparator); + + if (RangePair.first != RangePair.second) { + // Be conservative about unknown sizes + if (MaybeLHSSize == LocationSize::unknown() || + MaybeRHSSize == LocationSize::unknown()) + return true; + + const uint64_t LHSSize = MaybeLHSSize.getValue(); + const uint64_t RHSSize = MaybeRHSSize.getValue(); + + for (const auto &OVal : make_range(RangePair)) { + // Be conservative about UnknownOffset + if (OVal.Offset == UnknownOffset) + return true; + + // We know that LHS aliases (RHS + OVal.Offset) if the control flow + // reaches here. The may-alias query essentially becomes integer + // range-overlap queries over two ranges [OVal.Offset, OVal.Offset + + // LHSSize) and [0, RHSSize). + + // Try to be conservative on super large offsets + if (LLVM_UNLIKELY(LHSSize > INT64_MAX || RHSSize > INT64_MAX)) + return true; + + auto LHSStart = OVal.Offset; + // FIXME: Do we need to guard against integer overflow? + auto LHSEnd = OVal.Offset + static_cast<int64_t>(LHSSize); + auto RHSStart = 0; + auto RHSEnd = static_cast<int64_t>(RHSSize); + if (LHSEnd > RHSStart && LHSStart < RHSEnd) + return true; + } + } + } + + return false; +} + +static void propagate(InstantiatedValue From, InstantiatedValue To, + MatchState State, ReachabilitySet &ReachSet, + std::vector<WorkListItem> &WorkList) { + if (From == To) + return; + if (ReachSet.insert(From, To, State)) + WorkList.push_back(WorkListItem{From, To, State}); +} + +static void initializeWorkList(std::vector<WorkListItem> &WorkList, + ReachabilitySet &ReachSet, + const CFLGraph &Graph) { + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + auto &ValueInfo = Mapping.second; + assert(ValueInfo.getNumLevels() > 0); + + // Insert all immediate assignment neighbors to the worklist + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Src = InstantiatedValue{Val, I}; + // If there's an assignment edge from X to Y, it means Y is reachable from + // X at S3 and X is reachable from Y at S1 + for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) { + propagate(Edge.Other, Src, MatchState::FlowFromReadOnly, ReachSet, + WorkList); + propagate(Src, Edge.Other, MatchState::FlowToWriteOnly, ReachSet, + WorkList); + } + } + } +} + +static Optional<InstantiatedValue> getNodeBelow(const CFLGraph &Graph, + InstantiatedValue V) { + auto NodeBelow = InstantiatedValue{V.Val, V.DerefLevel + 1}; + if (Graph.getNode(NodeBelow)) + return NodeBelow; + return None; +} + +static void processWorkListItem(const WorkListItem &Item, const CFLGraph &Graph, + ReachabilitySet &ReachSet, AliasMemSet &MemSet, + std::vector<WorkListItem> &WorkList) { + auto FromNode = Item.From; + auto ToNode = Item.To; + + auto NodeInfo = Graph.getNode(ToNode); + assert(NodeInfo != nullptr); + + // TODO: propagate field offsets + + // FIXME: Here is a neat trick we can do: since both ReachSet and MemSet holds + // relations that are symmetric, we could actually cut the storage by half by + // sorting FromNode and ToNode before insertion happens. + + // The newly added value alias pair may potentially generate more memory + // alias pairs. Check for them here. + auto FromNodeBelow = getNodeBelow(Graph, FromNode); + auto ToNodeBelow = getNodeBelow(Graph, ToNode); + if (FromNodeBelow && ToNodeBelow && + MemSet.insert(*FromNodeBelow, *ToNodeBelow)) { + propagate(*FromNodeBelow, *ToNodeBelow, + MatchState::FlowFromMemAliasNoReadWrite, ReachSet, WorkList); + for (const auto &Mapping : ReachSet.reachableValueAliases(*FromNodeBelow)) { + auto Src = Mapping.first; + auto MemAliasPropagate = [&](MatchState FromState, MatchState ToState) { + if (Mapping.second.test(static_cast<size_t>(FromState))) + propagate(Src, *ToNodeBelow, ToState, ReachSet, WorkList); + }; + + MemAliasPropagate(MatchState::FlowFromReadOnly, + MatchState::FlowFromMemAliasReadOnly); + MemAliasPropagate(MatchState::FlowToWriteOnly, + MatchState::FlowToMemAliasWriteOnly); + MemAliasPropagate(MatchState::FlowToReadWrite, + MatchState::FlowToMemAliasReadWrite); + } + } + + // This is the core of the state machine walking algorithm. We expand ReachSet + // based on which state we are at (which in turn dictates what edges we + // should examine) + // From a high-level point of view, the state machine here guarantees two + // properties: + // - If *X and *Y are memory aliases, then X and Y are value aliases + // - If Y is an alias of X, then reverse assignment edges (if there is any) + // should precede any assignment edges on the path from X to Y. + auto NextAssignState = [&](MatchState State) { + for (const auto &AssignEdge : NodeInfo->Edges) + propagate(FromNode, AssignEdge.Other, State, ReachSet, WorkList); + }; + auto NextRevAssignState = [&](MatchState State) { + for (const auto &RevAssignEdge : NodeInfo->ReverseEdges) + propagate(FromNode, RevAssignEdge.Other, State, ReachSet, WorkList); + }; + auto NextMemState = [&](MatchState State) { + if (auto AliasSet = MemSet.getMemoryAliases(ToNode)) { + for (const auto &MemAlias : *AliasSet) + propagate(FromNode, MemAlias, State, ReachSet, WorkList); + } + }; + + switch (Item.State) { + case MatchState::FlowFromReadOnly: + NextRevAssignState(MatchState::FlowFromReadOnly); + NextAssignState(MatchState::FlowToReadWrite); + NextMemState(MatchState::FlowFromMemAliasReadOnly); + break; + + case MatchState::FlowFromMemAliasNoReadWrite: + NextRevAssignState(MatchState::FlowFromReadOnly); + NextAssignState(MatchState::FlowToWriteOnly); + break; + + case MatchState::FlowFromMemAliasReadOnly: + NextRevAssignState(MatchState::FlowFromReadOnly); + NextAssignState(MatchState::FlowToReadWrite); + break; + + case MatchState::FlowToWriteOnly: + NextAssignState(MatchState::FlowToWriteOnly); + NextMemState(MatchState::FlowToMemAliasWriteOnly); + break; + + case MatchState::FlowToReadWrite: + NextAssignState(MatchState::FlowToReadWrite); + NextMemState(MatchState::FlowToMemAliasReadWrite); + break; + + case MatchState::FlowToMemAliasWriteOnly: + NextAssignState(MatchState::FlowToWriteOnly); + break; + + case MatchState::FlowToMemAliasReadWrite: + NextAssignState(MatchState::FlowToReadWrite); + break; + } +} + +static AliasAttrMap buildAttrMap(const CFLGraph &Graph, + const ReachabilitySet &ReachSet) { + AliasAttrMap AttrMap; + std::vector<InstantiatedValue> WorkList, NextList; + + // Initialize each node with its original AliasAttrs in CFLGraph + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + auto &ValueInfo = Mapping.second; + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Node = InstantiatedValue{Val, I}; + AttrMap.add(Node, ValueInfo.getNodeInfoAtLevel(I).Attr); + WorkList.push_back(Node); + } + } + + while (!WorkList.empty()) { + for (const auto &Dst : WorkList) { + auto DstAttr = AttrMap.getAttrs(Dst); + if (DstAttr.none()) + continue; + + // Propagate attr on the same level + for (const auto &Mapping : ReachSet.reachableValueAliases(Dst)) { + auto Src = Mapping.first; + if (AttrMap.add(Src, DstAttr)) + NextList.push_back(Src); + } + + // Propagate attr to the levels below + auto DstBelow = getNodeBelow(Graph, Dst); + while (DstBelow) { + if (AttrMap.add(*DstBelow, DstAttr)) { + NextList.push_back(*DstBelow); + break; + } + DstBelow = getNodeBelow(Graph, *DstBelow); + } + } + WorkList.swap(NextList); + NextList.clear(); + } + + return AttrMap; +} + +CFLAndersAAResult::FunctionInfo +CFLAndersAAResult::buildInfoFrom(const Function &Fn) { + CFLGraphBuilder<CFLAndersAAResult> GraphBuilder( + *this, GetTLI(const_cast<Function &>(Fn)), + // Cast away the constness here due to GraphBuilder's API requirement + const_cast<Function &>(Fn)); + auto &Graph = GraphBuilder.getCFLGraph(); + + ReachabilitySet ReachSet; + AliasMemSet MemSet; + + std::vector<WorkListItem> WorkList, NextList; + initializeWorkList(WorkList, ReachSet, Graph); + // TODO: make sure we don't stop before the fix point is reached + while (!WorkList.empty()) { + for (const auto &Item : WorkList) + processWorkListItem(Item, Graph, ReachSet, MemSet, NextList); + + NextList.swap(WorkList); + NextList.clear(); + } + + // Now that we have all the reachability info, propagate AliasAttrs according + // to it + auto IValueAttrMap = buildAttrMap(Graph, ReachSet); + + return FunctionInfo(Fn, GraphBuilder.getReturnValues(), ReachSet, + std::move(IValueAttrMap)); +} + +void CFLAndersAAResult::scan(const Function &Fn) { + auto InsertPair = Cache.insert(std::make_pair(&Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call + // may get evaluated after operator[], potentially triggering a DenseMap + // resize and invalidating the reference returned by operator[] + auto FunInfo = buildInfoFrom(Fn); + Cache[&Fn] = std::move(FunInfo); + Handles.emplace_front(const_cast<Function *>(&Fn), this); +} + +void CFLAndersAAResult::evict(const Function *Fn) { Cache.erase(Fn); } + +const Optional<CFLAndersAAResult::FunctionInfo> & +CFLAndersAAResult::ensureCached(const Function &Fn) { + auto Iter = Cache.find(&Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(&Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; +} + +const AliasSummary *CFLAndersAAResult::getAliasSummary(const Function &Fn) { + auto &FunInfo = ensureCached(Fn); + if (FunInfo.hasValue()) + return &FunInfo->getAliasSummary(); + else + return nullptr; +} + +AliasResult CFLAndersAAResult::query(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + auto *ValA = LocA.Ptr; + auto *ValB = LocB.Ptr; + + if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) + return NoAlias; + + auto *Fn = parentFunctionOfValue(ValA); + if (!Fn) { + Fn = parentFunctionOfValue(ValB); + if (!Fn) { + // The only times this is known to happen are when globals + InlineAsm are + // involved + LLVM_DEBUG( + dbgs() + << "CFLAndersAA: could not extract parent function information.\n"); + return MayAlias; + } + } else { + assert(!parentFunctionOfValue(ValB) || parentFunctionOfValue(ValB) == Fn); + } + + assert(Fn != nullptr); + auto &FunInfo = ensureCached(*Fn); + + // AliasMap lookup + if (FunInfo->mayAlias(ValA, LocA.Size, ValB, LocB.Size)) + return MayAlias; + return NoAlias; +} + +AliasResult CFLAndersAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + if (LocA.Ptr == LocB.Ptr) + return MustAlias; + + // Comparisons between global variables and other constants should be + // handled by BasicAA. + // CFLAndersAA may report NoAlias when comparing a GlobalValue and + // ConstantExpr, but every query needs to have at least one Value tied to a + // Function, and neither GlobalValues nor ConstantExprs are. + if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr)) + return AAResultBase::alias(LocA, LocB, AAQI); + + AliasResult QueryResult = query(LocA, LocB); + if (QueryResult == MayAlias) + return AAResultBase::alias(LocA, LocB, AAQI); + + return QueryResult; +} + +AnalysisKey CFLAndersAA::Key; + +CFLAndersAAResult CFLAndersAA::run(Function &F, FunctionAnalysisManager &AM) { + auto GetTLI = [&AM](Function &F) -> TargetLibraryInfo & { + return AM.getResult<TargetLibraryAnalysis>(F); + }; + return CFLAndersAAResult(GetTLI); +} + +char CFLAndersAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLAndersAAWrapperPass, "cfl-anders-aa", + "Inclusion-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLAndersAAWrapperPass() { + return new CFLAndersAAWrapperPass(); +} + +CFLAndersAAWrapperPass::CFLAndersAAWrapperPass() : ImmutablePass(ID) { + initializeCFLAndersAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLAndersAAWrapperPass::initializePass() { + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; + Result.reset(new CFLAndersAAResult(GetTLI)); +} + +void CFLAndersAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/llvm/lib/Analysis/CFLGraph.h b/llvm/lib/Analysis/CFLGraph.h new file mode 100644 index 000000000000..21842ed36487 --- /dev/null +++ b/llvm/lib/Analysis/CFLGraph.h @@ -0,0 +1,660 @@ +//===- CFLGraph.h - Abstract stratified sets implementation. -----*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file defines CFLGraph, an auxiliary data structure used by CFL-based +/// alias analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_ANALYSIS_CFLGRAPH_H +#define LLVM_LIB_ANALYSIS_CFLGRAPH_H + +#include "AliasAnalysisSummary.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include <cassert> +#include <cstdint> +#include <vector> + +namespace llvm { +namespace cflaa { + +/// The Program Expression Graph (PEG) of CFL analysis +/// CFLGraph is auxiliary data structure used by CFL-based alias analysis to +/// describe flow-insensitive pointer-related behaviors. Given an LLVM function, +/// the main purpose of this graph is to abstract away unrelated facts and +/// translate the rest into a form that can be easily digested by CFL analyses. +/// Each Node in the graph is an InstantiatedValue, and each edge represent a +/// pointer assignment between InstantiatedValue. Pointer +/// references/dereferences are not explicitly stored in the graph: we +/// implicitly assume that for each node (X, I) it has a dereference edge to (X, +/// I+1) and a reference edge to (X, I-1). +class CFLGraph { +public: + using Node = InstantiatedValue; + + struct Edge { + Node Other; + int64_t Offset; + }; + + using EdgeList = std::vector<Edge>; + + struct NodeInfo { + EdgeList Edges, ReverseEdges; + AliasAttrs Attr; + }; + + class ValueInfo { + std::vector<NodeInfo> Levels; + + public: + bool addNodeToLevel(unsigned Level) { + auto NumLevels = Levels.size(); + if (NumLevels > Level) + return false; + Levels.resize(Level + 1); + return true; + } + + NodeInfo &getNodeInfoAtLevel(unsigned Level) { + assert(Level < Levels.size()); + return Levels[Level]; + } + const NodeInfo &getNodeInfoAtLevel(unsigned Level) const { + assert(Level < Levels.size()); + return Levels[Level]; + } + + unsigned getNumLevels() const { return Levels.size(); } + }; + +private: + using ValueMap = DenseMap<Value *, ValueInfo>; + + ValueMap ValueImpls; + + NodeInfo *getNode(Node N) { + auto Itr = ValueImpls.find(N.Val); + if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) + return nullptr; + return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); + } + +public: + using const_value_iterator = ValueMap::const_iterator; + + bool addNode(Node N, AliasAttrs Attr = AliasAttrs()) { + assert(N.Val != nullptr); + auto &ValInfo = ValueImpls[N.Val]; + auto Changed = ValInfo.addNodeToLevel(N.DerefLevel); + ValInfo.getNodeInfoAtLevel(N.DerefLevel).Attr |= Attr; + return Changed; + } + + void addAttr(Node N, AliasAttrs Attr) { + auto *Info = getNode(N); + assert(Info != nullptr); + Info->Attr |= Attr; + } + + void addEdge(Node From, Node To, int64_t Offset = 0) { + auto *FromInfo = getNode(From); + assert(FromInfo != nullptr); + auto *ToInfo = getNode(To); + assert(ToInfo != nullptr); + + FromInfo->Edges.push_back(Edge{To, Offset}); + ToInfo->ReverseEdges.push_back(Edge{From, Offset}); + } + + const NodeInfo *getNode(Node N) const { + auto Itr = ValueImpls.find(N.Val); + if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) + return nullptr; + return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); + } + + AliasAttrs attrFor(Node N) const { + auto *Info = getNode(N); + assert(Info != nullptr); + return Info->Attr; + } + + iterator_range<const_value_iterator> value_mappings() const { + return make_range<const_value_iterator>(ValueImpls.begin(), + ValueImpls.end()); + } +}; + +/// A builder class used to create CFLGraph instance from a given function +/// The CFL-AA that uses this builder must provide its own type as a template +/// argument. This is necessary for interprocedural processing: CFLGraphBuilder +/// needs a way of obtaining the summary of other functions when callinsts are +/// encountered. +/// As a result, we expect the said CFL-AA to expose a getAliasSummary() public +/// member function that takes a Function& and returns the corresponding summary +/// as a const AliasSummary*. +template <typename CFLAA> class CFLGraphBuilder { + // Input of the builder + CFLAA &Analysis; + const TargetLibraryInfo &TLI; + + // Output of the builder + CFLGraph Graph; + SmallVector<Value *, 4> ReturnedValues; + + // Helper class + /// Gets the edges our graph should have, based on an Instruction* + class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { + CFLAA &AA; + const DataLayout &DL; + const TargetLibraryInfo &TLI; + + CFLGraph &Graph; + SmallVectorImpl<Value *> &ReturnValues; + + static bool hasUsefulEdges(ConstantExpr *CE) { + // ConstantExpr doesn't have terminators, invokes, or fences, so only + // needs to check for compares. + return CE->getOpcode() != Instruction::ICmp && + CE->getOpcode() != Instruction::FCmp; + } + + // Returns possible functions called by CS into the given SmallVectorImpl. + // Returns true if targets found, false otherwise. + static bool getPossibleTargets(CallBase &Call, + SmallVectorImpl<Function *> &Output) { + if (auto *Fn = Call.getCalledFunction()) { + Output.push_back(Fn); + return true; + } + + // TODO: If the call is indirect, we might be able to enumerate all + // potential targets of the call and return them, rather than just + // failing. + return false; + } + + void addNode(Value *Val, AliasAttrs Attr = AliasAttrs()) { + assert(Val != nullptr && Val->getType()->isPointerTy()); + if (auto GVal = dyn_cast<GlobalValue>(Val)) { + if (Graph.addNode(InstantiatedValue{GVal, 0}, + getGlobalOrArgAttrFromValue(*GVal))) + Graph.addNode(InstantiatedValue{GVal, 1}, getAttrUnknown()); + } else if (auto CExpr = dyn_cast<ConstantExpr>(Val)) { + if (hasUsefulEdges(CExpr)) { + if (Graph.addNode(InstantiatedValue{CExpr, 0})) + visitConstantExpr(CExpr); + } + } else + Graph.addNode(InstantiatedValue{Val, 0}, Attr); + } + + void addAssignEdge(Value *From, Value *To, int64_t Offset = 0) { + assert(From != nullptr && To != nullptr); + if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) + return; + addNode(From); + if (To != From) { + addNode(To); + Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 0}, + Offset); + } + } + + void addDerefEdge(Value *From, Value *To, bool IsRead) { + assert(From != nullptr && To != nullptr); + // FIXME: This is subtly broken, due to how we model some instructions + // (e.g. extractvalue, extractelement) as loads. Since those take + // non-pointer operands, we'll entirely skip adding edges for those. + // + // addAssignEdge seems to have a similar issue with insertvalue, etc. + if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) + return; + addNode(From); + addNode(To); + if (IsRead) { + Graph.addNode(InstantiatedValue{From, 1}); + Graph.addEdge(InstantiatedValue{From, 1}, InstantiatedValue{To, 0}); + } else { + Graph.addNode(InstantiatedValue{To, 1}); + Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 1}); + } + } + + void addLoadEdge(Value *From, Value *To) { addDerefEdge(From, To, true); } + void addStoreEdge(Value *From, Value *To) { addDerefEdge(From, To, false); } + + public: + GetEdgesVisitor(CFLGraphBuilder &Builder, const DataLayout &DL) + : AA(Builder.Analysis), DL(DL), TLI(Builder.TLI), Graph(Builder.Graph), + ReturnValues(Builder.ReturnedValues) {} + + void visitInstruction(Instruction &) { + llvm_unreachable("Unsupported instruction encountered"); + } + + void visitReturnInst(ReturnInst &Inst) { + if (auto RetVal = Inst.getReturnValue()) { + if (RetVal->getType()->isPointerTy()) { + addNode(RetVal); + ReturnValues.push_back(RetVal); + } + } + } + + void visitPtrToIntInst(PtrToIntInst &Inst) { + auto *Ptr = Inst.getOperand(0); + addNode(Ptr, getAttrEscaped()); + } + + void visitIntToPtrInst(IntToPtrInst &Inst) { + auto *Ptr = &Inst; + addNode(Ptr, getAttrUnknown()); + } + + void visitCastInst(CastInst &Inst) { + auto *Src = Inst.getOperand(0); + addAssignEdge(Src, &Inst); + } + + void visitBinaryOperator(BinaryOperator &Inst) { + auto *Op1 = Inst.getOperand(0); + auto *Op2 = Inst.getOperand(1); + addAssignEdge(Op1, &Inst); + addAssignEdge(Op2, &Inst); + } + + void visitUnaryOperator(UnaryOperator &Inst) { + auto *Src = Inst.getOperand(0); + addAssignEdge(Src, &Inst); + } + + void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getNewValOperand(); + addStoreEdge(Val, Ptr); + } + + void visitAtomicRMWInst(AtomicRMWInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValOperand(); + addStoreEdge(Val, Ptr); + } + + void visitPHINode(PHINode &Inst) { + for (Value *Val : Inst.incoming_values()) + addAssignEdge(Val, &Inst); + } + + void visitGEP(GEPOperator &GEPOp) { + uint64_t Offset = UnknownOffset; + APInt APOffset(DL.getPointerSizeInBits(GEPOp.getPointerAddressSpace()), + 0); + if (GEPOp.accumulateConstantOffset(DL, APOffset)) + Offset = APOffset.getSExtValue(); + + auto *Op = GEPOp.getPointerOperand(); + addAssignEdge(Op, &GEPOp, Offset); + } + + void visitGetElementPtrInst(GetElementPtrInst &Inst) { + auto *GEPOp = cast<GEPOperator>(&Inst); + visitGEP(*GEPOp); + } + + void visitSelectInst(SelectInst &Inst) { + // Condition is not processed here (The actual statement producing + // the condition result is processed elsewhere). For select, the + // condition is evaluated, but not loaded, stored, or assigned + // simply as a result of being the condition of a select. + + auto *TrueVal = Inst.getTrueValue(); + auto *FalseVal = Inst.getFalseValue(); + addAssignEdge(TrueVal, &Inst); + addAssignEdge(FalseVal, &Inst); + } + + void visitAllocaInst(AllocaInst &Inst) { addNode(&Inst); } + + void visitLoadInst(LoadInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = &Inst; + addLoadEdge(Ptr, Val); + } + + void visitStoreInst(StoreInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValueOperand(); + addStoreEdge(Val, Ptr); + } + + void visitVAArgInst(VAArgInst &Inst) { + // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it + // does + // two things: + // 1. Loads a value from *((T*)*Ptr). + // 2. Increments (stores to) *Ptr by some target-specific amount. + // For now, we'll handle this like a landingpad instruction (by placing + // the + // result in its own group, and having that group alias externals). + if (Inst.getType()->isPointerTy()) + addNode(&Inst, getAttrUnknown()); + } + + static bool isFunctionExternal(Function *Fn) { + return !Fn->hasExactDefinition(); + } + + bool tryInterproceduralAnalysis(CallBase &Call, + const SmallVectorImpl<Function *> &Fns) { + assert(Fns.size() > 0); + + if (Call.arg_size() > MaxSupportedArgsInSummary) + return false; + + // Exit early if we'll fail anyway + for (auto *Fn : Fns) { + if (isFunctionExternal(Fn) || Fn->isVarArg()) + return false; + // Fail if the caller does not provide enough arguments + assert(Fn->arg_size() <= Call.arg_size()); + if (!AA.getAliasSummary(*Fn)) + return false; + } + + for (auto *Fn : Fns) { + auto Summary = AA.getAliasSummary(*Fn); + assert(Summary != nullptr); + + auto &RetParamRelations = Summary->RetParamRelations; + for (auto &Relation : RetParamRelations) { + auto IRelation = instantiateExternalRelation(Relation, Call); + if (IRelation.hasValue()) { + Graph.addNode(IRelation->From); + Graph.addNode(IRelation->To); + Graph.addEdge(IRelation->From, IRelation->To); + } + } + + auto &RetParamAttributes = Summary->RetParamAttributes; + for (auto &Attribute : RetParamAttributes) { + auto IAttr = instantiateExternalAttribute(Attribute, Call); + if (IAttr.hasValue()) + Graph.addNode(IAttr->IValue, IAttr->Attr); + } + } + + return true; + } + + void visitCallBase(CallBase &Call) { + // Make sure all arguments and return value are added to the graph first + for (Value *V : Call.args()) + if (V->getType()->isPointerTy()) + addNode(V); + if (Call.getType()->isPointerTy()) + addNode(&Call); + + // Check if Inst is a call to a library function that + // allocates/deallocates on the heap. Those kinds of functions do not + // introduce any aliases. + // TODO: address other common library functions such as realloc(), + // strdup(), etc. + if (isMallocOrCallocLikeFn(&Call, &TLI) || isFreeCall(&Call, &TLI)) + return; + + // TODO: Add support for noalias args/all the other fun function + // attributes that we can tack on. + SmallVector<Function *, 4> Targets; + if (getPossibleTargets(Call, Targets)) + if (tryInterproceduralAnalysis(Call, Targets)) + return; + + // Because the function is opaque, we need to note that anything + // could have happened to the arguments (unless the function is marked + // readonly or readnone), and that the result could alias just about + // anything, too (unless the result is marked noalias). + if (!Call.onlyReadsMemory()) + for (Value *V : Call.args()) { + if (V->getType()->isPointerTy()) { + // The argument itself escapes. + Graph.addAttr(InstantiatedValue{V, 0}, getAttrEscaped()); + // The fate of argument memory is unknown. Note that since + // AliasAttrs is transitive with respect to dereference, we only + // need to specify it for the first-level memory. + Graph.addNode(InstantiatedValue{V, 1}, getAttrUnknown()); + } + } + + if (Call.getType()->isPointerTy()) { + auto *Fn = Call.getCalledFunction(); + if (Fn == nullptr || !Fn->returnDoesNotAlias()) + // No need to call addNode() since we've added Inst at the + // beginning of this function and we know it is not a global. + Graph.addAttr(InstantiatedValue{&Call, 0}, getAttrUnknown()); + } + } + + /// Because vectors/aggregates are immutable and unaddressable, there's + /// nothing we can do to coax a value out of them, other than calling + /// Extract{Element,Value}. We can effectively treat them as pointers to + /// arbitrary memory locations we can store in and load from. + void visitExtractElementInst(ExtractElementInst &Inst) { + auto *Ptr = Inst.getVectorOperand(); + auto *Val = &Inst; + addLoadEdge(Ptr, Val); + } + + void visitInsertElementInst(InsertElementInst &Inst) { + auto *Vec = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + addAssignEdge(Vec, &Inst); + addStoreEdge(Val, &Inst); + } + + void visitLandingPadInst(LandingPadInst &Inst) { + // Exceptions come from "nowhere", from our analysis' perspective. + // So we place the instruction its own group, noting that said group may + // alias externals + if (Inst.getType()->isPointerTy()) + addNode(&Inst, getAttrUnknown()); + } + + void visitInsertValueInst(InsertValueInst &Inst) { + auto *Agg = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + addAssignEdge(Agg, &Inst); + addStoreEdge(Val, &Inst); + } + + void visitExtractValueInst(ExtractValueInst &Inst) { + auto *Ptr = Inst.getAggregateOperand(); + addLoadEdge(Ptr, &Inst); + } + + void visitShuffleVectorInst(ShuffleVectorInst &Inst) { + auto *From1 = Inst.getOperand(0); + auto *From2 = Inst.getOperand(1); + addAssignEdge(From1, &Inst); + addAssignEdge(From2, &Inst); + } + + void visitConstantExpr(ConstantExpr *CE) { + switch (CE->getOpcode()) { + case Instruction::GetElementPtr: { + auto GEPOp = cast<GEPOperator>(CE); + visitGEP(*GEPOp); + break; + } + + case Instruction::PtrToInt: { + addNode(CE->getOperand(0), getAttrEscaped()); + break; + } + + case Instruction::IntToPtr: { + addNode(CE, getAttrUnknown()); + break; + } + + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPExt: + case Instruction::FPTrunc: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: { + addAssignEdge(CE->getOperand(0), CE); + break; + } + + case Instruction::Select: { + addAssignEdge(CE->getOperand(1), CE); + addAssignEdge(CE->getOperand(2), CE); + break; + } + + case Instruction::InsertElement: + case Instruction::InsertValue: { + addAssignEdge(CE->getOperand(0), CE); + addStoreEdge(CE->getOperand(1), CE); + break; + } + + case Instruction::ExtractElement: + case Instruction::ExtractValue: { + addLoadEdge(CE->getOperand(0), CE); + break; + } + + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::ShuffleVector: { + addAssignEdge(CE->getOperand(0), CE); + addAssignEdge(CE->getOperand(1), CE); + break; + } + + case Instruction::FNeg: { + addAssignEdge(CE->getOperand(0), CE); + break; + } + + default: + llvm_unreachable("Unknown instruction type encountered!"); + } + } + }; + + // Helper functions + + // Determines whether or not we an instruction is useless to us (e.g. + // FenceInst) + static bool hasUsefulEdges(Instruction *Inst) { + bool IsNonInvokeRetTerminator = Inst->isTerminator() && + !isa<InvokeInst>(Inst) && + !isa<ReturnInst>(Inst); + return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && + !IsNonInvokeRetTerminator; + } + + void addArgumentToGraph(Argument &Arg) { + if (Arg.getType()->isPointerTy()) { + Graph.addNode(InstantiatedValue{&Arg, 0}, + getGlobalOrArgAttrFromValue(Arg)); + // Pointees of a formal parameter is known to the caller + Graph.addNode(InstantiatedValue{&Arg, 1}, getAttrCaller()); + } + } + + // Given an Instruction, this will add it to the graph, along with any + // Instructions that are potentially only available from said Instruction + // For example, given the following line: + // %0 = load i16* getelementptr ([1 x i16]* @a, 0, 0), align 2 + // addInstructionToGraph would add both the `load` and `getelementptr` + // instructions to the graph appropriately. + void addInstructionToGraph(GetEdgesVisitor &Visitor, Instruction &Inst) { + if (!hasUsefulEdges(&Inst)) + return; + + Visitor.visit(Inst); + } + + // Builds the graph needed for constructing the StratifiedSets for the given + // function + void buildGraphFrom(Function &Fn) { + GetEdgesVisitor Visitor(*this, Fn.getParent()->getDataLayout()); + + for (auto &Bb : Fn.getBasicBlockList()) + for (auto &Inst : Bb.getInstList()) + addInstructionToGraph(Visitor, Inst); + + for (auto &Arg : Fn.args()) + addArgumentToGraph(Arg); + } + +public: + CFLGraphBuilder(CFLAA &Analysis, const TargetLibraryInfo &TLI, Function &Fn) + : Analysis(Analysis), TLI(TLI) { + buildGraphFrom(Fn); + } + + const CFLGraph &getCFLGraph() const { return Graph; } + const SmallVector<Value *, 4> &getReturnValues() const { + return ReturnedValues; + } +}; + +} // end namespace cflaa +} // end namespace llvm + +#endif // LLVM_LIB_ANALYSIS_CFLGRAPH_H diff --git a/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp new file mode 100644 index 000000000000..b87aa4065392 --- /dev/null +++ b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp @@ -0,0 +1,363 @@ +//===- CFLSteensAliasAnalysis.cpp - Unification-based Alias Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-base, summary-based alias analysis algorithm. It +// does not depend on types. The algorithm is a mixture of the one described in +// "Demand-driven alias analysis for C" by Xin Zheng and Radu Rugina, and "Fast +// algorithms for Dyck-CFL-reachability with applications to Alias Analysis" by +// Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the papers, we build a +// graph of the uses of a variable, where each node is a memory location, and +// each edge is an action that happened on that memory location. The "actions" +// can be one of Dereference, Reference, or Assign. The precision of this +// analysis is roughly the same as that of an one level context-sensitive +// Steensgaard's algorithm. +// +// Two variables are considered as aliasing iff you can reach one value's node +// from the other value's node and the language formed by concatenating all of +// the edge labels (actions) conforms to a context-free grammar. +// +// Because this algorithm requires a graph search on each query, we execute the +// algorithm outlined in "Fast algorithms..." (mentioned above) +// in order to transform the graph into sets of variables that may alias in +// ~nlogn time (n = number of variables), which makes queries take constant +// time. +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLSteensAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" +#include "AliasAnalysisSummary.h" +#include "CFLGraph.h" +#include "StratifiedSets.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <limits> +#include <memory> +#include <utility> + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-steens-aa" + +CFLSteensAAResult::CFLSteensAAResult( + std::function<const TargetLibraryInfo &(Function &F)> GetTLI) + : AAResultBase(), GetTLI(std::move(GetTLI)) {} +CFLSteensAAResult::CFLSteensAAResult(CFLSteensAAResult &&Arg) + : AAResultBase(std::move(Arg)), GetTLI(std::move(Arg.GetTLI)) {} +CFLSteensAAResult::~CFLSteensAAResult() = default; + +/// Information we have about a function and would like to keep around. +class CFLSteensAAResult::FunctionInfo { + StratifiedSets<InstantiatedValue> Sets; + AliasSummary Summary; + +public: + FunctionInfo(Function &Fn, const SmallVectorImpl<Value *> &RetVals, + StratifiedSets<InstantiatedValue> S); + + const StratifiedSets<InstantiatedValue> &getStratifiedSets() const { + return Sets; + } + + const AliasSummary &getAliasSummary() const { return Summary; } +}; + +const StratifiedIndex StratifiedLink::SetSentinel = + std::numeric_limits<StratifiedIndex>::max(); + +//===----------------------------------------------------------------------===// +// Function declarations that require types defined in the namespace above +//===----------------------------------------------------------------------===// + +/// Determines whether it would be pointless to add the given Value to our sets. +static bool canSkipAddingToSets(Value *Val) { + // Constants can share instances, which may falsely unify multiple + // sets, e.g. in + // store i32* null, i32** %ptr1 + // store i32* null, i32** %ptr2 + // clearly ptr1 and ptr2 should not be unified into the same set, so + // we should filter out the (potentially shared) instance to + // i32* null. + if (isa<Constant>(Val)) { + // TODO: Because all of these things are constant, we can determine whether + // the data is *actually* mutable at graph building time. This will probably + // come for free/cheap with offset awareness. + bool CanStoreMutableData = isa<GlobalValue>(Val) || + isa<ConstantExpr>(Val) || + isa<ConstantAggregate>(Val); + return !CanStoreMutableData; + } + + return false; +} + +CFLSteensAAResult::FunctionInfo::FunctionInfo( + Function &Fn, const SmallVectorImpl<Value *> &RetVals, + StratifiedSets<InstantiatedValue> S) + : Sets(std::move(S)) { + // Historically, an arbitrary upper-bound of 50 args was selected. We may want + // to remove this if it doesn't really matter in practice. + if (Fn.arg_size() > MaxSupportedArgsInSummary) + return; + + DenseMap<StratifiedIndex, InterfaceValue> InterfaceMap; + + // Our intention here is to record all InterfaceValues that share the same + // StratifiedIndex in RetParamRelations. For each valid InterfaceValue, we + // have its StratifiedIndex scanned here and check if the index is presented + // in InterfaceMap: if it is not, we add the correspondence to the map; + // otherwise, an aliasing relation is found and we add it to + // RetParamRelations. + + auto AddToRetParamRelations = [&](unsigned InterfaceIndex, + StratifiedIndex SetIndex) { + unsigned Level = 0; + while (true) { + InterfaceValue CurrValue{InterfaceIndex, Level}; + + auto Itr = InterfaceMap.find(SetIndex); + if (Itr != InterfaceMap.end()) { + if (CurrValue != Itr->second) + Summary.RetParamRelations.push_back( + ExternalRelation{CurrValue, Itr->second, UnknownOffset}); + break; + } + + auto &Link = Sets.getLink(SetIndex); + InterfaceMap.insert(std::make_pair(SetIndex, CurrValue)); + auto ExternalAttrs = getExternallyVisibleAttrs(Link.Attrs); + if (ExternalAttrs.any()) + Summary.RetParamAttributes.push_back( + ExternalAttribute{CurrValue, ExternalAttrs}); + + if (!Link.hasBelow()) + break; + + ++Level; + SetIndex = Link.Below; + } + }; + + // Populate RetParamRelations for return values + for (auto *RetVal : RetVals) { + assert(RetVal != nullptr); + assert(RetVal->getType()->isPointerTy()); + auto RetInfo = Sets.find(InstantiatedValue{RetVal, 0}); + if (RetInfo.hasValue()) + AddToRetParamRelations(0, RetInfo->Index); + } + + // Populate RetParamRelations for parameters + unsigned I = 0; + for (auto &Param : Fn.args()) { + if (Param.getType()->isPointerTy()) { + auto ParamInfo = Sets.find(InstantiatedValue{&Param, 0}); + if (ParamInfo.hasValue()) + AddToRetParamRelations(I + 1, ParamInfo->Index); + } + ++I; + } +} + +// Builds the graph + StratifiedSets for a function. +CFLSteensAAResult::FunctionInfo CFLSteensAAResult::buildSetsFrom(Function *Fn) { + CFLGraphBuilder<CFLSteensAAResult> GraphBuilder(*this, GetTLI(*Fn), *Fn); + StratifiedSetsBuilder<InstantiatedValue> SetBuilder; + + // Add all CFLGraph nodes and all Dereference edges to StratifiedSets + auto &Graph = GraphBuilder.getCFLGraph(); + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + if (canSkipAddingToSets(Val)) + continue; + auto &ValueInfo = Mapping.second; + + assert(ValueInfo.getNumLevels() > 0); + SetBuilder.add(InstantiatedValue{Val, 0}); + SetBuilder.noteAttributes(InstantiatedValue{Val, 0}, + ValueInfo.getNodeInfoAtLevel(0).Attr); + for (unsigned I = 0, E = ValueInfo.getNumLevels() - 1; I < E; ++I) { + SetBuilder.add(InstantiatedValue{Val, I + 1}); + SetBuilder.noteAttributes(InstantiatedValue{Val, I + 1}, + ValueInfo.getNodeInfoAtLevel(I + 1).Attr); + SetBuilder.addBelow(InstantiatedValue{Val, I}, + InstantiatedValue{Val, I + 1}); + } + } + + // Add all assign edges to StratifiedSets + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + if (canSkipAddingToSets(Val)) + continue; + auto &ValueInfo = Mapping.second; + + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Src = InstantiatedValue{Val, I}; + for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) + SetBuilder.addWith(Src, Edge.Other); + } + } + + return FunctionInfo(*Fn, GraphBuilder.getReturnValues(), SetBuilder.build()); +} + +void CFLSteensAAResult::scan(Function *Fn) { + auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call + // may get evaluated after operator[], potentially triggering a DenseMap + // resize and invalidating the reference returned by operator[] + auto FunInfo = buildSetsFrom(Fn); + Cache[Fn] = std::move(FunInfo); + + Handles.emplace_front(Fn, this); +} + +void CFLSteensAAResult::evict(Function *Fn) { Cache.erase(Fn); } + +/// Ensures that the given function is available in the cache, and returns the +/// entry. +const Optional<CFLSteensAAResult::FunctionInfo> & +CFLSteensAAResult::ensureCached(Function *Fn) { + auto Iter = Cache.find(Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; +} + +const AliasSummary *CFLSteensAAResult::getAliasSummary(Function &Fn) { + auto &FunInfo = ensureCached(&Fn); + if (FunInfo.hasValue()) + return &FunInfo->getAliasSummary(); + else + return nullptr; +} + +AliasResult CFLSteensAAResult::query(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + auto *ValA = const_cast<Value *>(LocA.Ptr); + auto *ValB = const_cast<Value *>(LocB.Ptr); + + if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) + return NoAlias; + + Function *Fn = nullptr; + Function *MaybeFnA = const_cast<Function *>(parentFunctionOfValue(ValA)); + Function *MaybeFnB = const_cast<Function *>(parentFunctionOfValue(ValB)); + if (!MaybeFnA && !MaybeFnB) { + // The only times this is known to happen are when globals + InlineAsm are + // involved + LLVM_DEBUG( + dbgs() + << "CFLSteensAA: could not extract parent function information.\n"); + return MayAlias; + } + + if (MaybeFnA) { + Fn = MaybeFnA; + assert((!MaybeFnB || MaybeFnB == MaybeFnA) && + "Interprocedural queries not supported"); + } else { + Fn = MaybeFnB; + } + + assert(Fn != nullptr); + auto &MaybeInfo = ensureCached(Fn); + assert(MaybeInfo.hasValue()); + + auto &Sets = MaybeInfo->getStratifiedSets(); + auto MaybeA = Sets.find(InstantiatedValue{ValA, 0}); + if (!MaybeA.hasValue()) + return MayAlias; + + auto MaybeB = Sets.find(InstantiatedValue{ValB, 0}); + if (!MaybeB.hasValue()) + return MayAlias; + + auto SetA = *MaybeA; + auto SetB = *MaybeB; + auto AttrsA = Sets.getLink(SetA.Index).Attrs; + auto AttrsB = Sets.getLink(SetB.Index).Attrs; + + // If both values are local (meaning the corresponding set has attribute + // AttrNone or AttrEscaped), then we know that CFLSteensAA fully models them: + // they may-alias each other if and only if they are in the same set. + // If at least one value is non-local (meaning it either is global/argument or + // it comes from unknown sources like integer cast), the situation becomes a + // bit more interesting. We follow three general rules described below: + // - Non-local values may alias each other + // - AttrNone values do not alias any non-local values + // - AttrEscaped do not alias globals/arguments, but they may alias + // AttrUnknown values + if (SetA.Index == SetB.Index) + return MayAlias; + if (AttrsA.none() || AttrsB.none()) + return NoAlias; + if (hasUnknownOrCallerAttr(AttrsA) || hasUnknownOrCallerAttr(AttrsB)) + return MayAlias; + if (isGlobalOrArgAttr(AttrsA) && isGlobalOrArgAttr(AttrsB)) + return MayAlias; + return NoAlias; +} + +AnalysisKey CFLSteensAA::Key; + +CFLSteensAAResult CFLSteensAA::run(Function &F, FunctionAnalysisManager &AM) { + auto GetTLI = [&AM](Function &F) -> const TargetLibraryInfo & { + return AM.getResult<TargetLibraryAnalysis>(F); + }; + return CFLSteensAAResult(GetTLI); +} + +char CFLSteensAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLSteensAAWrapperPass, "cfl-steens-aa", + "Unification-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLSteensAAWrapperPass() { + return new CFLSteensAAWrapperPass(); +} + +CFLSteensAAWrapperPass::CFLSteensAAWrapperPass() : ImmutablePass(ID) { + initializeCFLSteensAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLSteensAAWrapperPass::initializePass() { + auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; + Result.reset(new CFLSteensAAResult(GetTLI)); +} + +void CFLSteensAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/llvm/lib/Analysis/CGSCCPassManager.cpp b/llvm/lib/Analysis/CGSCCPassManager.cpp new file mode 100644 index 000000000000..a0b3f83cca6a --- /dev/null +++ b/llvm/lib/Analysis/CGSCCPassManager.cpp @@ -0,0 +1,709 @@ +//===- CGSCCPassManager.cpp - Managing & running CGSCC passes -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <iterator> + +#define DEBUG_TYPE "cgscc" + +using namespace llvm; + +// Explicit template instantiations and specialization definitions for core +// template typedefs. +namespace llvm { + +// Explicit instantiations for the core proxy templates. +template class AllAnalysesOn<LazyCallGraph::SCC>; +template class AnalysisManager<LazyCallGraph::SCC, LazyCallGraph &>; +template class PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, + LazyCallGraph &, CGSCCUpdateResult &>; +template class InnerAnalysisManagerProxy<CGSCCAnalysisManager, Module>; +template class OuterAnalysisManagerProxy<ModuleAnalysisManager, + LazyCallGraph::SCC, LazyCallGraph &>; +template class OuterAnalysisManagerProxy<CGSCCAnalysisManager, Function>; + +/// Explicitly specialize the pass manager run method to handle call graph +/// updates. +template <> +PreservedAnalyses +PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, + CGSCCUpdateResult &>::run(LazyCallGraph::SCC &InitialC, + CGSCCAnalysisManager &AM, + LazyCallGraph &G, CGSCCUpdateResult &UR) { + // Request PassInstrumentation from analysis manager, will use it to run + // instrumenting callbacks for the passes later. + PassInstrumentation PI = + AM.getResult<PassInstrumentationAnalysis>(InitialC, G); + + PreservedAnalyses PA = PreservedAnalyses::all(); + + if (DebugLogging) + dbgs() << "Starting CGSCC pass manager run.\n"; + + // The SCC may be refined while we are running passes over it, so set up + // a pointer that we can update. + LazyCallGraph::SCC *C = &InitialC; + + for (auto &Pass : Passes) { + if (DebugLogging) + dbgs() << "Running pass: " << Pass->name() << " on " << *C << "\n"; + + // Check the PassInstrumentation's BeforePass callbacks before running the + // pass, skip its execution completely if asked to (callback returns false). + if (!PI.runBeforePass(*Pass, *C)) + continue; + + PreservedAnalyses PassPA = Pass->run(*C, AM, G, UR); + + if (UR.InvalidatedSCCs.count(C)) + PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass); + else + PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C); + + // Update the SCC if necessary. + C = UR.UpdatedC ? UR.UpdatedC : C; + + // If the CGSCC pass wasn't able to provide a valid updated SCC, the + // current SCC may simply need to be skipped if invalid. + if (UR.InvalidatedSCCs.count(C)) { + LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n"); + break; + } + // Check that we didn't miss any update scenario. + assert(C->begin() != C->end() && "Cannot have an empty SCC!"); + + // Update the analysis manager as each pass runs and potentially + // invalidates analyses. + AM.invalidate(*C, PassPA); + + // Finally, we intersect the final preserved analyses to compute the + // aggregate preserved set for this pass manager. + PA.intersect(std::move(PassPA)); + + // FIXME: Historically, the pass managers all called the LLVM context's + // yield function here. We don't have a generic way to acquire the + // context and it isn't yet clear what the right pattern is for yielding + // in the new pass manager so it is currently omitted. + // ...getContext().yield(); + } + + // Before we mark all of *this* SCC's analyses as preserved below, intersect + // this with the cross-SCC preserved analysis set. This is used to allow + // CGSCC passes to mutate ancestor SCCs and still trigger proper invalidation + // for them. + UR.CrossSCCPA.intersect(PA); + + // Invalidation was handled after each pass in the above loop for the current + // SCC. Therefore, the remaining analysis results in the AnalysisManager are + // preserved. We mark this with a set so that we don't need to inspect each + // one individually. + PA.preserveSet<AllAnalysesOn<LazyCallGraph::SCC>>(); + + if (DebugLogging) + dbgs() << "Finished CGSCC pass manager run.\n"; + + return PA; +} + +bool CGSCCAnalysisManagerModuleProxy::Result::invalidate( + Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &Inv) { + // If literally everything is preserved, we're done. + if (PA.areAllPreserved()) + return false; // This is still a valid proxy. + + // If this proxy or the call graph is going to be invalidated, we also need + // to clear all the keys coming from that analysis. + // + // We also directly invalidate the FAM's module proxy if necessary, and if + // that proxy isn't preserved we can't preserve this proxy either. We rely on + // it to handle module -> function analysis invalidation in the face of + // structural changes and so if it's unavailable we conservatively clear the + // entire SCC layer as well rather than trying to do invalidation ourselves. + auto PAC = PA.getChecker<CGSCCAnalysisManagerModuleProxy>(); + if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>()) || + Inv.invalidate<LazyCallGraphAnalysis>(M, PA) || + Inv.invalidate<FunctionAnalysisManagerModuleProxy>(M, PA)) { + InnerAM->clear(); + + // And the proxy itself should be marked as invalid so that we can observe + // the new call graph. This isn't strictly necessary because we cheat + // above, but is still useful. + return true; + } + + // Directly check if the relevant set is preserved so we can short circuit + // invalidating SCCs below. + bool AreSCCAnalysesPreserved = + PA.allAnalysesInSetPreserved<AllAnalysesOn<LazyCallGraph::SCC>>(); + + // Ok, we have a graph, so we can propagate the invalidation down into it. + G->buildRefSCCs(); + for (auto &RC : G->postorder_ref_sccs()) + for (auto &C : RC) { + Optional<PreservedAnalyses> InnerPA; + + // Check to see whether the preserved set needs to be adjusted based on + // module-level analysis invalidation triggering deferred invalidation + // for this SCC. + if (auto *OuterProxy = + InnerAM->getCachedResult<ModuleAnalysisManagerCGSCCProxy>(C)) + for (const auto &OuterInvalidationPair : + OuterProxy->getOuterInvalidations()) { + AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; + const auto &InnerAnalysisIDs = OuterInvalidationPair.second; + if (Inv.invalidate(OuterAnalysisID, M, PA)) { + if (!InnerPA) + InnerPA = PA; + for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) + InnerPA->abandon(InnerAnalysisID); + } + } + + // Check if we needed a custom PA set. If so we'll need to run the inner + // invalidation. + if (InnerPA) { + InnerAM->invalidate(C, *InnerPA); + continue; + } + + // Otherwise we only need to do invalidation if the original PA set didn't + // preserve all SCC analyses. + if (!AreSCCAnalysesPreserved) + InnerAM->invalidate(C, PA); + } + + // Return false to indicate that this result is still a valid proxy. + return false; +} + +template <> +CGSCCAnalysisManagerModuleProxy::Result +CGSCCAnalysisManagerModuleProxy::run(Module &M, ModuleAnalysisManager &AM) { + // Force the Function analysis manager to also be available so that it can + // be accessed in an SCC analysis and proxied onward to function passes. + // FIXME: It is pretty awkward to just drop the result here and assert that + // we can find it again later. + (void)AM.getResult<FunctionAnalysisManagerModuleProxy>(M); + + return Result(*InnerAM, AM.getResult<LazyCallGraphAnalysis>(M)); +} + +AnalysisKey FunctionAnalysisManagerCGSCCProxy::Key; + +FunctionAnalysisManagerCGSCCProxy::Result +FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG) { + // Collect the FunctionAnalysisManager from the Module layer and use that to + // build the proxy result. + // + // This allows us to rely on the FunctionAnalysisMangaerModuleProxy to + // invalidate the function analyses. + auto &MAM = AM.getResult<ModuleAnalysisManagerCGSCCProxy>(C, CG).getManager(); + Module &M = *C.begin()->getFunction().getParent(); + auto *FAMProxy = MAM.getCachedResult<FunctionAnalysisManagerModuleProxy>(M); + assert(FAMProxy && "The CGSCC pass manager requires that the FAM module " + "proxy is run on the module prior to entering the CGSCC " + "walk."); + + // Note that we special-case invalidation handling of this proxy in the CGSCC + // analysis manager's Module proxy. This avoids the need to do anything + // special here to recompute all of this if ever the FAM's module proxy goes + // away. + return Result(FAMProxy->getManager()); +} + +bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate( + LazyCallGraph::SCC &C, const PreservedAnalyses &PA, + CGSCCAnalysisManager::Invalidator &Inv) { + // If literally everything is preserved, we're done. + if (PA.areAllPreserved()) + return false; // This is still a valid proxy. + + // If this proxy isn't marked as preserved, then even if the result remains + // valid, the key itself may no longer be valid, so we clear everything. + // + // Note that in order to preserve this proxy, a module pass must ensure that + // the FAM has been completely updated to handle the deletion of functions. + // Specifically, any FAM-cached results for those functions need to have been + // forcibly cleared. When preserved, this proxy will only invalidate results + // cached on functions *still in the module* at the end of the module pass. + auto PAC = PA.getChecker<FunctionAnalysisManagerCGSCCProxy>(); + if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<LazyCallGraph::SCC>>()) { + for (LazyCallGraph::Node &N : C) + FAM->clear(N.getFunction(), N.getFunction().getName()); + + return true; + } + + // Directly check if the relevant set is preserved. + bool AreFunctionAnalysesPreserved = + PA.allAnalysesInSetPreserved<AllAnalysesOn<Function>>(); + + // Now walk all the functions to see if any inner analysis invalidation is + // necessary. + for (LazyCallGraph::Node &N : C) { + Function &F = N.getFunction(); + Optional<PreservedAnalyses> FunctionPA; + + // Check to see whether the preserved set needs to be pruned based on + // SCC-level analysis invalidation that triggers deferred invalidation + // registered with the outer analysis manager proxy for this function. + if (auto *OuterProxy = + FAM->getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F)) + for (const auto &OuterInvalidationPair : + OuterProxy->getOuterInvalidations()) { + AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; + const auto &InnerAnalysisIDs = OuterInvalidationPair.second; + if (Inv.invalidate(OuterAnalysisID, C, PA)) { + if (!FunctionPA) + FunctionPA = PA; + for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) + FunctionPA->abandon(InnerAnalysisID); + } + } + + // Check if we needed a custom PA set, and if so we'll need to run the + // inner invalidation. + if (FunctionPA) { + FAM->invalidate(F, *FunctionPA); + continue; + } + + // Otherwise we only need to do invalidation if the original PA set didn't + // preserve all function analyses. + if (!AreFunctionAnalysesPreserved) + FAM->invalidate(F, PA); + } + + // Return false to indicate that this result is still a valid proxy. + return false; +} + +} // end namespace llvm + +/// When a new SCC is created for the graph and there might be function +/// analysis results cached for the functions now in that SCC two forms of +/// updates are required. +/// +/// First, a proxy from the SCC to the FunctionAnalysisManager needs to be +/// created so that any subsequent invalidation events to the SCC are +/// propagated to the function analysis results cached for functions within it. +/// +/// Second, if any of the functions within the SCC have analysis results with +/// outer analysis dependencies, then those dependencies would point to the +/// *wrong* SCC's analysis result. We forcibly invalidate the necessary +/// function analyses so that they don't retain stale handles. +static void updateNewSCCFunctionAnalyses(LazyCallGraph::SCC &C, + LazyCallGraph &G, + CGSCCAnalysisManager &AM) { + // Get the relevant function analysis manager. + auto &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, G).getManager(); + + // Now walk the functions in this SCC and invalidate any function analysis + // results that might have outer dependencies on an SCC analysis. + for (LazyCallGraph::Node &N : C) { + Function &F = N.getFunction(); + + auto *OuterProxy = + FAM.getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F); + if (!OuterProxy) + // No outer analyses were queried, nothing to do. + continue; + + // Forcibly abandon all the inner analyses with dependencies, but + // invalidate nothing else. + auto PA = PreservedAnalyses::all(); + for (const auto &OuterInvalidationPair : + OuterProxy->getOuterInvalidations()) { + const auto &InnerAnalysisIDs = OuterInvalidationPair.second; + for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) + PA.abandon(InnerAnalysisID); + } + + // Now invalidate anything we found. + FAM.invalidate(F, PA); + } +} + +/// Helper function to update both the \c CGSCCAnalysisManager \p AM and the \c +/// CGSCCPassManager's \c CGSCCUpdateResult \p UR based on a range of newly +/// added SCCs. +/// +/// The range of new SCCs must be in postorder already. The SCC they were split +/// out of must be provided as \p C. The current node being mutated and +/// triggering updates must be passed as \p N. +/// +/// This function returns the SCC containing \p N. This will be either \p C if +/// no new SCCs have been split out, or it will be the new SCC containing \p N. +template <typename SCCRangeT> +static LazyCallGraph::SCC * +incorporateNewSCCRange(const SCCRangeT &NewSCCRange, LazyCallGraph &G, + LazyCallGraph::Node &N, LazyCallGraph::SCC *C, + CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) { + using SCC = LazyCallGraph::SCC; + + if (NewSCCRange.begin() == NewSCCRange.end()) + return C; + + // Add the current SCC to the worklist as its shape has changed. + UR.CWorklist.insert(C); + LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist:" << *C + << "\n"); + + SCC *OldC = C; + + // Update the current SCC. Note that if we have new SCCs, this must actually + // change the SCC. + assert(C != &*NewSCCRange.begin() && + "Cannot insert new SCCs without changing current SCC!"); + C = &*NewSCCRange.begin(); + assert(G.lookupSCC(N) == C && "Failed to update current SCC!"); + + // If we had a cached FAM proxy originally, we will want to create more of + // them for each SCC that was split off. + bool NeedFAMProxy = + AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>(*OldC) != nullptr; + + // We need to propagate an invalidation call to all but the newly current SCC + // because the outer pass manager won't do that for us after splitting them. + // FIXME: We should accept a PreservedAnalysis from the CG updater so that if + // there are preserved analysis we can avoid invalidating them here for + // split-off SCCs. + // We know however that this will preserve any FAM proxy so go ahead and mark + // that. + PreservedAnalyses PA; + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + AM.invalidate(*OldC, PA); + + // Ensure the now-current SCC's function analyses are updated. + if (NeedFAMProxy) + updateNewSCCFunctionAnalyses(*C, G, AM); + + for (SCC &NewC : llvm::reverse(make_range(std::next(NewSCCRange.begin()), + NewSCCRange.end()))) { + assert(C != &NewC && "No need to re-visit the current SCC!"); + assert(OldC != &NewC && "Already handled the original SCC!"); + UR.CWorklist.insert(&NewC); + LLVM_DEBUG(dbgs() << "Enqueuing a newly formed SCC:" << NewC << "\n"); + + // Ensure new SCCs' function analyses are updated. + if (NeedFAMProxy) + updateNewSCCFunctionAnalyses(NewC, G, AM); + + // Also propagate a normal invalidation to the new SCC as only the current + // will get one from the pass manager infrastructure. + AM.invalidate(NewC, PA); + } + return C; +} + +LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( + LazyCallGraph &G, LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N, + CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) { + using Node = LazyCallGraph::Node; + using Edge = LazyCallGraph::Edge; + using SCC = LazyCallGraph::SCC; + using RefSCC = LazyCallGraph::RefSCC; + + RefSCC &InitialRC = InitialC.getOuterRefSCC(); + SCC *C = &InitialC; + RefSCC *RC = &InitialRC; + Function &F = N.getFunction(); + + // Walk the function body and build up the set of retained, promoted, and + // demoted edges. + SmallVector<Constant *, 16> Worklist; + SmallPtrSet<Constant *, 16> Visited; + SmallPtrSet<Node *, 16> RetainedEdges; + SmallSetVector<Node *, 4> PromotedRefTargets; + SmallSetVector<Node *, 4> DemotedCallTargets; + + // First walk the function and handle all called functions. We do this first + // because if there is a single call edge, whether there are ref edges is + // irrelevant. + for (Instruction &I : instructions(F)) + if (auto CS = CallSite(&I)) + if (Function *Callee = CS.getCalledFunction()) + if (Visited.insert(Callee).second && !Callee->isDeclaration()) { + Node &CalleeN = *G.lookup(*Callee); + Edge *E = N->lookup(CalleeN); + // FIXME: We should really handle adding new calls. While it will + // make downstream usage more complex, there is no fundamental + // limitation and it will allow passes within the CGSCC to be a bit + // more flexible in what transforms they can do. Until then, we + // verify that new calls haven't been introduced. + assert(E && "No function transformations should introduce *new* " + "call edges! Any new calls should be modeled as " + "promoted existing ref edges!"); + bool Inserted = RetainedEdges.insert(&CalleeN).second; + (void)Inserted; + assert(Inserted && "We should never visit a function twice."); + if (!E->isCall()) + PromotedRefTargets.insert(&CalleeN); + } + + // Now walk all references. + for (Instruction &I : instructions(F)) + for (Value *Op : I.operand_values()) + if (auto *C = dyn_cast<Constant>(Op)) + if (Visited.insert(C).second) + Worklist.push_back(C); + + auto VisitRef = [&](Function &Referee) { + Node &RefereeN = *G.lookup(Referee); + Edge *E = N->lookup(RefereeN); + // FIXME: Similarly to new calls, we also currently preclude + // introducing new references. See above for details. + assert(E && "No function transformations should introduce *new* ref " + "edges! Any new ref edges would require IPO which " + "function passes aren't allowed to do!"); + bool Inserted = RetainedEdges.insert(&RefereeN).second; + (void)Inserted; + assert(Inserted && "We should never visit a function twice."); + if (E->isCall()) + DemotedCallTargets.insert(&RefereeN); + }; + LazyCallGraph::visitReferences(Worklist, Visited, VisitRef); + + // Include synthetic reference edges to known, defined lib functions. + for (auto *F : G.getLibFunctions()) + // While the list of lib functions doesn't have repeats, don't re-visit + // anything handled above. + if (!Visited.count(F)) + VisitRef(*F); + + // First remove all of the edges that are no longer present in this function. + // The first step makes these edges uniformly ref edges and accumulates them + // into a separate data structure so removal doesn't invalidate anything. + SmallVector<Node *, 4> DeadTargets; + for (Edge &E : *N) { + if (RetainedEdges.count(&E.getNode())) + continue; + + SCC &TargetC = *G.lookupSCC(E.getNode()); + RefSCC &TargetRC = TargetC.getOuterRefSCC(); + if (&TargetRC == RC && E.isCall()) { + if (C != &TargetC) { + // For separate SCCs this is trivial. + RC->switchTrivialInternalEdgeToRef(N, E.getNode()); + } else { + // Now update the call graph. + C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, E.getNode()), + G, N, C, AM, UR); + } + } + + // Now that this is ready for actual removal, put it into our list. + DeadTargets.push_back(&E.getNode()); + } + // Remove the easy cases quickly and actually pull them out of our list. + DeadTargets.erase( + llvm::remove_if(DeadTargets, + [&](Node *TargetN) { + SCC &TargetC = *G.lookupSCC(*TargetN); + RefSCC &TargetRC = TargetC.getOuterRefSCC(); + + // We can't trivially remove internal targets, so skip + // those. + if (&TargetRC == RC) + return false; + + RC->removeOutgoingEdge(N, *TargetN); + LLVM_DEBUG(dbgs() << "Deleting outgoing edge from '" + << N << "' to '" << TargetN << "'\n"); + return true; + }), + DeadTargets.end()); + + // Now do a batch removal of the internal ref edges left. + auto NewRefSCCs = RC->removeInternalRefEdge(N, DeadTargets); + if (!NewRefSCCs.empty()) { + // The old RefSCC is dead, mark it as such. + UR.InvalidatedRefSCCs.insert(RC); + + // Note that we don't bother to invalidate analyses as ref-edge + // connectivity is not really observable in any way and is intended + // exclusively to be used for ordering of transforms rather than for + // analysis conclusions. + + // Update RC to the "bottom". + assert(G.lookupSCC(N) == C && "Changed the SCC when splitting RefSCCs!"); + RC = &C->getOuterRefSCC(); + assert(G.lookupRefSCC(N) == RC && "Failed to update current RefSCC!"); + + // The RC worklist is in reverse postorder, so we enqueue the new ones in + // RPO except for the one which contains the source node as that is the + // "bottom" we will continue processing in the bottom-up walk. + assert(NewRefSCCs.front() == RC && + "New current RefSCC not first in the returned list!"); + for (RefSCC *NewRC : llvm::reverse(make_range(std::next(NewRefSCCs.begin()), + NewRefSCCs.end()))) { + assert(NewRC != RC && "Should not encounter the current RefSCC further " + "in the postorder list of new RefSCCs."); + UR.RCWorklist.insert(NewRC); + LLVM_DEBUG(dbgs() << "Enqueuing a new RefSCC in the update worklist: " + << *NewRC << "\n"); + } + } + + // Next demote all the call edges that are now ref edges. This helps make + // the SCCs small which should minimize the work below as we don't want to + // form cycles that this would break. + for (Node *RefTarget : DemotedCallTargets) { + SCC &TargetC = *G.lookupSCC(*RefTarget); + RefSCC &TargetRC = TargetC.getOuterRefSCC(); + + // The easy case is when the target RefSCC is not this RefSCC. This is + // only supported when the target RefSCC is a child of this RefSCC. + if (&TargetRC != RC) { + assert(RC->isAncestorOf(TargetRC) && + "Cannot potentially form RefSCC cycles here!"); + RC->switchOutgoingEdgeToRef(N, *RefTarget); + LLVM_DEBUG(dbgs() << "Switch outgoing call edge to a ref edge from '" << N + << "' to '" << *RefTarget << "'\n"); + continue; + } + + // We are switching an internal call edge to a ref edge. This may split up + // some SCCs. + if (C != &TargetC) { + // For separate SCCs this is trivial. + RC->switchTrivialInternalEdgeToRef(N, *RefTarget); + continue; + } + + // Now update the call graph. + C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, *RefTarget), G, N, + C, AM, UR); + } + + // Now promote ref edges into call edges. + for (Node *CallTarget : PromotedRefTargets) { + SCC &TargetC = *G.lookupSCC(*CallTarget); + RefSCC &TargetRC = TargetC.getOuterRefSCC(); + + // The easy case is when the target RefSCC is not this RefSCC. This is + // only supported when the target RefSCC is a child of this RefSCC. + if (&TargetRC != RC) { + assert(RC->isAncestorOf(TargetRC) && + "Cannot potentially form RefSCC cycles here!"); + RC->switchOutgoingEdgeToCall(N, *CallTarget); + LLVM_DEBUG(dbgs() << "Switch outgoing ref edge to a call edge from '" << N + << "' to '" << *CallTarget << "'\n"); + continue; + } + LLVM_DEBUG(dbgs() << "Switch an internal ref edge to a call edge from '" + << N << "' to '" << *CallTarget << "'\n"); + + // Otherwise we are switching an internal ref edge to a call edge. This + // may merge away some SCCs, and we add those to the UpdateResult. We also + // need to make sure to update the worklist in the event SCCs have moved + // before the current one in the post-order sequence + bool HasFunctionAnalysisProxy = false; + auto InitialSCCIndex = RC->find(*C) - RC->begin(); + bool FormedCycle = RC->switchInternalEdgeToCall( + N, *CallTarget, [&](ArrayRef<SCC *> MergedSCCs) { + for (SCC *MergedC : MergedSCCs) { + assert(MergedC != &TargetC && "Cannot merge away the target SCC!"); + + HasFunctionAnalysisProxy |= + AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>( + *MergedC) != nullptr; + + // Mark that this SCC will no longer be valid. + UR.InvalidatedSCCs.insert(MergedC); + + // FIXME: We should really do a 'clear' here to forcibly release + // memory, but we don't have a good way of doing that and + // preserving the function analyses. + auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>(); + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + AM.invalidate(*MergedC, PA); + } + }); + + // If we formed a cycle by creating this call, we need to update more data + // structures. + if (FormedCycle) { + C = &TargetC; + assert(G.lookupSCC(N) == C && "Failed to update current SCC!"); + + // If one of the invalidated SCCs had a cached proxy to a function + // analysis manager, we need to create a proxy in the new current SCC as + // the invalidated SCCs had their functions moved. + if (HasFunctionAnalysisProxy) + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, G); + + // Any analyses cached for this SCC are no longer precise as the shape + // has changed by introducing this cycle. However, we have taken care to + // update the proxies so it remains valide. + auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>(); + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + AM.invalidate(*C, PA); + } + auto NewSCCIndex = RC->find(*C) - RC->begin(); + // If we have actually moved an SCC to be topologically "below" the current + // one due to merging, we will need to revisit the current SCC after + // visiting those moved SCCs. + // + // It is critical that we *do not* revisit the current SCC unless we + // actually move SCCs in the process of merging because otherwise we may + // form a cycle where an SCC is split apart, merged, split, merged and so + // on infinitely. + if (InitialSCCIndex < NewSCCIndex) { + // Put our current SCC back onto the worklist as we'll visit other SCCs + // that are now definitively ordered prior to the current one in the + // post-order sequence, and may end up observing more precise context to + // optimize the current SCC. + UR.CWorklist.insert(C); + LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist: " << *C + << "\n"); + // Enqueue in reverse order as we pop off the back of the worklist. + for (SCC &MovedC : llvm::reverse(make_range(RC->begin() + InitialSCCIndex, + RC->begin() + NewSCCIndex))) { + UR.CWorklist.insert(&MovedC); + LLVM_DEBUG(dbgs() << "Enqueuing a newly earlier in post-order SCC: " + << MovedC << "\n"); + } + } + } + + assert(!UR.InvalidatedSCCs.count(C) && "Invalidated the current SCC!"); + assert(!UR.InvalidatedRefSCCs.count(RC) && "Invalidated the current RefSCC!"); + assert(&C->getOuterRefSCC() == RC && "Current SCC not in current RefSCC!"); + + // Record the current RefSCC and SCC for higher layers of the CGSCC pass + // manager now that all the updates have been applied. + if (RC != &InitialRC) + UR.UpdatedRC = RC; + if (C != &InitialC) + UR.UpdatedC = C; + + return *C; +} diff --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp new file mode 100644 index 000000000000..70aeb1a688ee --- /dev/null +++ b/llvm/lib/Analysis/CallGraph.cpp @@ -0,0 +1,326 @@ +//===- CallGraph.cpp - Build a Module's call graph ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallGraph.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraph class methods. +// + +CallGraph::CallGraph(Module &M) + : M(M), ExternalCallingNode(getOrInsertFunction(nullptr)), + CallsExternalNode(std::make_unique<CallGraphNode>(nullptr)) { + // Add every function to the call graph. + for (Function &F : M) + addToCallGraph(&F); +} + +CallGraph::CallGraph(CallGraph &&Arg) + : M(Arg.M), FunctionMap(std::move(Arg.FunctionMap)), + ExternalCallingNode(Arg.ExternalCallingNode), + CallsExternalNode(std::move(Arg.CallsExternalNode)) { + Arg.FunctionMap.clear(); + Arg.ExternalCallingNode = nullptr; +} + +CallGraph::~CallGraph() { + // CallsExternalNode is not in the function map, delete it explicitly. + if (CallsExternalNode) + CallsExternalNode->allReferencesDropped(); + +// Reset all node's use counts to zero before deleting them to prevent an +// assertion from firing. +#ifndef NDEBUG + for (auto &I : FunctionMap) + I.second->allReferencesDropped(); +#endif +} + +void CallGraph::addToCallGraph(Function *F) { + CallGraphNode *Node = getOrInsertFunction(F); + + // If this function has external linkage or has its address taken, anything + // could call it. + if (!F->hasLocalLinkage() || F->hasAddressTaken()) + ExternalCallingNode->addCalledFunction(nullptr, Node); + + // If this function is not defined in this translation unit, it could call + // anything. + if (F->isDeclaration() && !F->isIntrinsic()) + Node->addCalledFunction(nullptr, CallsExternalNode.get()); + + // Look for calls by this function. + for (BasicBlock &BB : *F) + for (Instruction &I : BB) { + if (auto *Call = dyn_cast<CallBase>(&I)) { + const Function *Callee = Call->getCalledFunction(); + if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) + // Indirect calls of intrinsics are not allowed so no need to check. + // We can be more precise here by using TargetArg returned by + // Intrinsic::isLeaf. + Node->addCalledFunction(Call, CallsExternalNode.get()); + else if (!Callee->isIntrinsic()) + Node->addCalledFunction(Call, getOrInsertFunction(Callee)); + } + } +} + +void CallGraph::print(raw_ostream &OS) const { + // Print in a deterministic order by sorting CallGraphNodes by name. We do + // this here to avoid slowing down the non-printing fast path. + + SmallVector<CallGraphNode *, 16> Nodes; + Nodes.reserve(FunctionMap.size()); + + for (const auto &I : *this) + Nodes.push_back(I.second.get()); + + llvm::sort(Nodes, [](CallGraphNode *LHS, CallGraphNode *RHS) { + if (Function *LF = LHS->getFunction()) + if (Function *RF = RHS->getFunction()) + return LF->getName() < RF->getName(); + + return RHS->getFunction() != nullptr; + }); + + for (CallGraphNode *CN : Nodes) + CN->print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraph::dump() const { print(dbgs()); } +#endif + +// removeFunctionFromModule - Unlink the function from this module, returning +// it. Because this removes the function from the module, the call graph node +// is destroyed. This is only valid if the function does not call any other +// functions (ie, there are no edges in it's CGN). The easiest way to do this +// is to dropAllReferences before calling this. +// +Function *CallGraph::removeFunctionFromModule(CallGraphNode *CGN) { + assert(CGN->empty() && "Cannot remove function from call " + "graph if it references other functions!"); + Function *F = CGN->getFunction(); // Get the function for the call graph node + FunctionMap.erase(F); // Remove the call graph node from the map + + M.getFunctionList().remove(F); + return F; +} + +/// spliceFunction - Replace the function represented by this node by another. +/// This does not rescan the body of the function, so it is suitable when +/// splicing the body of the old function to the new while also updating all +/// callers from old to new. +void CallGraph::spliceFunction(const Function *From, const Function *To) { + assert(FunctionMap.count(From) && "No CallGraphNode for function!"); + assert(!FunctionMap.count(To) && + "Pointing CallGraphNode at a function that already exists"); + FunctionMapTy::iterator I = FunctionMap.find(From); + I->second->F = const_cast<Function*>(To); + FunctionMap[To] = std::move(I->second); + FunctionMap.erase(I); +} + +// getOrInsertFunction - This method is identical to calling operator[], but +// it will insert a new CallGraphNode for the specified function if one does +// not already exist. +CallGraphNode *CallGraph::getOrInsertFunction(const Function *F) { + auto &CGN = FunctionMap[F]; + if (CGN) + return CGN.get(); + + assert((!F || F->getParent() == &M) && "Function not in current module!"); + CGN = std::make_unique<CallGraphNode>(const_cast<Function *>(F)); + return CGN.get(); +} + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraphNode class methods. +// + +void CallGraphNode::print(raw_ostream &OS) const { + if (Function *F = getFunction()) + OS << "Call graph node for function: '" << F->getName() << "'"; + else + OS << "Call graph node <<null function>>"; + + OS << "<<" << this << ">> #uses=" << getNumReferences() << '\n'; + + for (const auto &I : *this) { + OS << " CS<" << I.first << "> calls "; + if (Function *FI = I.second->getFunction()) + OS << "function '" << FI->getName() <<"'\n"; + else + OS << "external node\n"; + } + OS << '\n'; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraphNode::dump() const { print(dbgs()); } +#endif + +/// removeCallEdgeFor - This method removes the edge in the node for the +/// specified call site. Note that this method takes linear time, so it +/// should be used sparingly. +void CallGraphNode::removeCallEdgeFor(CallBase &Call) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callsite to remove!"); + if (I->first == &Call) { + I->second->DropRef(); + *I = CalledFunctions.back(); + CalledFunctions.pop_back(); + return; + } + } +} + +// removeAnyCallEdgeTo - This method removes any call edges from this node to +// the specified callee function. This takes more time to execute than +// removeCallEdgeTo, so it should not be used unless necessary. +void CallGraphNode::removeAnyCallEdgeTo(CallGraphNode *Callee) { + for (unsigned i = 0, e = CalledFunctions.size(); i != e; ++i) + if (CalledFunctions[i].second == Callee) { + Callee->DropRef(); + CalledFunctions[i] = CalledFunctions.back(); + CalledFunctions.pop_back(); + --i; --e; + } +} + +/// removeOneAbstractEdgeTo - Remove one edge associated with a null callsite +/// from this node to the specified callee function. +void CallGraphNode::removeOneAbstractEdgeTo(CallGraphNode *Callee) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callee to remove!"); + CallRecord &CR = *I; + if (CR.second == Callee && CR.first == nullptr) { + Callee->DropRef(); + *I = CalledFunctions.back(); + CalledFunctions.pop_back(); + return; + } + } +} + +/// replaceCallEdge - This method replaces the edge in the node for the +/// specified call site with a new one. Note that this method takes linear +/// time, so it should be used sparingly. +void CallGraphNode::replaceCallEdge(CallBase &Call, CallBase &NewCall, + CallGraphNode *NewNode) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callsite to remove!"); + if (I->first == &Call) { + I->second->DropRef(); + I->first = &NewCall; + I->second = NewNode; + NewNode->AddRef(); + return; + } + } +} + +// Provide an explicit template instantiation for the static ID. +AnalysisKey CallGraphAnalysis::Key; + +PreservedAnalyses CallGraphPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + AM.getResult<CallGraphAnalysis>(M).print(OS); + return PreservedAnalyses::all(); +} + +//===----------------------------------------------------------------------===// +// Out-of-line definitions of CallGraphAnalysis class members. +// + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraphWrapperPass class methods. +// + +CallGraphWrapperPass::CallGraphWrapperPass() : ModulePass(ID) { + initializeCallGraphWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +CallGraphWrapperPass::~CallGraphWrapperPass() = default; + +void CallGraphWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} + +bool CallGraphWrapperPass::runOnModule(Module &M) { + // All the real work is done in the constructor for the CallGraph. + G.reset(new CallGraph(M)); + return false; +} + +INITIALIZE_PASS(CallGraphWrapperPass, "basiccg", "CallGraph Construction", + false, true) + +char CallGraphWrapperPass::ID = 0; + +void CallGraphWrapperPass::releaseMemory() { G.reset(); } + +void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { + if (!G) { + OS << "No call graph has been built!\n"; + return; + } + + // Just delegate. + G->print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD +void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } +#endif + +namespace { + +struct CallGraphPrinterLegacyPass : public ModulePass { + static char ID; // Pass ID, replacement for typeid + + CallGraphPrinterLegacyPass() : ModulePass(ID) { + initializeCallGraphPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequiredTransitive<CallGraphWrapperPass>(); + } + + bool runOnModule(Module &M) override { + getAnalysis<CallGraphWrapperPass>().print(errs(), &M); + return false; + } +}; + +} // end anonymous namespace + +char CallGraphPrinterLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(CallGraphPrinterLegacyPass, "print-callgraph", + "Print a call graph", true, true) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_END(CallGraphPrinterLegacyPass, "print-callgraph", + "Print a call graph", true, true) diff --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp new file mode 100644 index 000000000000..196ef400bc4e --- /dev/null +++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -0,0 +1,711 @@ +//===- CallGraphSCCPass.cpp - Pass that operates BU on call graph ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the CallGraphSCCPass class, which is used for passes +// which are implemented as bottom-up traversals on the call graph. Because +// there may be cycles in the call graph, passes of this type operate on the +// call-graph in SCC order: that is, they process function bottom-up, except for +// recursive functions, which they process all at once. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManagers.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassTimingInfo.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "cgscc-passmgr" + +static cl::opt<unsigned> +MaxIterations("max-cg-scc-iterations", cl::ReallyHidden, cl::init(4)); + +STATISTIC(MaxSCCIterations, "Maximum CGSCCPassMgr iterations on one SCC"); + +//===----------------------------------------------------------------------===// +// CGPassManager +// +/// CGPassManager manages FPPassManagers and CallGraphSCCPasses. + +namespace { + +class CGPassManager : public ModulePass, public PMDataManager { +public: + static char ID; + + explicit CGPassManager() : ModulePass(ID), PMDataManager() {} + + /// Execute all of the passes scheduled for execution. Keep track of + /// whether any of the passes modifies the module, and if so, return true. + bool runOnModule(Module &M) override; + + using ModulePass::doInitialization; + using ModulePass::doFinalization; + + bool doInitialization(CallGraph &CG); + bool doFinalization(CallGraph &CG); + + /// Pass Manager itself does not invalidate any analysis info. + void getAnalysisUsage(AnalysisUsage &Info) const override { + // CGPassManager walks SCC and it needs CallGraph. + Info.addRequired<CallGraphWrapperPass>(); + Info.setPreservesAll(); + } + + StringRef getPassName() const override { return "CallGraph Pass Manager"; } + + PMDataManager *getAsPMDataManager() override { return this; } + Pass *getAsPass() override { return this; } + + // Print passes managed by this manager + void dumpPassStructure(unsigned Offset) override { + errs().indent(Offset*2) << "Call Graph SCC Pass Manager\n"; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + P->dumpPassStructure(Offset + 1); + dumpLastUses(P, Offset+1); + } + } + + Pass *getContainedPass(unsigned N) { + assert(N < PassVector.size() && "Pass number out of range!"); + return static_cast<Pass *>(PassVector[N]); + } + + PassManagerType getPassManagerType() const override { + return PMT_CallGraphPassManager; + } + +private: + bool RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, + bool &DevirtualizedCall); + + bool RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC, + CallGraph &CG, bool &CallGraphUpToDate, + bool &DevirtualizedCall); + bool RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, + bool IsCheckingMode); +}; + +} // end anonymous namespace. + +char CGPassManager::ID = 0; + +bool CGPassManager::RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC, + CallGraph &CG, bool &CallGraphUpToDate, + bool &DevirtualizedCall) { + bool Changed = false; + PMDataManager *PM = P->getAsPMDataManager(); + Module &M = CG.getModule(); + + if (!PM) { + CallGraphSCCPass *CGSP = (CallGraphSCCPass *)P; + if (!CallGraphUpToDate) { + DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false); + CallGraphUpToDate = true; + } + + { + unsigned InstrCount, SCCCount = 0; + StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount; + bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); + TimeRegion PassTimer(getPassTimer(CGSP)); + if (EmitICRemark) + InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount); + Changed = CGSP->runOnSCC(CurSCC); + + if (EmitICRemark) { + // FIXME: Add getInstructionCount to CallGraphSCC. + SCCCount = M.getInstructionCount(); + // Is there a difference in the number of instructions in the module? + if (SCCCount != InstrCount) { + // Yep. Emit a remark and update InstrCount. + int64_t Delta = + static_cast<int64_t>(SCCCount) - static_cast<int64_t>(InstrCount); + emitInstrCountChangedRemark(P, M, Delta, InstrCount, + FunctionToInstrCount); + InstrCount = SCCCount; + } + } + } + + // After the CGSCCPass is done, when assertions are enabled, use + // RefreshCallGraph to verify that the callgraph was correctly updated. +#ifndef NDEBUG + if (Changed) + RefreshCallGraph(CurSCC, CG, true); +#endif + + return Changed; + } + + assert(PM->getPassManagerType() == PMT_FunctionPassManager && + "Invalid CGPassManager member"); + FPPassManager *FPP = (FPPassManager*)P; + + // Run pass P on all functions in the current SCC. + for (CallGraphNode *CGN : CurSCC) { + if (Function *F = CGN->getFunction()) { + dumpPassInfo(P, EXECUTION_MSG, ON_FUNCTION_MSG, F->getName()); + { + TimeRegion PassTimer(getPassTimer(FPP)); + Changed |= FPP->runOnFunction(*F); + } + F->getContext().yield(); + } + } + + // The function pass(es) modified the IR, they may have clobbered the + // callgraph. + if (Changed && CallGraphUpToDate) { + LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Pass Dirtied SCC: " << P->getPassName() + << '\n'); + CallGraphUpToDate = false; + } + return Changed; +} + +/// Scan the functions in the specified CFG and resync the +/// callgraph with the call sites found in it. This is used after +/// FunctionPasses have potentially munged the callgraph, and can be used after +/// CallGraphSCC passes to verify that they correctly updated the callgraph. +/// +/// This function returns true if it devirtualized an existing function call, +/// meaning it turned an indirect call into a direct call. This happens when +/// a function pass like GVN optimizes away stuff feeding the indirect call. +/// This never happens in checking mode. +bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, + bool CheckingMode) { + DenseMap<Value *, CallGraphNode *> Calls; + + LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Refreshing SCC with " << CurSCC.size() + << " nodes:\n"; + for (CallGraphNode *CGN + : CurSCC) CGN->dump();); + + bool MadeChange = false; + bool DevirtualizedCall = false; + + // Scan all functions in the SCC. + unsigned FunctionNo = 0; + for (CallGraphSCC::iterator SCCIdx = CurSCC.begin(), E = CurSCC.end(); + SCCIdx != E; ++SCCIdx, ++FunctionNo) { + CallGraphNode *CGN = *SCCIdx; + Function *F = CGN->getFunction(); + if (!F || F->isDeclaration()) continue; + + // Walk the function body looking for call sites. Sync up the call sites in + // CGN with those actually in the function. + + // Keep track of the number of direct and indirect calls that were + // invalidated and removed. + unsigned NumDirectRemoved = 0, NumIndirectRemoved = 0; + + // Get the set of call sites currently in the function. + for (CallGraphNode::iterator I = CGN->begin(), E = CGN->end(); I != E; ) { + // If this call site is null, then the function pass deleted the call + // entirely and the WeakTrackingVH nulled it out. + auto *Call = dyn_cast_or_null<CallBase>(I->first); + if (!I->first || + // If we've already seen this call site, then the FunctionPass RAUW'd + // one call with another, which resulted in two "uses" in the edge + // list of the same call. + Calls.count(I->first) || + + // If the call edge is not from a call or invoke, or it is a + // instrinsic call, then the function pass RAUW'd a call with + // another value. This can happen when constant folding happens + // of well known functions etc. + !Call || + (Call->getCalledFunction() && + Call->getCalledFunction()->isIntrinsic() && + Intrinsic::isLeaf(Call->getCalledFunction()->getIntrinsicID()))) { + assert(!CheckingMode && + "CallGraphSCCPass did not update the CallGraph correctly!"); + + // If this was an indirect call site, count it. + if (!I->second->getFunction()) + ++NumIndirectRemoved; + else + ++NumDirectRemoved; + + // Just remove the edge from the set of callees, keep track of whether + // I points to the last element of the vector. + bool WasLast = I + 1 == E; + CGN->removeCallEdge(I); + + // If I pointed to the last element of the vector, we have to bail out: + // iterator checking rejects comparisons of the resultant pointer with + // end. + if (WasLast) + break; + E = CGN->end(); + continue; + } + + assert(!Calls.count(I->first) && + "Call site occurs in node multiple times"); + + if (Call) { + Function *Callee = Call->getCalledFunction(); + // Ignore intrinsics because they're not really function calls. + if (!Callee || !(Callee->isIntrinsic())) + Calls.insert(std::make_pair(I->first, I->second)); + } + ++I; + } + + // Loop over all of the instructions in the function, getting the callsites. + // Keep track of the number of direct/indirect calls added. + unsigned NumDirectAdded = 0, NumIndirectAdded = 0; + + for (BasicBlock &BB : *F) + for (Instruction &I : BB) { + auto *Call = dyn_cast<CallBase>(&I); + if (!Call) + continue; + Function *Callee = Call->getCalledFunction(); + if (Callee && Callee->isIntrinsic()) + continue; + + // If this call site already existed in the callgraph, just verify it + // matches up to expectations and remove it from Calls. + DenseMap<Value *, CallGraphNode *>::iterator ExistingIt = + Calls.find(Call); + if (ExistingIt != Calls.end()) { + CallGraphNode *ExistingNode = ExistingIt->second; + + // Remove from Calls since we have now seen it. + Calls.erase(ExistingIt); + + // Verify that the callee is right. + if (ExistingNode->getFunction() == Call->getCalledFunction()) + continue; + + // If we are in checking mode, we are not allowed to actually mutate + // the callgraph. If this is a case where we can infer that the + // callgraph is less precise than it could be (e.g. an indirect call + // site could be turned direct), don't reject it in checking mode, and + // don't tweak it to be more precise. + if (CheckingMode && Call->getCalledFunction() && + ExistingNode->getFunction() == nullptr) + continue; + + assert(!CheckingMode && + "CallGraphSCCPass did not update the CallGraph correctly!"); + + // If not, we either went from a direct call to indirect, indirect to + // direct, or direct to different direct. + CallGraphNode *CalleeNode; + if (Function *Callee = Call->getCalledFunction()) { + CalleeNode = CG.getOrInsertFunction(Callee); + // Keep track of whether we turned an indirect call into a direct + // one. + if (!ExistingNode->getFunction()) { + DevirtualizedCall = true; + LLVM_DEBUG(dbgs() << " CGSCCPASSMGR: Devirtualized call to '" + << Callee->getName() << "'\n"); + } + } else { + CalleeNode = CG.getCallsExternalNode(); + } + + // Update the edge target in CGN. + CGN->replaceCallEdge(*Call, *Call, CalleeNode); + MadeChange = true; + continue; + } + + assert(!CheckingMode && + "CallGraphSCCPass did not update the CallGraph correctly!"); + + // If the call site didn't exist in the CGN yet, add it. + CallGraphNode *CalleeNode; + if (Function *Callee = Call->getCalledFunction()) { + CalleeNode = CG.getOrInsertFunction(Callee); + ++NumDirectAdded; + } else { + CalleeNode = CG.getCallsExternalNode(); + ++NumIndirectAdded; + } + + CGN->addCalledFunction(Call, CalleeNode); + MadeChange = true; + } + + // We scanned the old callgraph node, removing invalidated call sites and + // then added back newly found call sites. One thing that can happen is + // that an old indirect call site was deleted and replaced with a new direct + // call. In this case, we have devirtualized a call, and CGSCCPM would like + // to iteratively optimize the new code. Unfortunately, we don't really + // have a great way to detect when this happens. As an approximation, we + // just look at whether the number of indirect calls is reduced and the + // number of direct calls is increased. There are tons of ways to fool this + // (e.g. DCE'ing an indirect call and duplicating an unrelated block with a + // direct call) but this is close enough. + if (NumIndirectRemoved > NumIndirectAdded && + NumDirectRemoved < NumDirectAdded) + DevirtualizedCall = true; + + // After scanning this function, if we still have entries in callsites, then + // they are dangling pointers. WeakTrackingVH should save us for this, so + // abort if + // this happens. + assert(Calls.empty() && "Dangling pointers found in call sites map"); + + // Periodically do an explicit clear to remove tombstones when processing + // large scc's. + if ((FunctionNo & 15) == 15) + Calls.clear(); + } + + LLVM_DEBUG(if (MadeChange) { + dbgs() << "CGSCCPASSMGR: Refreshed SCC is now:\n"; + for (CallGraphNode *CGN : CurSCC) + CGN->dump(); + if (DevirtualizedCall) + dbgs() << "CGSCCPASSMGR: Refresh devirtualized a call!\n"; + } else { + dbgs() << "CGSCCPASSMGR: SCC Refresh didn't change call graph.\n"; + }); + (void)MadeChange; + + return DevirtualizedCall; +} + +/// Execute the body of the entire pass manager on the specified SCC. +/// This keeps track of whether a function pass devirtualizes +/// any calls and returns it in DevirtualizedCall. +bool CGPassManager::RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, + bool &DevirtualizedCall) { + bool Changed = false; + + // Keep track of whether the callgraph is known to be up-to-date or not. + // The CGSSC pass manager runs two types of passes: + // CallGraphSCC Passes and other random function passes. Because other + // random function passes are not CallGraph aware, they may clobber the + // call graph by introducing new calls or deleting other ones. This flag + // is set to false when we run a function pass so that we know to clean up + // the callgraph when we need to run a CGSCCPass again. + bool CallGraphUpToDate = true; + + // Run all passes on current SCC. + for (unsigned PassNo = 0, e = getNumContainedPasses(); + PassNo != e; ++PassNo) { + Pass *P = getContainedPass(PassNo); + + // If we're in -debug-pass=Executions mode, construct the SCC node list, + // otherwise avoid constructing this string as it is expensive. + if (isPassDebuggingExecutionsOrMore()) { + std::string Functions; + #ifndef NDEBUG + raw_string_ostream OS(Functions); + for (CallGraphSCC::iterator I = CurSCC.begin(), E = CurSCC.end(); + I != E; ++I) { + if (I != CurSCC.begin()) OS << ", "; + (*I)->print(OS); + } + OS.flush(); + #endif + dumpPassInfo(P, EXECUTION_MSG, ON_CG_MSG, Functions); + } + dumpRequiredSet(P); + + initializeAnalysisImpl(P); + + // Actually run this pass on the current SCC. + Changed |= RunPassOnSCC(P, CurSCC, CG, + CallGraphUpToDate, DevirtualizedCall); + + if (Changed) + dumpPassInfo(P, MODIFICATION_MSG, ON_CG_MSG, ""); + dumpPreservedSet(P); + + verifyPreservedAnalysis(P); + removeNotPreservedAnalysis(P); + recordAvailableAnalysis(P); + removeDeadPasses(P, "", ON_CG_MSG); + } + + // If the callgraph was left out of date (because the last pass run was a + // functionpass), refresh it before we move on to the next SCC. + if (!CallGraphUpToDate) + DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false); + return Changed; +} + +/// Execute all of the passes scheduled for execution. Keep track of +/// whether any of the passes modifies the module, and if so, return true. +bool CGPassManager::runOnModule(Module &M) { + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + bool Changed = doInitialization(CG); + + // Walk the callgraph in bottom-up SCC order. + scc_iterator<CallGraph*> CGI = scc_begin(&CG); + + CallGraphSCC CurSCC(CG, &CGI); + while (!CGI.isAtEnd()) { + // Copy the current SCC and increment past it so that the pass can hack + // on the SCC if it wants to without invalidating our iterator. + const std::vector<CallGraphNode *> &NodeVec = *CGI; + CurSCC.initialize(NodeVec); + ++CGI; + + // At the top level, we run all the passes in this pass manager on the + // functions in this SCC. However, we support iterative compilation in the + // case where a function pass devirtualizes a call to a function. For + // example, it is very common for a function pass (often GVN or instcombine) + // to eliminate the addressing that feeds into a call. With that improved + // information, we would like the call to be an inline candidate, infer + // mod-ref information etc. + // + // Because of this, we allow iteration up to a specified iteration count. + // This only happens in the case of a devirtualized call, so we only burn + // compile time in the case that we're making progress. We also have a hard + // iteration count limit in case there is crazy code. + unsigned Iteration = 0; + bool DevirtualizedCall = false; + do { + LLVM_DEBUG(if (Iteration) dbgs() + << " SCCPASSMGR: Re-visiting SCC, iteration #" << Iteration + << '\n'); + DevirtualizedCall = false; + Changed |= RunAllPassesOnSCC(CurSCC, CG, DevirtualizedCall); + } while (Iteration++ < MaxIterations && DevirtualizedCall); + + if (DevirtualizedCall) + LLVM_DEBUG(dbgs() << " CGSCCPASSMGR: Stopped iteration after " + << Iteration + << " times, due to -max-cg-scc-iterations\n"); + + MaxSCCIterations.updateMax(Iteration); + } + Changed |= doFinalization(CG); + return Changed; +} + +/// Initialize CG +bool CGPassManager::doInitialization(CallGraph &CG) { + bool Changed = false; + for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) { + if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) { + assert(PM->getPassManagerType() == PMT_FunctionPassManager && + "Invalid CGPassManager member"); + Changed |= ((FPPassManager*)PM)->doInitialization(CG.getModule()); + } else { + Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doInitialization(CG); + } + } + return Changed; +} + +/// Finalize CG +bool CGPassManager::doFinalization(CallGraph &CG) { + bool Changed = false; + for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) { + if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) { + assert(PM->getPassManagerType() == PMT_FunctionPassManager && + "Invalid CGPassManager member"); + Changed |= ((FPPassManager*)PM)->doFinalization(CG.getModule()); + } else { + Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doFinalization(CG); + } + } + return Changed; +} + +//===----------------------------------------------------------------------===// +// CallGraphSCC Implementation +//===----------------------------------------------------------------------===// + +/// This informs the SCC and the pass manager that the specified +/// Old node has been deleted, and New is to be used in its place. +void CallGraphSCC::ReplaceNode(CallGraphNode *Old, CallGraphNode *New) { + assert(Old != New && "Should not replace node with self"); + for (unsigned i = 0; ; ++i) { + assert(i != Nodes.size() && "Node not in SCC"); + if (Nodes[i] != Old) continue; + Nodes[i] = New; + break; + } + + // Update the active scc_iterator so that it doesn't contain dangling + // pointers to the old CallGraphNode. + scc_iterator<CallGraph*> *CGI = (scc_iterator<CallGraph*>*)Context; + CGI->ReplaceNode(Old, New); +} + +//===----------------------------------------------------------------------===// +// CallGraphSCCPass Implementation +//===----------------------------------------------------------------------===// + +/// Assign pass manager to manage this pass. +void CallGraphSCCPass::assignPassManager(PMStack &PMS, + PassManagerType PreferredType) { + // Find CGPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_CallGraphPassManager) + PMS.pop(); + + assert(!PMS.empty() && "Unable to handle Call Graph Pass"); + CGPassManager *CGP; + + if (PMS.top()->getPassManagerType() == PMT_CallGraphPassManager) + CGP = (CGPassManager*)PMS.top(); + else { + // Create new Call Graph SCC Pass Manager if it does not exist. + assert(!PMS.empty() && "Unable to create Call Graph Pass Manager"); + PMDataManager *PMD = PMS.top(); + + // [1] Create new Call Graph Pass Manager + CGP = new CGPassManager(); + + // [2] Set up new manager's top level manager + PMTopLevelManager *TPM = PMD->getTopLevelManager(); + TPM->addIndirectPassManager(CGP); + + // [3] Assign manager to manage this new manager. This may create + // and push new managers into PMS + Pass *P = CGP; + TPM->schedulePass(P); + + // [4] Push new manager into PMS + PMS.push(CGP); + } + + CGP->add(this); +} + +/// For this class, we declare that we require and preserve the call graph. +/// If the derived class implements this method, it should +/// always explicitly call the implementation here. +void CallGraphSCCPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<CallGraphWrapperPass>(); + AU.addPreserved<CallGraphWrapperPass>(); +} + +//===----------------------------------------------------------------------===// +// PrintCallGraphPass Implementation +//===----------------------------------------------------------------------===// + +namespace { + + /// PrintCallGraphPass - Print a Module corresponding to a call graph. + /// + class PrintCallGraphPass : public CallGraphSCCPass { + std::string Banner; + raw_ostream &OS; // raw_ostream to print on. + + public: + static char ID; + + PrintCallGraphPass(const std::string &B, raw_ostream &OS) + : CallGraphSCCPass(ID), Banner(B), OS(OS) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + bool runOnSCC(CallGraphSCC &SCC) override { + bool BannerPrinted = false; + auto PrintBannerOnce = [&]() { + if (BannerPrinted) + return; + OS << Banner; + BannerPrinted = true; + }; + + bool NeedModule = llvm::forcePrintModuleIR(); + if (isFunctionInPrintList("*") && NeedModule) { + PrintBannerOnce(); + OS << "\n"; + SCC.getCallGraph().getModule().print(OS, nullptr); + return false; + } + bool FoundFunction = false; + for (CallGraphNode *CGN : SCC) { + if (Function *F = CGN->getFunction()) { + if (!F->isDeclaration() && isFunctionInPrintList(F->getName())) { + FoundFunction = true; + if (!NeedModule) { + PrintBannerOnce(); + F->print(OS); + } + } + } else if (isFunctionInPrintList("*")) { + PrintBannerOnce(); + OS << "\nPrinting <null> Function\n"; + } + } + if (NeedModule && FoundFunction) { + PrintBannerOnce(); + OS << "\n"; + SCC.getCallGraph().getModule().print(OS, nullptr); + } + return false; + } + + StringRef getPassName() const override { return "Print CallGraph IR"; } + }; + +} // end anonymous namespace. + +char PrintCallGraphPass::ID = 0; + +Pass *CallGraphSCCPass::createPrinterPass(raw_ostream &OS, + const std::string &Banner) const { + return new PrintCallGraphPass(Banner, OS); +} + +static std::string getDescription(const CallGraphSCC &SCC) { + std::string Desc = "SCC ("; + bool First = true; + for (CallGraphNode *CGN : SCC) { + if (First) + First = false; + else + Desc += ", "; + Function *F = CGN->getFunction(); + if (F) + Desc += F->getName(); + else + Desc += "<<null function>>"; + } + Desc += ")"; + return Desc; +} + +bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const { + OptPassGate &Gate = + SCC.getCallGraph().getModule().getContext().getOptPassGate(); + return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(SCC)); +} + +char DummyCGSCCPass::ID = 0; + +INITIALIZE_PASS(DummyCGSCCPass, "DummyCGSCCPass", "DummyCGSCCPass", false, + false) diff --git a/llvm/lib/Analysis/CallPrinter.cpp b/llvm/lib/Analysis/CallPrinter.cpp new file mode 100644 index 000000000000..d24cbd104bf6 --- /dev/null +++ b/llvm/lib/Analysis/CallPrinter.cpp @@ -0,0 +1,91 @@ +//===- CallPrinter.cpp - DOT printer for call graph -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot +// containing the call graph of a module. +// +// There is also a pass available to directly call dotty ('-view-callgraph'). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallPrinter.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" + +using namespace llvm; + +namespace llvm { + +template <> struct DOTGraphTraits<CallGraph *> : public DefaultDOTGraphTraits { + DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + + static std::string getGraphName(CallGraph *Graph) { return "Call graph"; } + + std::string getNodeLabel(CallGraphNode *Node, CallGraph *Graph) { + if (Function *Func = Node->getFunction()) + return Func->getName(); + + return "external node"; + } +}; + +struct AnalysisCallGraphWrapperPassTraits { + static CallGraph *getGraph(CallGraphWrapperPass *P) { + return &P->getCallGraph(); + } +}; + +} // end llvm namespace + +namespace { + +struct CallGraphViewer + : public DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *, + AnalysisCallGraphWrapperPassTraits> { + static char ID; + + CallGraphViewer() + : DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *, + AnalysisCallGraphWrapperPassTraits>( + "callgraph", ID) { + initializeCallGraphViewerPass(*PassRegistry::getPassRegistry()); + } +}; + +struct CallGraphDOTPrinter : public DOTGraphTraitsModulePrinter< + CallGraphWrapperPass, true, CallGraph *, + AnalysisCallGraphWrapperPassTraits> { + static char ID; + + CallGraphDOTPrinter() + : DOTGraphTraitsModulePrinter<CallGraphWrapperPass, true, CallGraph *, + AnalysisCallGraphWrapperPassTraits>( + "callgraph", ID) { + initializeCallGraphDOTPrinterPass(*PassRegistry::getPassRegistry()); + } +}; + +} // end anonymous namespace + +char CallGraphViewer::ID = 0; +INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false, + false) + +char CallGraphDOTPrinter::ID = 0; +INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph", + "Print call graph to 'dot' file", false, false) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); } + +ModulePass *llvm::createCallGraphDOTPrinterPass() { + return new CallGraphDOTPrinter(); +} diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp new file mode 100644 index 000000000000..20e2f06540a3 --- /dev/null +++ b/llvm/lib/Analysis/CaptureTracking.cpp @@ -0,0 +1,389 @@ +//===--- CaptureTracking.cpp - Determine whether a pointer is captured ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help determine which pointers are captured. +// A pointer value is captured if the function makes a copy of any part of the +// pointer that outlives the call. Not being captured means, more or less, that +// the pointer is only dereferenced and not stored in a global. Returning part +// of the pointer as the function return value may or may not count as capturing +// the pointer, depending on the context. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" + +using namespace llvm; + +CaptureTracker::~CaptureTracker() {} + +bool CaptureTracker::shouldExplore(const Use *U) { return true; } + +bool CaptureTracker::isDereferenceableOrNull(Value *O, const DataLayout &DL) { + // An inbounds GEP can either be a valid pointer (pointing into + // or to the end of an allocation), or be null in the default + // address space. So for an inbounds GEP there is no way to let + // the pointer escape using clever GEP hacking because doing so + // would make the pointer point outside of the allocated object + // and thus make the GEP result a poison value. Similarly, other + // dereferenceable pointers cannot be manipulated without producing + // poison. + if (auto *GEP = dyn_cast<GetElementPtrInst>(O)) + if (GEP->isInBounds()) + return true; + bool CanBeNull; + return O->getPointerDereferenceableBytes(DL, CanBeNull); +} + +namespace { + struct SimpleCaptureTracker : public CaptureTracker { + explicit SimpleCaptureTracker(bool ReturnCaptures) + : ReturnCaptures(ReturnCaptures), Captured(false) {} + + void tooManyUses() override { Captured = true; } + + bool captured(const Use *U) override { + if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) + return false; + + Captured = true; + return true; + } + + bool ReturnCaptures; + + bool Captured; + }; + + /// Only find pointer captures which happen before the given instruction. Uses + /// the dominator tree to determine whether one instruction is before another. + /// Only support the case where the Value is defined in the same basic block + /// as the given instruction and the use. + struct CapturesBefore : public CaptureTracker { + + CapturesBefore(bool ReturnCaptures, const Instruction *I, const DominatorTree *DT, + bool IncludeI, OrderedBasicBlock *IC) + : OrderedBB(IC), BeforeHere(I), DT(DT), + ReturnCaptures(ReturnCaptures), IncludeI(IncludeI), Captured(false) {} + + void tooManyUses() override { Captured = true; } + + bool isSafeToPrune(Instruction *I) { + BasicBlock *BB = I->getParent(); + // We explore this usage only if the usage can reach "BeforeHere". + // If use is not reachable from entry, there is no need to explore. + if (BeforeHere != I && !DT->isReachableFromEntry(BB)) + return true; + + // Compute the case where both instructions are inside the same basic + // block. Since instructions in the same BB as BeforeHere are numbered in + // 'OrderedBB', avoid using 'dominates' and 'isPotentiallyReachable' + // which are very expensive for large basic blocks. + if (BB == BeforeHere->getParent()) { + // 'I' dominates 'BeforeHere' => not safe to prune. + // + // The value defined by an invoke dominates an instruction only + // if it dominates every instruction in UseBB. A PHI is dominated only + // if the instruction dominates every possible use in the UseBB. Since + // UseBB == BB, avoid pruning. + if (isa<InvokeInst>(BeforeHere) || isa<PHINode>(I) || I == BeforeHere) + return false; + if (!OrderedBB->dominates(BeforeHere, I)) + return false; + + // 'BeforeHere' comes before 'I', it's safe to prune if we also + // guarantee that 'I' never reaches 'BeforeHere' through a back-edge or + // by its successors, i.e, prune if: + // + // (1) BB is an entry block or have no successors. + // (2) There's no path coming back through BB successors. + if (BB == &BB->getParent()->getEntryBlock() || + !BB->getTerminator()->getNumSuccessors()) + return true; + + SmallVector<BasicBlock*, 32> Worklist; + Worklist.append(succ_begin(BB), succ_end(BB)); + return !isPotentiallyReachableFromMany(Worklist, BB, nullptr, DT); + } + + // If the value is defined in the same basic block as use and BeforeHere, + // there is no need to explore the use if BeforeHere dominates use. + // Check whether there is a path from I to BeforeHere. + if (BeforeHere != I && DT->dominates(BeforeHere, I) && + !isPotentiallyReachable(I, BeforeHere, nullptr, DT)) + return true; + + return false; + } + + bool shouldExplore(const Use *U) override { + Instruction *I = cast<Instruction>(U->getUser()); + + if (BeforeHere == I && !IncludeI) + return false; + + if (isSafeToPrune(I)) + return false; + + return true; + } + + bool captured(const Use *U) override { + if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) + return false; + + if (!shouldExplore(U)) + return false; + + Captured = true; + return true; + } + + OrderedBasicBlock *OrderedBB; + const Instruction *BeforeHere; + const DominatorTree *DT; + + bool ReturnCaptures; + bool IncludeI; + + bool Captured; + }; +} + +/// PointerMayBeCaptured - Return true if this pointer value may be captured +/// by the enclosing function (which is required to exist). This routine can +/// be expensive, so consider caching the results. The boolean ReturnCaptures +/// specifies whether returning the value (or part of it) from the function +/// counts as capturing it or not. The boolean StoreCaptures specified whether +/// storing the value (or part of it) into memory anywhere automatically +/// counts as capturing it or not. +bool llvm::PointerMayBeCaptured(const Value *V, + bool ReturnCaptures, bool StoreCaptures, + unsigned MaxUsesToExplore) { + assert(!isa<GlobalValue>(V) && + "It doesn't make sense to ask whether a global is captured."); + + // TODO: If StoreCaptures is not true, we could do Fancy analysis + // to determine whether this store is not actually an escape point. + // In that case, BasicAliasAnalysis should be updated as well to + // take advantage of this. + (void)StoreCaptures; + + SimpleCaptureTracker SCT(ReturnCaptures); + PointerMayBeCaptured(V, &SCT, MaxUsesToExplore); + return SCT.Captured; +} + +/// PointerMayBeCapturedBefore - Return true if this pointer value may be +/// captured by the enclosing function (which is required to exist). If a +/// DominatorTree is provided, only captures which happen before the given +/// instruction are considered. This routine can be expensive, so consider +/// caching the results. The boolean ReturnCaptures specifies whether +/// returning the value (or part of it) from the function counts as capturing +/// it or not. The boolean StoreCaptures specified whether storing the value +/// (or part of it) into memory anywhere automatically counts as capturing it +/// or not. A ordered basic block \p OBB can be used in order to speed up +/// queries about relative order among instructions in the same basic block. +bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, + bool StoreCaptures, const Instruction *I, + const DominatorTree *DT, bool IncludeI, + OrderedBasicBlock *OBB, + unsigned MaxUsesToExplore) { + assert(!isa<GlobalValue>(V) && + "It doesn't make sense to ask whether a global is captured."); + bool UseNewOBB = OBB == nullptr; + + if (!DT) + return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures, + MaxUsesToExplore); + if (UseNewOBB) + OBB = new OrderedBasicBlock(I->getParent()); + + // TODO: See comment in PointerMayBeCaptured regarding what could be done + // with StoreCaptures. + + CapturesBefore CB(ReturnCaptures, I, DT, IncludeI, OBB); + PointerMayBeCaptured(V, &CB, MaxUsesToExplore); + + if (UseNewOBB) + delete OBB; + return CB.Captured; +} + +void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker, + unsigned MaxUsesToExplore) { + assert(V->getType()->isPointerTy() && "Capture is for pointers only!"); + SmallVector<const Use *, DefaultMaxUsesToExplore> Worklist; + SmallSet<const Use *, DefaultMaxUsesToExplore> Visited; + + auto AddUses = [&](const Value *V) { + unsigned Count = 0; + for (const Use &U : V->uses()) { + // If there are lots of uses, conservatively say that the value + // is captured to avoid taking too much compile time. + if (Count++ >= MaxUsesToExplore) + return Tracker->tooManyUses(); + if (!Visited.insert(&U).second) + continue; + if (!Tracker->shouldExplore(&U)) + continue; + Worklist.push_back(&U); + } + }; + AddUses(V); + + while (!Worklist.empty()) { + const Use *U = Worklist.pop_back_val(); + Instruction *I = cast<Instruction>(U->getUser()); + V = U->get(); + + switch (I->getOpcode()) { + case Instruction::Call: + case Instruction::Invoke: { + auto *Call = cast<CallBase>(I); + // Not captured if the callee is readonly, doesn't return a copy through + // its return value and doesn't unwind (a readonly function can leak bits + // by throwing an exception or not depending on the input value). + if (Call->onlyReadsMemory() && Call->doesNotThrow() && + Call->getType()->isVoidTy()) + break; + + // The pointer is not captured if returned pointer is not captured. + // NOTE: CaptureTracking users should not assume that only functions + // marked with nocapture do not capture. This means that places like + // GetUnderlyingObject in ValueTracking or DecomposeGEPExpression + // in BasicAA also need to know about this property. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call, + true)) { + AddUses(Call); + break; + } + + // Volatile operations effectively capture the memory location that they + // load and store to. + if (auto *MI = dyn_cast<MemIntrinsic>(Call)) + if (MI->isVolatile()) + if (Tracker->captured(U)) + return; + + // Not captured if only passed via 'nocapture' arguments. Note that + // calling a function pointer does not in itself cause the pointer to + // be captured. This is a subtle point considering that (for example) + // the callee might return its own address. It is analogous to saying + // that loading a value from a pointer does not cause the pointer to be + // captured, even though the loaded value might be the pointer itself + // (think of self-referential objects). + for (auto IdxOpPair : enumerate(Call->data_ops())) { + int Idx = IdxOpPair.index(); + Value *A = IdxOpPair.value(); + if (A == V && !Call->doesNotCapture(Idx)) + // The parameter is not marked 'nocapture' - captured. + if (Tracker->captured(U)) + return; + } + break; + } + case Instruction::Load: + // Volatile loads make the address observable. + if (cast<LoadInst>(I)->isVolatile()) + if (Tracker->captured(U)) + return; + break; + case Instruction::VAArg: + // "va-arg" from a pointer does not cause it to be captured. + break; + case Instruction::Store: + // Stored the pointer - conservatively assume it may be captured. + // Volatile stores make the address observable. + if (V == I->getOperand(0) || cast<StoreInst>(I)->isVolatile()) + if (Tracker->captured(U)) + return; + break; + case Instruction::AtomicRMW: { + // atomicrmw conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ARMWI = cast<AtomicRMWInst>(I); + if (ARMWI->getValOperand() == V || ARMWI->isVolatile()) + if (Tracker->captured(U)) + return; + break; + } + case Instruction::AtomicCmpXchg: { + // cmpxchg conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ACXI = cast<AtomicCmpXchgInst>(I); + if (ACXI->getCompareOperand() == V || ACXI->getNewValOperand() == V || + ACXI->isVolatile()) + if (Tracker->captured(U)) + return; + break; + } + case Instruction::BitCast: + case Instruction::GetElementPtr: + case Instruction::PHI: + case Instruction::Select: + case Instruction::AddrSpaceCast: + // The original value is not captured via this if the new value isn't. + AddUses(I); + break; + case Instruction::ICmp: { + unsigned Idx = (I->getOperand(0) == V) ? 0 : 1; + unsigned OtherIdx = 1 - Idx; + if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) { + // Don't count comparisons of a no-alias return value against null as + // captures. This allows us to ignore comparisons of malloc results + // with null, for example. + if (CPN->getType()->getAddressSpace() == 0) + if (isNoAliasCall(V->stripPointerCasts())) + break; + if (!I->getFunction()->nullPointerIsDefined()) { + auto *O = I->getOperand(Idx)->stripPointerCastsSameRepresentation(); + // Comparing a dereferenceable_or_null pointer against null cannot + // lead to pointer escapes, because if it is not null it must be a + // valid (in-bounds) pointer. + if (Tracker->isDereferenceableOrNull(O, I->getModule()->getDataLayout())) + break; + } + } + // Comparison against value stored in global variable. Given the pointer + // does not escape, its value cannot be guessed and stored separately in a + // global variable. + auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIdx)); + if (LI && isa<GlobalVariable>(LI->getPointerOperand())) + break; + // Otherwise, be conservative. There are crazy ways to capture pointers + // using comparisons. + if (Tracker->captured(U)) + return; + break; + } + default: + // Something else - be conservative and say it is captured. + if (Tracker->captured(U)) + return; + break; + } + } + + // All uses examined. +} diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp new file mode 100644 index 000000000000..a5757be2c4f4 --- /dev/null +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -0,0 +1,143 @@ +//===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse compare instructions +// and fold them into constants or other compare instructions +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; + +unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { + ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate() + : ICI->getPredicate(); + switch (Pred) { + // False -> 0 + case ICmpInst::ICMP_UGT: return 1; // 001 + case ICmpInst::ICMP_SGT: return 1; // 001 + case ICmpInst::ICMP_EQ: return 2; // 010 + case ICmpInst::ICMP_UGE: return 3; // 011 + case ICmpInst::ICMP_SGE: return 3; // 011 + case ICmpInst::ICMP_ULT: return 4; // 100 + case ICmpInst::ICMP_SLT: return 4; // 100 + case ICmpInst::ICMP_NE: return 5; // 101 + case ICmpInst::ICMP_ULE: return 6; // 110 + case ICmpInst::ICMP_SLE: return 6; // 110 + // True -> 7 + default: + llvm_unreachable("Invalid ICmp predicate!"); + } +} + +Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy, + CmpInst::Predicate &Pred) { + switch (Code) { + default: llvm_unreachable("Illegal ICmp code!"); + case 0: // False. + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); + case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; + case 2: Pred = ICmpInst::ICMP_EQ; break; + case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; + case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; + case 5: Pred = ICmpInst::ICMP_NE; break; + case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; + case 7: // True. + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); + } + return nullptr; +} + +bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) { + return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) || + (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) || + (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1)); +} + +bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, + CmpInst::Predicate &Pred, + Value *&X, APInt &Mask, bool LookThruTrunc) { + using namespace PatternMatch; + + const APInt *C; + if (!match(RHS, m_APInt(C))) + return false; + + switch (Pred) { + default: + return false; + case ICmpInst::ICMP_SLT: + // X < 0 is equivalent to (X & SignMask) != 0. + if (!C->isNullValue()) + return false; + Mask = APInt::getSignMask(C->getBitWidth()); + Pred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_SLE: + // X <= -1 is equivalent to (X & SignMask) != 0. + if (!C->isAllOnesValue()) + return false; + Mask = APInt::getSignMask(C->getBitWidth()); + Pred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_SGT: + // X > -1 is equivalent to (X & SignMask) == 0. + if (!C->isAllOnesValue()) + return false; + Mask = APInt::getSignMask(C->getBitWidth()); + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_SGE: + // X >= 0 is equivalent to (X & SignMask) == 0. + if (!C->isNullValue()) + return false; + Mask = APInt::getSignMask(C->getBitWidth()); + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_ULT: + // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. + if (!C->isPowerOf2()) + return false; + Mask = -*C; + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_ULE: + // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0. + if (!(*C + 1).isPowerOf2()) + return false; + Mask = ~*C; + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_UGT: + // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. + if (!(*C + 1).isPowerOf2()) + return false; + Mask = ~*C; + Pred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_UGE: + // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0. + if (!C->isPowerOf2()) + return false; + Mask = -*C; + Pred = ICmpInst::ICMP_NE; + break; + } + + if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) { + Mask = Mask.zext(X->getType()->getScalarSizeInBits()); + } else { + X = LHS; + } + + return true; +} diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp new file mode 100644 index 000000000000..627d955c865f --- /dev/null +++ b/llvm/lib/Analysis/CodeMetrics.cpp @@ -0,0 +1,195 @@ +//===- CodeMetrics.cpp - Code cost measurements ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements code cost measurement utilities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "code-metrics" + +using namespace llvm; + +static void +appendSpeculatableOperands(const Value *V, + SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<const Value *> &Worklist) { + const User *U = dyn_cast<User>(V); + if (!U) + return; + + for (const Value *Operand : U->operands()) + if (Visited.insert(Operand).second) + if (isSafeToSpeculativelyExecute(Operand)) + Worklist.push_back(Operand); +} + +static void completeEphemeralValues(SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<const Value *> &Worklist, + SmallPtrSetImpl<const Value *> &EphValues) { + // Note: We don't speculate PHIs here, so we'll miss instruction chains kept + // alive only by ephemeral values. + + // Walk the worklist using an index but without caching the size so we can + // append more entries as we process the worklist. This forms a queue without + // quadratic behavior by just leaving processed nodes at the head of the + // worklist forever. + for (int i = 0; i < (int)Worklist.size(); ++i) { + const Value *V = Worklist[i]; + + assert(Visited.count(V) && + "Failed to add a worklist entry to our visited set!"); + + // If all uses of this value are ephemeral, then so is this value. + if (!all_of(V->users(), [&](const User *U) { return EphValues.count(U); })) + continue; + + EphValues.insert(V); + LLVM_DEBUG(dbgs() << "Ephemeral Value: " << *V << "\n"); + + // Append any more operands to consider. + appendSpeculatableOperands(V, Visited, Worklist); + } +} + +// Find all ephemeral values. +void CodeMetrics::collectEphemeralValues( + const Loop *L, AssumptionCache *AC, + SmallPtrSetImpl<const Value *> &EphValues) { + SmallPtrSet<const Value *, 32> Visited; + SmallVector<const Value *, 16> Worklist; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + Instruction *I = cast<Instruction>(AssumeVH); + + // Filter out call sites outside of the loop so we don't do a function's + // worth of work for each of its loops (and, in the common case, ephemeral + // values in the loop are likely due to @llvm.assume calls in the loop). + if (!L->contains(I->getParent())) + continue; + + if (EphValues.insert(I).second) + appendSpeculatableOperands(I, Visited, Worklist); + } + + completeEphemeralValues(Visited, Worklist, EphValues); +} + +void CodeMetrics::collectEphemeralValues( + const Function *F, AssumptionCache *AC, + SmallPtrSetImpl<const Value *> &EphValues) { + SmallPtrSet<const Value *, 32> Visited; + SmallVector<const Value *, 16> Worklist; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + Instruction *I = cast<Instruction>(AssumeVH); + assert(I->getParent()->getParent() == F && + "Found assumption for the wrong function!"); + + if (EphValues.insert(I).second) + appendSpeculatableOperands(I, Visited, Worklist); + } + + completeEphemeralValues(Visited, Worklist, EphValues); +} + +/// Fill in the current structure with information gleaned from the specified +/// block. +void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, + const TargetTransformInfo &TTI, + const SmallPtrSetImpl<const Value*> &EphValues) { + ++NumBlocks; + unsigned NumInstsBeforeThisBB = NumInsts; + for (const Instruction &I : *BB) { + // Skip ephemeral values. + if (EphValues.count(&I)) + continue; + + // Special handling for calls. + if (const auto *Call = dyn_cast<CallBase>(&I)) { + if (const Function *F = Call->getCalledFunction()) { + // If a function is both internal and has a single use, then it is + // extremely likely to get inlined in the future (it was probably + // exposed by an interleaved devirtualization pass). + if (!Call->isNoInline() && F->hasInternalLinkage() && F->hasOneUse()) + ++NumInlineCandidates; + + // If this call is to function itself, then the function is recursive. + // Inlining it into other functions is a bad idea, because this is + // basically just a form of loop peeling, and our metrics aren't useful + // for that case. + if (F == BB->getParent()) + isRecursive = true; + + if (TTI.isLoweredToCall(F)) + ++NumCalls; + } else { + // We don't want inline asm to count as a call - that would prevent loop + // unrolling. The argument setup cost is still real, though. + if (!Call->isInlineAsm()) + ++NumCalls; + } + } + + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { + if (!AI->isStaticAlloca()) + this->usesDynamicAlloca = true; + } + + if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy()) + ++NumVectorInsts; + + if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) + notDuplicatable = true; + + if (const CallInst *CI = dyn_cast<CallInst>(&I)) { + if (CI->cannotDuplicate()) + notDuplicatable = true; + if (CI->isConvergent()) + convergent = true; + } + + if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I)) + if (InvI->cannotDuplicate()) + notDuplicatable = true; + + NumInsts += TTI.getUserCost(&I); + } + + if (isa<ReturnInst>(BB->getTerminator())) + ++NumRets; + + // We never want to inline functions that contain an indirectbr. This is + // incorrect because all the blockaddress's (in static global initializers + // for example) would be referring to the original function, and this indirect + // jump would jump from the inlined copy of the function into the original + // function which is extremely undefined behavior. + // FIXME: This logic isn't really right; we can safely inline functions + // with indirectbr's as long as no other function or global references the + // blockaddress of a block within the current function. And as a QOI issue, + // if someone is using a blockaddress without an indirectbr, and that + // reference somehow ends up in another function or global, we probably + // don't want to inline this function. + notDuplicatable |= isa<IndirectBrInst>(BB->getTerminator()); + + // Remember NumInsts for this BB. + NumBBInsts[BB] = NumInsts - NumInstsBeforeThisBB; +} diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp new file mode 100644 index 000000000000..8dbcf7034fda --- /dev/null +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -0,0 +1,2637 @@ +//===-- ConstantFolding.cpp - Fold instructions into constants ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines routines for folding instructions into constants. +// +// Also, to supplement the basic IR ConstantExpr simplifications, +// this file defines some additional folding routines that can make use of +// DataLayout information. These functions cannot go in IR due to library +// dependency issues. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Config/config.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include <cassert> +#include <cerrno> +#include <cfenv> +#include <cmath> +#include <cstddef> +#include <cstdint> + +using namespace llvm; + +namespace { + +//===----------------------------------------------------------------------===// +// Constant Folding internal helper functions +//===----------------------------------------------------------------------===// + +static Constant *foldConstVectorToAPInt(APInt &Result, Type *DestTy, + Constant *C, Type *SrcEltTy, + unsigned NumSrcElts, + const DataLayout &DL) { + // Now that we know that the input value is a vector of integers, just shift + // and insert them into our result. + unsigned BitShift = DL.getTypeSizeInBits(SrcEltTy); + for (unsigned i = 0; i != NumSrcElts; ++i) { + Constant *Element; + if (DL.isLittleEndian()) + Element = C->getAggregateElement(NumSrcElts - i - 1); + else + Element = C->getAggregateElement(i); + + if (Element && isa<UndefValue>(Element)) { + Result <<= BitShift; + continue; + } + + auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); + if (!ElementCI) + return ConstantExpr::getBitCast(C, DestTy); + + Result <<= BitShift; + Result |= ElementCI->getValue().zextOrSelf(Result.getBitWidth()); + } + + return nullptr; +} + +/// Constant fold bitcast, symbolically evaluating it with DataLayout. +/// This always returns a non-null constant, but it may be a +/// ConstantExpr if unfoldable. +Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { + assert(CastInst::castIsValid(Instruction::BitCast, C, DestTy) && + "Invalid constantexpr bitcast!"); + + // Catch the obvious splat cases. + if (C->isNullValue() && !DestTy->isX86_MMXTy()) + return Constant::getNullValue(DestTy); + if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && + !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! + return Constant::getAllOnesValue(DestTy); + + if (auto *VTy = dyn_cast<VectorType>(C->getType())) { + // Handle a vector->scalar integer/fp cast. + if (isa<IntegerType>(DestTy) || DestTy->isFloatingPointTy()) { + unsigned NumSrcElts = VTy->getNumElements(); + Type *SrcEltTy = VTy->getElementType(); + + // If the vector is a vector of floating point, convert it to vector of int + // to simplify things. + if (SrcEltTy->isFloatingPointTy()) { + unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits(); + Type *SrcIVTy = + VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElts); + // Ask IR to do the conversion now that #elts line up. + C = ConstantExpr::getBitCast(C, SrcIVTy); + } + + APInt Result(DL.getTypeSizeInBits(DestTy), 0); + if (Constant *CE = foldConstVectorToAPInt(Result, DestTy, C, + SrcEltTy, NumSrcElts, DL)) + return CE; + + if (isa<IntegerType>(DestTy)) + return ConstantInt::get(DestTy, Result); + + APFloat FP(DestTy->getFltSemantics(), Result); + return ConstantFP::get(DestTy->getContext(), FP); + } + } + + // The code below only handles casts to vectors currently. + auto *DestVTy = dyn_cast<VectorType>(DestTy); + if (!DestVTy) + return ConstantExpr::getBitCast(C, DestTy); + + // If this is a scalar -> vector cast, convert the input into a <1 x scalar> + // vector so the code below can handle it uniformly. + if (isa<ConstantFP>(C) || isa<ConstantInt>(C)) { + Constant *Ops = C; // don't take the address of C! + return FoldBitCast(ConstantVector::get(Ops), DestTy, DL); + } + + // If this is a bitcast from constant vector -> vector, fold it. + if (!isa<ConstantDataVector>(C) && !isa<ConstantVector>(C)) + return ConstantExpr::getBitCast(C, DestTy); + + // If the element types match, IR can fold it. + unsigned NumDstElt = DestVTy->getNumElements(); + unsigned NumSrcElt = C->getType()->getVectorNumElements(); + if (NumDstElt == NumSrcElt) + return ConstantExpr::getBitCast(C, DestTy); + + Type *SrcEltTy = C->getType()->getVectorElementType(); + Type *DstEltTy = DestVTy->getElementType(); + + // Otherwise, we're changing the number of elements in a vector, which + // requires endianness information to do the right thing. For example, + // bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) + // folds to (little endian): + // <4 x i32> <i32 0, i32 0, i32 1, i32 0> + // and to (big endian): + // <4 x i32> <i32 0, i32 0, i32 0, i32 1> + + // First thing is first. We only want to think about integer here, so if + // we have something in FP form, recast it as integer. + if (DstEltTy->isFloatingPointTy()) { + // Fold to an vector of integers with same size as our FP type. + unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits(); + Type *DestIVTy = + VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumDstElt); + // Recursively handle this integer conversion, if possible. + C = FoldBitCast(C, DestIVTy, DL); + + // Finally, IR can handle this now that #elts line up. + return ConstantExpr::getBitCast(C, DestTy); + } + + // Okay, we know the destination is integer, if the input is FP, convert + // it to integer first. + if (SrcEltTy->isFloatingPointTy()) { + unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits(); + Type *SrcIVTy = + VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElt); + // Ask IR to do the conversion now that #elts line up. + C = ConstantExpr::getBitCast(C, SrcIVTy); + // If IR wasn't able to fold it, bail out. + if (!isa<ConstantVector>(C) && // FIXME: Remove ConstantVector. + !isa<ConstantDataVector>(C)) + return C; + } + + // Now we know that the input and output vectors are both integer vectors + // of the same size, and that their #elements is not the same. Do the + // conversion here, which depends on whether the input or output has + // more elements. + bool isLittleEndian = DL.isLittleEndian(); + + SmallVector<Constant*, 32> Result; + if (NumDstElt < NumSrcElt) { + // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>) + Constant *Zero = Constant::getNullValue(DstEltTy); + unsigned Ratio = NumSrcElt/NumDstElt; + unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits(); + unsigned SrcElt = 0; + for (unsigned i = 0; i != NumDstElt; ++i) { + // Build each element of the result. + Constant *Elt = Zero; + unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1); + for (unsigned j = 0; j != Ratio; ++j) { + Constant *Src = C->getAggregateElement(SrcElt++); + if (Src && isa<UndefValue>(Src)) + Src = Constant::getNullValue(C->getType()->getVectorElementType()); + else + Src = dyn_cast_or_null<ConstantInt>(Src); + if (!Src) // Reject constantexpr elements. + return ConstantExpr::getBitCast(C, DestTy); + + // Zero extend the element to the right size. + Src = ConstantExpr::getZExt(Src, Elt->getType()); + + // Shift it to the right place, depending on endianness. + Src = ConstantExpr::getShl(Src, + ConstantInt::get(Src->getType(), ShiftAmt)); + ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize; + + // Mix it in. + Elt = ConstantExpr::getOr(Elt, Src); + } + Result.push_back(Elt); + } + return ConstantVector::get(Result); + } + + // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) + unsigned Ratio = NumDstElt/NumSrcElt; + unsigned DstBitSize = DL.getTypeSizeInBits(DstEltTy); + + // Loop over each source value, expanding into multiple results. + for (unsigned i = 0; i != NumSrcElt; ++i) { + auto *Element = C->getAggregateElement(i); + + if (!Element) // Reject constantexpr elements. + return ConstantExpr::getBitCast(C, DestTy); + + if (isa<UndefValue>(Element)) { + // Correctly Propagate undef values. + Result.append(Ratio, UndefValue::get(DstEltTy)); + continue; + } + + auto *Src = dyn_cast<ConstantInt>(Element); + if (!Src) + return ConstantExpr::getBitCast(C, DestTy); + + unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1); + for (unsigned j = 0; j != Ratio; ++j) { + // Shift the piece of the value into the right place, depending on + // endianness. + Constant *Elt = ConstantExpr::getLShr(Src, + ConstantInt::get(Src->getType(), ShiftAmt)); + ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize; + + // Truncate the element to an integer with the same pointer size and + // convert the element back to a pointer using a inttoptr. + if (DstEltTy->isPointerTy()) { + IntegerType *DstIntTy = Type::getIntNTy(C->getContext(), DstBitSize); + Constant *CE = ConstantExpr::getTrunc(Elt, DstIntTy); + Result.push_back(ConstantExpr::getIntToPtr(CE, DstEltTy)); + continue; + } + + // Truncate and remember this piece. + Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy)); + } + } + + return ConstantVector::get(Result); +} + +} // end anonymous namespace + +/// If this constant is a constant offset from a global, return the global and +/// the constant. Because of constantexprs, this function is recursive. +bool llvm::IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, + APInt &Offset, const DataLayout &DL) { + // Trivial case, constant is the global. + if ((GV = dyn_cast<GlobalValue>(C))) { + unsigned BitWidth = DL.getIndexTypeSizeInBits(GV->getType()); + Offset = APInt(BitWidth, 0); + return true; + } + + // Otherwise, if this isn't a constant expr, bail out. + auto *CE = dyn_cast<ConstantExpr>(C); + if (!CE) return false; + + // Look through ptr->int and ptr->ptr casts. + if (CE->getOpcode() == Instruction::PtrToInt || + CE->getOpcode() == Instruction::BitCast) + return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, DL); + + // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5) + auto *GEP = dyn_cast<GEPOperator>(CE); + if (!GEP) + return false; + + unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP->getType()); + APInt TmpOffset(BitWidth, 0); + + // If the base isn't a global+constant, we aren't either. + if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, TmpOffset, DL)) + return false; + + // Otherwise, add any offset that our operands provide. + if (!GEP->accumulateConstantOffset(DL, TmpOffset)) + return false; + + Offset = TmpOffset; + return true; +} + +Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy, + const DataLayout &DL) { + do { + Type *SrcTy = C->getType(); + + // If the type sizes are the same and a cast is legal, just directly + // cast the constant. + if (DL.getTypeSizeInBits(DestTy) == DL.getTypeSizeInBits(SrcTy)) { + Instruction::CastOps Cast = Instruction::BitCast; + // If we are going from a pointer to int or vice versa, we spell the cast + // differently. + if (SrcTy->isIntegerTy() && DestTy->isPointerTy()) + Cast = Instruction::IntToPtr; + else if (SrcTy->isPointerTy() && DestTy->isIntegerTy()) + Cast = Instruction::PtrToInt; + + if (CastInst::castIsValid(Cast, C, DestTy)) + return ConstantExpr::getCast(Cast, C, DestTy); + } + + // If this isn't an aggregate type, there is nothing we can do to drill down + // and find a bitcastable constant. + if (!SrcTy->isAggregateType()) + return nullptr; + + // We're simulating a load through a pointer that was bitcast to point to + // a different type, so we can try to walk down through the initial + // elements of an aggregate to see if some part of the aggregate is + // castable to implement the "load" semantic model. + if (SrcTy->isStructTy()) { + // Struct types might have leading zero-length elements like [0 x i32], + // which are certainly not what we are looking for, so skip them. + unsigned Elem = 0; + Constant *ElemC; + do { + ElemC = C->getAggregateElement(Elem++); + } while (ElemC && DL.getTypeSizeInBits(ElemC->getType()) == 0); + C = ElemC; + } else { + C = C->getAggregateElement(0u); + } + } while (C); + + return nullptr; +} + +namespace { + +/// Recursive helper to read bits out of global. C is the constant being copied +/// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy +/// results into and BytesLeft is the number of bytes left in +/// the CurPtr buffer. DL is the DataLayout. +bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, + unsigned BytesLeft, const DataLayout &DL) { + assert(ByteOffset <= DL.getTypeAllocSize(C->getType()) && + "Out of range access"); + + // If this element is zero or undefined, we can just return since *CurPtr is + // zero initialized. + if (isa<ConstantAggregateZero>(C) || isa<UndefValue>(C)) + return true; + + if (auto *CI = dyn_cast<ConstantInt>(C)) { + if (CI->getBitWidth() > 64 || + (CI->getBitWidth() & 7) != 0) + return false; + + uint64_t Val = CI->getZExtValue(); + unsigned IntBytes = unsigned(CI->getBitWidth()/8); + + for (unsigned i = 0; i != BytesLeft && ByteOffset != IntBytes; ++i) { + int n = ByteOffset; + if (!DL.isLittleEndian()) + n = IntBytes - n - 1; + CurPtr[i] = (unsigned char)(Val >> (n * 8)); + ++ByteOffset; + } + return true; + } + + if (auto *CFP = dyn_cast<ConstantFP>(C)) { + if (CFP->getType()->isDoubleTy()) { + C = FoldBitCast(C, Type::getInt64Ty(C->getContext()), DL); + return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); + } + if (CFP->getType()->isFloatTy()){ + C = FoldBitCast(C, Type::getInt32Ty(C->getContext()), DL); + return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); + } + if (CFP->getType()->isHalfTy()){ + C = FoldBitCast(C, Type::getInt16Ty(C->getContext()), DL); + return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); + } + return false; + } + + if (auto *CS = dyn_cast<ConstantStruct>(C)) { + const StructLayout *SL = DL.getStructLayout(CS->getType()); + unsigned Index = SL->getElementContainingOffset(ByteOffset); + uint64_t CurEltOffset = SL->getElementOffset(Index); + ByteOffset -= CurEltOffset; + + while (true) { + // If the element access is to the element itself and not to tail padding, + // read the bytes from the element. + uint64_t EltSize = DL.getTypeAllocSize(CS->getOperand(Index)->getType()); + + if (ByteOffset < EltSize && + !ReadDataFromGlobal(CS->getOperand(Index), ByteOffset, CurPtr, + BytesLeft, DL)) + return false; + + ++Index; + + // Check to see if we read from the last struct element, if so we're done. + if (Index == CS->getType()->getNumElements()) + return true; + + // If we read all of the bytes we needed from this element we're done. + uint64_t NextEltOffset = SL->getElementOffset(Index); + + if (BytesLeft <= NextEltOffset - CurEltOffset - ByteOffset) + return true; + + // Move to the next element of the struct. + CurPtr += NextEltOffset - CurEltOffset - ByteOffset; + BytesLeft -= NextEltOffset - CurEltOffset - ByteOffset; + ByteOffset = 0; + CurEltOffset = NextEltOffset; + } + // not reached. + } + + if (isa<ConstantArray>(C) || isa<ConstantVector>(C) || + isa<ConstantDataSequential>(C)) { + Type *EltTy = C->getType()->getSequentialElementType(); + uint64_t EltSize = DL.getTypeAllocSize(EltTy); + uint64_t Index = ByteOffset / EltSize; + uint64_t Offset = ByteOffset - Index * EltSize; + uint64_t NumElts; + if (auto *AT = dyn_cast<ArrayType>(C->getType())) + NumElts = AT->getNumElements(); + else + NumElts = C->getType()->getVectorNumElements(); + + for (; Index != NumElts; ++Index) { + if (!ReadDataFromGlobal(C->getAggregateElement(Index), Offset, CurPtr, + BytesLeft, DL)) + return false; + + uint64_t BytesWritten = EltSize - Offset; + assert(BytesWritten <= EltSize && "Not indexing into this element?"); + if (BytesWritten >= BytesLeft) + return true; + + Offset = 0; + BytesLeft -= BytesWritten; + CurPtr += BytesWritten; + } + return true; + } + + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::IntToPtr && + CE->getOperand(0)->getType() == DL.getIntPtrType(CE->getType())) { + return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr, + BytesLeft, DL); + } + } + + // Otherwise, unknown initializer type. + return false; +} + +Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy, + const DataLayout &DL) { + auto *PTy = cast<PointerType>(C->getType()); + auto *IntType = dyn_cast<IntegerType>(LoadTy); + + // If this isn't an integer load we can't fold it directly. + if (!IntType) { + unsigned AS = PTy->getAddressSpace(); + + // If this is a float/double load, we can try folding it as an int32/64 load + // and then bitcast the result. This can be useful for union cases. Note + // that address spaces don't matter here since we're not going to result in + // an actual new load. + Type *MapTy; + if (LoadTy->isHalfTy()) + MapTy = Type::getInt16Ty(C->getContext()); + else if (LoadTy->isFloatTy()) + MapTy = Type::getInt32Ty(C->getContext()); + else if (LoadTy->isDoubleTy()) + MapTy = Type::getInt64Ty(C->getContext()); + else if (LoadTy->isVectorTy()) { + MapTy = PointerType::getIntNTy(C->getContext(), + DL.getTypeSizeInBits(LoadTy)); + } else + return nullptr; + + C = FoldBitCast(C, MapTy->getPointerTo(AS), DL); + if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) { + if (Res->isNullValue() && !LoadTy->isX86_MMXTy()) + // Materializing a zero can be done trivially without a bitcast + return Constant::getNullValue(LoadTy); + Type *CastTy = LoadTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(LoadTy) : LoadTy; + Res = FoldBitCast(Res, CastTy, DL); + if (LoadTy->isPtrOrPtrVectorTy()) { + // For vector of pointer, we needed to first convert to a vector of integer, then do vector inttoptr + if (Res->isNullValue() && !LoadTy->isX86_MMXTy()) + return Constant::getNullValue(LoadTy); + if (DL.isNonIntegralPointerType(LoadTy->getScalarType())) + // Be careful not to replace a load of an addrspace value with an inttoptr here + return nullptr; + Res = ConstantExpr::getCast(Instruction::IntToPtr, Res, LoadTy); + } + return Res; + } + return nullptr; + } + + unsigned BytesLoaded = (IntType->getBitWidth() + 7) / 8; + if (BytesLoaded > 32 || BytesLoaded == 0) + return nullptr; + + GlobalValue *GVal; + APInt OffsetAI; + if (!IsConstantOffsetFromGlobal(C, GVal, OffsetAI, DL)) + return nullptr; + + auto *GV = dyn_cast<GlobalVariable>(GVal); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || + !GV->getInitializer()->getType()->isSized()) + return nullptr; + + int64_t Offset = OffsetAI.getSExtValue(); + int64_t InitializerSize = DL.getTypeAllocSize(GV->getInitializer()->getType()); + + // If we're not accessing anything in this constant, the result is undefined. + if (Offset <= -1 * static_cast<int64_t>(BytesLoaded)) + return UndefValue::get(IntType); + + // If we're not accessing anything in this constant, the result is undefined. + if (Offset >= InitializerSize) + return UndefValue::get(IntType); + + unsigned char RawBytes[32] = {0}; + unsigned char *CurPtr = RawBytes; + unsigned BytesLeft = BytesLoaded; + + // If we're loading off the beginning of the global, some bytes may be valid. + if (Offset < 0) { + CurPtr += -Offset; + BytesLeft += Offset; + Offset = 0; + } + + if (!ReadDataFromGlobal(GV->getInitializer(), Offset, CurPtr, BytesLeft, DL)) + return nullptr; + + APInt ResultVal = APInt(IntType->getBitWidth(), 0); + if (DL.isLittleEndian()) { + ResultVal = RawBytes[BytesLoaded - 1]; + for (unsigned i = 1; i != BytesLoaded; ++i) { + ResultVal <<= 8; + ResultVal |= RawBytes[BytesLoaded - 1 - i]; + } + } else { + ResultVal = RawBytes[0]; + for (unsigned i = 1; i != BytesLoaded; ++i) { + ResultVal <<= 8; + ResultVal |= RawBytes[i]; + } + } + + return ConstantInt::get(IntType->getContext(), ResultVal); +} + +Constant *ConstantFoldLoadThroughBitcastExpr(ConstantExpr *CE, Type *DestTy, + const DataLayout &DL) { + auto *SrcPtr = CE->getOperand(0); + auto *SrcPtrTy = dyn_cast<PointerType>(SrcPtr->getType()); + if (!SrcPtrTy) + return nullptr; + Type *SrcTy = SrcPtrTy->getPointerElementType(); + + Constant *C = ConstantFoldLoadFromConstPtr(SrcPtr, SrcTy, DL); + if (!C) + return nullptr; + + return llvm::ConstantFoldLoadThroughBitcast(C, DestTy, DL); +} + +} // end anonymous namespace + +Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, + const DataLayout &DL) { + // First, try the easy cases: + if (auto *GV = dyn_cast<GlobalVariable>(C)) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + return GV->getInitializer(); + + if (auto *GA = dyn_cast<GlobalAlias>(C)) + if (GA->getAliasee() && !GA->isInterposable()) + return ConstantFoldLoadFromConstPtr(GA->getAliasee(), Ty, DL); + + // If the loaded value isn't a constant expr, we can't handle it. + auto *CE = dyn_cast<ConstantExpr>(C); + if (!CE) + return nullptr; + + if (CE->getOpcode() == Instruction::GetElementPtr) { + if (auto *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) { + if (GV->isConstant() && GV->hasDefinitiveInitializer()) { + if (Constant *V = + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) + return V; + } + } + } + + if (CE->getOpcode() == Instruction::BitCast) + if (Constant *LoadedC = ConstantFoldLoadThroughBitcastExpr(CE, Ty, DL)) + return LoadedC; + + // Instead of loading constant c string, use corresponding integer value + // directly if string length is small enough. + StringRef Str; + if (getConstantStringInfo(CE, Str) && !Str.empty()) { + size_t StrLen = Str.size(); + unsigned NumBits = Ty->getPrimitiveSizeInBits(); + // Replace load with immediate integer if the result is an integer or fp + // value. + if ((NumBits >> 3) == StrLen + 1 && (NumBits & 7) == 0 && + (isa<IntegerType>(Ty) || Ty->isFloatingPointTy())) { + APInt StrVal(NumBits, 0); + APInt SingleChar(NumBits, 0); + if (DL.isLittleEndian()) { + for (unsigned char C : reverse(Str.bytes())) { + SingleChar = static_cast<uint64_t>(C); + StrVal = (StrVal << 8) | SingleChar; + } + } else { + for (unsigned char C : Str.bytes()) { + SingleChar = static_cast<uint64_t>(C); + StrVal = (StrVal << 8) | SingleChar; + } + // Append NULL at the end. + SingleChar = 0; + StrVal = (StrVal << 8) | SingleChar; + } + + Constant *Res = ConstantInt::get(CE->getContext(), StrVal); + if (Ty->isFloatingPointTy()) + Res = ConstantExpr::getBitCast(Res, Ty); + return Res; + } + } + + // If this load comes from anywhere in a constant global, and if the global + // is all undef or zero, we know what it loads. + if (auto *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(CE, DL))) { + if (GV->isConstant() && GV->hasDefinitiveInitializer()) { + if (GV->getInitializer()->isNullValue()) + return Constant::getNullValue(Ty); + if (isa<UndefValue>(GV->getInitializer())) + return UndefValue::get(Ty); + } + } + + // Try hard to fold loads from bitcasted strange and non-type-safe things. + return FoldReinterpretLoadFromConstPtr(CE, Ty, DL); +} + +namespace { + +Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout &DL) { + if (LI->isVolatile()) return nullptr; + + if (auto *C = dyn_cast<Constant>(LI->getOperand(0))) + return ConstantFoldLoadFromConstPtr(C, LI->getType(), DL); + + return nullptr; +} + +/// One of Op0/Op1 is a constant expression. +/// Attempt to symbolically evaluate the result of a binary operator merging +/// these together. If target data info is available, it is provided as DL, +/// otherwise DL is null. +Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1, + const DataLayout &DL) { + // SROA + + // Fold (and 0xffffffff00000000, (shl x, 32)) -> shl. + // Fold (lshr (or X, Y), 32) -> (lshr [X/Y], 32) if one doesn't contribute + // bits. + + if (Opc == Instruction::And) { + KnownBits Known0 = computeKnownBits(Op0, DL); + KnownBits Known1 = computeKnownBits(Op1, DL); + if ((Known1.One | Known0.Zero).isAllOnesValue()) { + // All the bits of Op0 that the 'and' could be masking are already zero. + return Op0; + } + if ((Known0.One | Known1.Zero).isAllOnesValue()) { + // All the bits of Op1 that the 'and' could be masking are already zero. + return Op1; + } + + Known0.Zero |= Known1.Zero; + Known0.One &= Known1.One; + if (Known0.isConstant()) + return ConstantInt::get(Op0->getType(), Known0.getConstant()); + } + + // If the constant expr is something like &A[123] - &A[4].f, fold this into a + // constant. This happens frequently when iterating over a global array. + if (Opc == Instruction::Sub) { + GlobalValue *GV1, *GV2; + APInt Offs1, Offs2; + + if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, DL)) + if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, DL) && GV1 == GV2) { + unsigned OpSize = DL.getTypeSizeInBits(Op0->getType()); + + // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow. + // PtrToInt may change the bitwidth so we have convert to the right size + // first. + return ConstantInt::get(Op0->getType(), Offs1.zextOrTrunc(OpSize) - + Offs2.zextOrTrunc(OpSize)); + } + } + + return nullptr; +} + +/// If array indices are not pointer-sized integers, explicitly cast them so +/// that they aren't implicitly casted by the getelementptr. +Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops, + Type *ResultTy, Optional<unsigned> InRangeIndex, + const DataLayout &DL, const TargetLibraryInfo *TLI) { + Type *IntPtrTy = DL.getIntPtrType(ResultTy); + Type *IntPtrScalarTy = IntPtrTy->getScalarType(); + + bool Any = false; + SmallVector<Constant*, 32> NewIdxs; + for (unsigned i = 1, e = Ops.size(); i != e; ++i) { + if ((i == 1 || + !isa<StructType>(GetElementPtrInst::getIndexedType( + SrcElemTy, Ops.slice(1, i - 1)))) && + Ops[i]->getType()->getScalarType() != IntPtrScalarTy) { + Any = true; + Type *NewType = Ops[i]->getType()->isVectorTy() + ? IntPtrTy + : IntPtrTy->getScalarType(); + NewIdxs.push_back(ConstantExpr::getCast(CastInst::getCastOpcode(Ops[i], + true, + NewType, + true), + Ops[i], NewType)); + } else + NewIdxs.push_back(Ops[i]); + } + + if (!Any) + return nullptr; + + Constant *C = ConstantExpr::getGetElementPtr( + SrcElemTy, Ops[0], NewIdxs, /*InBounds=*/false, InRangeIndex); + if (Constant *Folded = ConstantFoldConstant(C, DL, TLI)) + C = Folded; + + return C; +} + +/// Strip the pointer casts, but preserve the address space information. +Constant *StripPtrCastKeepAS(Constant *Ptr, Type *&ElemTy) { + assert(Ptr->getType()->isPointerTy() && "Not a pointer type"); + auto *OldPtrTy = cast<PointerType>(Ptr->getType()); + Ptr = cast<Constant>(Ptr->stripPointerCasts()); + auto *NewPtrTy = cast<PointerType>(Ptr->getType()); + + ElemTy = NewPtrTy->getPointerElementType(); + + // Preserve the address space number of the pointer. + if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) { + NewPtrTy = ElemTy->getPointerTo(OldPtrTy->getAddressSpace()); + Ptr = ConstantExpr::getPointerCast(Ptr, NewPtrTy); + } + return Ptr; +} + +/// If we can symbolically evaluate the GEP constant expression, do so. +Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + const GEPOperator *InnermostGEP = GEP; + bool InBounds = GEP->isInBounds(); + + Type *SrcElemTy = GEP->getSourceElementType(); + Type *ResElemTy = GEP->getResultElementType(); + Type *ResTy = GEP->getType(); + if (!SrcElemTy->isSized()) + return nullptr; + + if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy, + GEP->getInRangeIndex(), DL, TLI)) + return C; + + Constant *Ptr = Ops[0]; + if (!Ptr->getType()->isPointerTy()) + return nullptr; + + Type *IntPtrTy = DL.getIntPtrType(Ptr->getType()); + + // If this is a constant expr gep that is effectively computing an + // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12' + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + if (!isa<ConstantInt>(Ops[i])) { + + // If this is "gep i8* Ptr, (sub 0, V)", fold this as: + // "inttoptr (sub (ptrtoint Ptr), V)" + if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) { + auto *CE = dyn_cast<ConstantExpr>(Ops[1]); + assert((!CE || CE->getType() == IntPtrTy) && + "CastGEPIndices didn't canonicalize index types!"); + if (CE && CE->getOpcode() == Instruction::Sub && + CE->getOperand(0)->isNullValue()) { + Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType()); + Res = ConstantExpr::getSub(Res, CE->getOperand(1)); + Res = ConstantExpr::getIntToPtr(Res, ResTy); + if (auto *FoldedRes = ConstantFoldConstant(Res, DL, TLI)) + Res = FoldedRes; + return Res; + } + } + return nullptr; + } + + unsigned BitWidth = DL.getTypeSizeInBits(IntPtrTy); + APInt Offset = + APInt(BitWidth, + DL.getIndexedOffsetInType( + SrcElemTy, + makeArrayRef((Value * const *)Ops.data() + 1, Ops.size() - 1))); + Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); + + // If this is a GEP of a GEP, fold it all into a single GEP. + while (auto *GEP = dyn_cast<GEPOperator>(Ptr)) { + InnermostGEP = GEP; + InBounds &= GEP->isInBounds(); + + SmallVector<Value *, 4> NestedOps(GEP->op_begin() + 1, GEP->op_end()); + + // Do not try the incorporate the sub-GEP if some index is not a number. + bool AllConstantInt = true; + for (Value *NestedOp : NestedOps) + if (!isa<ConstantInt>(NestedOp)) { + AllConstantInt = false; + break; + } + if (!AllConstantInt) + break; + + Ptr = cast<Constant>(GEP->getOperand(0)); + SrcElemTy = GEP->getSourceElementType(); + Offset += APInt(BitWidth, DL.getIndexedOffsetInType(SrcElemTy, NestedOps)); + Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); + } + + // If the base value for this address is a literal integer value, fold the + // getelementptr to the resulting integer value casted to the pointer type. + APInt BasePtr(BitWidth, 0); + if (auto *CE = dyn_cast<ConstantExpr>(Ptr)) { + if (CE->getOpcode() == Instruction::IntToPtr) { + if (auto *Base = dyn_cast<ConstantInt>(CE->getOperand(0))) + BasePtr = Base->getValue().zextOrTrunc(BitWidth); + } + } + + auto *PTy = cast<PointerType>(Ptr->getType()); + if ((Ptr->isNullValue() || BasePtr != 0) && + !DL.isNonIntegralPointerType(PTy)) { + Constant *C = ConstantInt::get(Ptr->getContext(), Offset + BasePtr); + return ConstantExpr::getIntToPtr(C, ResTy); + } + + // Otherwise form a regular getelementptr. Recompute the indices so that + // we eliminate over-indexing of the notional static type array bounds. + // This makes it easy to determine if the getelementptr is "inbounds". + // Also, this helps GlobalOpt do SROA on GlobalVariables. + Type *Ty = PTy; + SmallVector<Constant *, 32> NewIdxs; + + do { + if (!Ty->isStructTy()) { + if (Ty->isPointerTy()) { + // The only pointer indexing we'll do is on the first index of the GEP. + if (!NewIdxs.empty()) + break; + + Ty = SrcElemTy; + + // Only handle pointers to sized types, not pointers to functions. + if (!Ty->isSized()) + return nullptr; + } else if (auto *ATy = dyn_cast<SequentialType>(Ty)) { + Ty = ATy->getElementType(); + } else { + // We've reached some non-indexable type. + break; + } + + // Determine which element of the array the offset points into. + APInt ElemSize(BitWidth, DL.getTypeAllocSize(Ty)); + if (ElemSize == 0) { + // The element size is 0. This may be [0 x Ty]*, so just use a zero + // index for this level and proceed to the next level to see if it can + // accommodate the offset. + NewIdxs.push_back(ConstantInt::get(IntPtrTy, 0)); + } else { + // The element size is non-zero divide the offset by the element + // size (rounding down), to compute the index at this level. + bool Overflow; + APInt NewIdx = Offset.sdiv_ov(ElemSize, Overflow); + if (Overflow) + break; + Offset -= NewIdx * ElemSize; + NewIdxs.push_back(ConstantInt::get(IntPtrTy, NewIdx)); + } + } else { + auto *STy = cast<StructType>(Ty); + // If we end up with an offset that isn't valid for this struct type, we + // can't re-form this GEP in a regular form, so bail out. The pointer + // operand likely went through casts that are necessary to make the GEP + // sensible. + const StructLayout &SL = *DL.getStructLayout(STy); + if (Offset.isNegative() || Offset.uge(SL.getSizeInBytes())) + break; + + // Determine which field of the struct the offset points into. The + // getZExtValue is fine as we've already ensured that the offset is + // within the range representable by the StructLayout API. + unsigned ElIdx = SL.getElementContainingOffset(Offset.getZExtValue()); + NewIdxs.push_back(ConstantInt::get(Type::getInt32Ty(Ty->getContext()), + ElIdx)); + Offset -= APInt(BitWidth, SL.getElementOffset(ElIdx)); + Ty = STy->getTypeAtIndex(ElIdx); + } + } while (Ty != ResElemTy); + + // If we haven't used up the entire offset by descending the static + // type, then the offset is pointing into the middle of an indivisible + // member, so we can't simplify it. + if (Offset != 0) + return nullptr; + + // Preserve the inrange index from the innermost GEP if possible. We must + // have calculated the same indices up to and including the inrange index. + Optional<unsigned> InRangeIndex; + if (Optional<unsigned> LastIRIndex = InnermostGEP->getInRangeIndex()) + if (SrcElemTy == InnermostGEP->getSourceElementType() && + NewIdxs.size() > *LastIRIndex) { + InRangeIndex = LastIRIndex; + for (unsigned I = 0; I <= *LastIRIndex; ++I) + if (NewIdxs[I] != InnermostGEP->getOperand(I + 1)) + return nullptr; + } + + // Create a GEP. + Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, + InBounds, InRangeIndex); + assert(C->getType()->getPointerElementType() == Ty && + "Computed GetElementPtr has unexpected type!"); + + // If we ended up indexing a member with a type that doesn't match + // the type of what the original indices indexed, add a cast. + if (Ty != ResElemTy) + C = FoldBitCast(C, ResTy, DL); + + return C; +} + +/// Attempt to constant fold an instruction with the +/// specified opcode and operands. If successful, the constant result is +/// returned, if not, null is returned. Note that this function can fail when +/// attempting to fold instructions like loads and stores, which have no +/// constant expression form. +Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + Type *DestTy = InstOrCE->getType(); + + if (Instruction::isUnaryOp(Opcode)) + return ConstantFoldUnaryOpOperand(Opcode, Ops[0], DL); + + if (Instruction::isBinaryOp(Opcode)) + return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); + + if (Instruction::isCast(Opcode)) + return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); + + if (auto *GEP = dyn_cast<GEPOperator>(InstOrCE)) { + if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI)) + return C; + + return ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), Ops[0], + Ops.slice(1), GEP->isInBounds(), + GEP->getInRangeIndex()); + } + + if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) + return CE->getWithOperands(Ops); + + switch (Opcode) { + default: return nullptr; + case Instruction::ICmp: + case Instruction::FCmp: llvm_unreachable("Invalid for compares"); + case Instruction::Call: + if (auto *F = dyn_cast<Function>(Ops.back())) { + const auto *Call = cast<CallBase>(InstOrCE); + if (canConstantFoldCallTo(Call, F)) + return ConstantFoldCall(Call, F, Ops.slice(0, Ops.size() - 1), TLI); + } + return nullptr; + case Instruction::Select: + return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); + case Instruction::ExtractElement: + return ConstantExpr::getExtractElement(Ops[0], Ops[1]); + case Instruction::ExtractValue: + return ConstantExpr::getExtractValue( + Ops[0], cast<ExtractValueInst>(InstOrCE)->getIndices()); + case Instruction::InsertElement: + return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); + case Instruction::ShuffleVector: + return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); + } +} + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Constant Folding public APIs +//===----------------------------------------------------------------------===// + +namespace { + +Constant * +ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL, + const TargetLibraryInfo *TLI, + SmallDenseMap<Constant *, Constant *> &FoldedOps) { + if (!isa<ConstantVector>(C) && !isa<ConstantExpr>(C)) + return nullptr; + + SmallVector<Constant *, 8> Ops; + for (const Use &NewU : C->operands()) { + auto *NewC = cast<Constant>(&NewU); + // Recursively fold the ConstantExpr's operands. If we have already folded + // a ConstantExpr, we don't have to process it again. + if (isa<ConstantVector>(NewC) || isa<ConstantExpr>(NewC)) { + auto It = FoldedOps.find(NewC); + if (It == FoldedOps.end()) { + if (auto *FoldedC = + ConstantFoldConstantImpl(NewC, DL, TLI, FoldedOps)) { + FoldedOps.insert({NewC, FoldedC}); + NewC = FoldedC; + } else { + FoldedOps.insert({NewC, NewC}); + } + } else { + NewC = It->second; + } + } + Ops.push_back(NewC); + } + + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->isCompare()) + return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1], + DL, TLI); + + return ConstantFoldInstOperandsImpl(CE, CE->getOpcode(), Ops, DL, TLI); + } + + assert(isa<ConstantVector>(C)); + return ConstantVector::get(Ops); +} + +} // end anonymous namespace + +Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + // Handle PHI nodes quickly here... + if (auto *PN = dyn_cast<PHINode>(I)) { + Constant *CommonValue = nullptr; + + SmallDenseMap<Constant *, Constant *> FoldedOps; + for (Value *Incoming : PN->incoming_values()) { + // If the incoming value is undef then skip it. Note that while we could + // skip the value if it is equal to the phi node itself we choose not to + // because that would break the rule that constant folding only applies if + // all operands are constants. + if (isa<UndefValue>(Incoming)) + continue; + // If the incoming value is not a constant, then give up. + auto *C = dyn_cast<Constant>(Incoming); + if (!C) + return nullptr; + // Fold the PHI's operands. + if (auto *FoldedC = ConstantFoldConstantImpl(C, DL, TLI, FoldedOps)) + C = FoldedC; + // If the incoming value is a different constant to + // the one we saw previously, then give up. + if (CommonValue && C != CommonValue) + return nullptr; + CommonValue = C; + } + + // If we reach here, all incoming values are the same constant or undef. + return CommonValue ? CommonValue : UndefValue::get(PN->getType()); + } + + // Scan the operand list, checking to see if they are all constants, if so, + // hand off to ConstantFoldInstOperandsImpl. + if (!all_of(I->operands(), [](Use &U) { return isa<Constant>(U); })) + return nullptr; + + SmallDenseMap<Constant *, Constant *> FoldedOps; + SmallVector<Constant *, 8> Ops; + for (const Use &OpU : I->operands()) { + auto *Op = cast<Constant>(&OpU); + // Fold the Instruction's operands. + if (auto *FoldedOp = ConstantFoldConstantImpl(Op, DL, TLI, FoldedOps)) + Op = FoldedOp; + + Ops.push_back(Op); + } + + if (const auto *CI = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1], + DL, TLI); + + if (const auto *LI = dyn_cast<LoadInst>(I)) + return ConstantFoldLoadInst(LI, DL); + + if (auto *IVI = dyn_cast<InsertValueInst>(I)) { + return ConstantExpr::getInsertValue( + cast<Constant>(IVI->getAggregateOperand()), + cast<Constant>(IVI->getInsertedValueOperand()), + IVI->getIndices()); + } + + if (auto *EVI = dyn_cast<ExtractValueInst>(I)) { + return ConstantExpr::getExtractValue( + cast<Constant>(EVI->getAggregateOperand()), + EVI->getIndices()); + } + + return ConstantFoldInstOperands(I, Ops, DL, TLI); +} + +Constant *llvm::ConstantFoldConstant(const Constant *C, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + SmallDenseMap<Constant *, Constant *> FoldedOps; + return ConstantFoldConstantImpl(C, DL, TLI, FoldedOps); +} + +Constant *llvm::ConstantFoldInstOperands(Instruction *I, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + return ConstantFoldInstOperandsImpl(I, I->getOpcode(), Ops, DL, TLI); +} + +Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, + Constant *Ops0, Constant *Ops1, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + // fold: icmp (inttoptr x), null -> icmp x, 0 + // fold: icmp null, (inttoptr x) -> icmp 0, x + // fold: icmp (ptrtoint x), 0 -> icmp x, null + // fold: icmp 0, (ptrtoint x) -> icmp null, x + // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y + // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y + // + // FIXME: The following comment is out of data and the DataLayout is here now. + // ConstantExpr::getCompare cannot do this, because it doesn't have DL + // around to know if bit truncation is happening. + if (auto *CE0 = dyn_cast<ConstantExpr>(Ops0)) { + if (Ops1->isNullValue()) { + if (CE0->getOpcode() == Instruction::IntToPtr) { + Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); + // Convert the integer value to the right size to ensure we get the + // proper extension or truncation. + Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0), + IntPtrTy, false); + Constant *Null = Constant::getNullValue(C->getType()); + return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI); + } + + // Only do this transformation if the int is intptrty in size, otherwise + // there is a truncation or extension that we aren't modeling. + if (CE0->getOpcode() == Instruction::PtrToInt) { + Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType()); + if (CE0->getType() == IntPtrTy) { + Constant *C = CE0->getOperand(0); + Constant *Null = Constant::getNullValue(C->getType()); + return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI); + } + } + } + + if (auto *CE1 = dyn_cast<ConstantExpr>(Ops1)) { + if (CE0->getOpcode() == CE1->getOpcode()) { + if (CE0->getOpcode() == Instruction::IntToPtr) { + Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); + + // Convert the integer value to the right size to ensure we get the + // proper extension or truncation. + Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0), + IntPtrTy, false); + Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0), + IntPtrTy, false); + return ConstantFoldCompareInstOperands(Predicate, C0, C1, DL, TLI); + } + + // Only do this transformation if the int is intptrty in size, otherwise + // there is a truncation or extension that we aren't modeling. + if (CE0->getOpcode() == Instruction::PtrToInt) { + Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType()); + if (CE0->getType() == IntPtrTy && + CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()) { + return ConstantFoldCompareInstOperands( + Predicate, CE0->getOperand(0), CE1->getOperand(0), DL, TLI); + } + } + } + } + + // icmp eq (or x, y), 0 -> (icmp eq x, 0) & (icmp eq y, 0) + // icmp ne (or x, y), 0 -> (icmp ne x, 0) | (icmp ne y, 0) + if ((Predicate == ICmpInst::ICMP_EQ || Predicate == ICmpInst::ICMP_NE) && + CE0->getOpcode() == Instruction::Or && Ops1->isNullValue()) { + Constant *LHS = ConstantFoldCompareInstOperands( + Predicate, CE0->getOperand(0), Ops1, DL, TLI); + Constant *RHS = ConstantFoldCompareInstOperands( + Predicate, CE0->getOperand(1), Ops1, DL, TLI); + unsigned OpC = + Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); + } + } else if (isa<ConstantExpr>(Ops1)) { + // If RHS is a constant expression, but the left side isn't, swap the + // operands and try again. + Predicate = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)Predicate); + return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI); + } + + return ConstantExpr::getCompare(Predicate, Ops0, Ops1); +} + +Constant *llvm::ConstantFoldUnaryOpOperand(unsigned Opcode, Constant *Op, + const DataLayout &DL) { + assert(Instruction::isUnaryOp(Opcode)); + + return ConstantExpr::get(Opcode, Op); +} + +Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, + Constant *RHS, + const DataLayout &DL) { + assert(Instruction::isBinaryOp(Opcode)); + if (isa<ConstantExpr>(LHS) || isa<ConstantExpr>(RHS)) + if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL)) + return C; + + return ConstantExpr::get(Opcode, LHS, RHS); +} + +Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, + Type *DestTy, const DataLayout &DL) { + assert(Instruction::isCast(Opcode)); + switch (Opcode) { + default: + llvm_unreachable("Missing case"); + case Instruction::PtrToInt: + // If the input is a inttoptr, eliminate the pair. This requires knowing + // the width of a pointer, so it can't be done in ConstantExpr::getCast. + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::IntToPtr) { + Constant *Input = CE->getOperand(0); + unsigned InWidth = Input->getType()->getScalarSizeInBits(); + unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); + if (PtrWidth < InWidth) { + Constant *Mask = + ConstantInt::get(CE->getContext(), + APInt::getLowBitsSet(InWidth, PtrWidth)); + Input = ConstantExpr::getAnd(Input, Mask); + } + // Do a zext or trunc to get to the dest size. + return ConstantExpr::getIntegerCast(Input, DestTy, false); + } + } + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::IntToPtr: + // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if + // the int size is >= the ptr size and the address spaces are the same. + // This requires knowing the width of a pointer, so it can't be done in + // ConstantExpr::getCast. + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::PtrToInt) { + Constant *SrcPtr = CE->getOperand(0); + unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); + unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); + + if (MidIntSize >= SrcPtrSize) { + unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); + if (SrcAS == DestTy->getPointerAddressSpace()) + return FoldBitCast(CE->getOperand(0), DestTy, DL); + } + } + } + + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::AddrSpaceCast: + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::BitCast: + return FoldBitCast(C, DestTy, DL); + } +} + +Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, + ConstantExpr *CE) { + if (!CE->getOperand(1)->isNullValue()) + return nullptr; // Do not allow stepping over the value! + + // Loop over all of the operands, tracking down which value we are + // addressing. + for (unsigned i = 2, e = CE->getNumOperands(); i != e; ++i) { + C = C->getAggregateElement(CE->getOperand(i)); + if (!C) + return nullptr; + } + return C; +} + +Constant * +llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, + ArrayRef<Constant *> Indices) { + // Loop over all of the operands, tracking down which value we are + // addressing. + for (Constant *Index : Indices) { + C = C->getAggregateElement(Index); + if (!C) + return nullptr; + } + return C; +} + +//===----------------------------------------------------------------------===// +// Constant Folding for Calls +// + +bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { + if (Call->isNoBuiltin() || Call->isStrictFP()) + return false; + switch (F->getIntrinsicID()) { + case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: + case Intrinsic::log: + case Intrinsic::log2: + case Intrinsic::log10: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::sqrt: + case Intrinsic::sin: + case Intrinsic::cos: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::pow: + case Intrinsic::powi: + case Intrinsic::bswap: + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::fshl: + case Intrinsic::fshr: + case Intrinsic::fma: + case Intrinsic::fmuladd: + case Intrinsic::copysign: + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + case Intrinsic::round: + case Intrinsic::masked_load: + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + case Intrinsic::sadd_sat: + case Intrinsic::uadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::usub_sat: + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: + case Intrinsic::convert_from_fp16: + case Intrinsic::convert_to_fp16: + case Intrinsic::bitreverse: + case Intrinsic::x86_sse_cvtss2si: + case Intrinsic::x86_sse_cvtss2si64: + case Intrinsic::x86_sse_cvttss2si: + case Intrinsic::x86_sse_cvttss2si64: + case Intrinsic::x86_sse2_cvtsd2si: + case Intrinsic::x86_sse2_cvtsd2si64: + case Intrinsic::x86_sse2_cvttsd2si: + case Intrinsic::x86_sse2_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: + case Intrinsic::is_constant: + return true; + default: + return false; + case Intrinsic::not_intrinsic: break; + } + + if (!F->hasName()) + return false; + + // In these cases, the check of the length is required. We don't want to + // return true for a name like "cos\0blah" which strcmp would return equal to + // "cos", but has length 8. + StringRef Name = F->getName(); + switch (Name[0]) { + default: + return false; + case 'a': + return Name == "acos" || Name == "acosf" || + Name == "asin" || Name == "asinf" || + Name == "atan" || Name == "atanf" || + Name == "atan2" || Name == "atan2f"; + case 'c': + return Name == "ceil" || Name == "ceilf" || + Name == "cos" || Name == "cosf" || + Name == "cosh" || Name == "coshf"; + case 'e': + return Name == "exp" || Name == "expf" || + Name == "exp2" || Name == "exp2f"; + case 'f': + return Name == "fabs" || Name == "fabsf" || + Name == "floor" || Name == "floorf" || + Name == "fmod" || Name == "fmodf"; + case 'l': + return Name == "log" || Name == "logf" || + Name == "log2" || Name == "log2f" || + Name == "log10" || Name == "log10f"; + case 'n': + return Name == "nearbyint" || Name == "nearbyintf"; + case 'p': + return Name == "pow" || Name == "powf"; + case 'r': + return Name == "rint" || Name == "rintf" || + Name == "round" || Name == "roundf"; + case 's': + return Name == "sin" || Name == "sinf" || + Name == "sinh" || Name == "sinhf" || + Name == "sqrt" || Name == "sqrtf"; + case 't': + return Name == "tan" || Name == "tanf" || + Name == "tanh" || Name == "tanhf" || + Name == "trunc" || Name == "truncf"; + case '_': + // Check for various function names that get used for the math functions + // when the header files are preprocessed with the macro + // __FINITE_MATH_ONLY__ enabled. + // The '12' here is the length of the shortest name that can match. + // We need to check the size before looking at Name[1] and Name[2] + // so we may as well check a limit that will eliminate mismatches. + if (Name.size() < 12 || Name[1] != '_') + return false; + switch (Name[2]) { + default: + return false; + case 'a': + return Name == "__acos_finite" || Name == "__acosf_finite" || + Name == "__asin_finite" || Name == "__asinf_finite" || + Name == "__atan2_finite" || Name == "__atan2f_finite"; + case 'c': + return Name == "__cosh_finite" || Name == "__coshf_finite"; + case 'e': + return Name == "__exp_finite" || Name == "__expf_finite" || + Name == "__exp2_finite" || Name == "__exp2f_finite"; + case 'l': + return Name == "__log_finite" || Name == "__logf_finite" || + Name == "__log10_finite" || Name == "__log10f_finite"; + case 'p': + return Name == "__pow_finite" || Name == "__powf_finite"; + case 's': + return Name == "__sinh_finite" || Name == "__sinhf_finite"; + } + } +} + +namespace { + +Constant *GetConstantFoldFPValue(double V, Type *Ty) { + if (Ty->isHalfTy() || Ty->isFloatTy()) { + APFloat APF(V); + bool unused; + APF.convert(Ty->getFltSemantics(), APFloat::rmNearestTiesToEven, &unused); + return ConstantFP::get(Ty->getContext(), APF); + } + if (Ty->isDoubleTy()) + return ConstantFP::get(Ty->getContext(), APFloat(V)); + llvm_unreachable("Can only constant fold half/float/double"); +} + +/// Clear the floating-point exception state. +inline void llvm_fenv_clearexcept() { +#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT + feclearexcept(FE_ALL_EXCEPT); +#endif + errno = 0; +} + +/// Test if a floating-point exception was raised. +inline bool llvm_fenv_testexcept() { + int errno_val = errno; + if (errno_val == ERANGE || errno_val == EDOM) + return true; +#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT && HAVE_DECL_FE_INEXACT + if (fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT)) + return true; +#endif + return false; +} + +Constant *ConstantFoldFP(double (*NativeFP)(double), double V, Type *Ty) { + llvm_fenv_clearexcept(); + V = NativeFP(V); + if (llvm_fenv_testexcept()) { + llvm_fenv_clearexcept(); + return nullptr; + } + + return GetConstantFoldFPValue(V, Ty); +} + +Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V, + double W, Type *Ty) { + llvm_fenv_clearexcept(); + V = NativeFP(V, W); + if (llvm_fenv_testexcept()) { + llvm_fenv_clearexcept(); + return nullptr; + } + + return GetConstantFoldFPValue(V, Ty); +} + +/// Attempt to fold an SSE floating point to integer conversion of a constant +/// floating point. If roundTowardZero is false, the default IEEE rounding is +/// used (toward nearest, ties to even). This matches the behavior of the +/// non-truncating SSE instructions in the default rounding mode. The desired +/// integer type Ty is used to select how many bits are available for the +/// result. Returns null if the conversion cannot be performed, otherwise +/// returns the Constant value resulting from the conversion. +Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero, + Type *Ty, bool IsSigned) { + // All of these conversion intrinsics form an integer of at most 64bits. + unsigned ResultWidth = Ty->getIntegerBitWidth(); + assert(ResultWidth <= 64 && + "Can only constant fold conversions to 64 and 32 bit ints"); + + uint64_t UIntVal; + bool isExact = false; + APFloat::roundingMode mode = roundTowardZero? APFloat::rmTowardZero + : APFloat::rmNearestTiesToEven; + APFloat::opStatus status = + Val.convertToInteger(makeMutableArrayRef(UIntVal), ResultWidth, + IsSigned, mode, &isExact); + if (status != APFloat::opOK && + (!roundTowardZero || status != APFloat::opInexact)) + return nullptr; + return ConstantInt::get(Ty, UIntVal, IsSigned); +} + +double getValueAsDouble(ConstantFP *Op) { + Type *Ty = Op->getType(); + + if (Ty->isFloatTy()) + return Op->getValueAPF().convertToFloat(); + + if (Ty->isDoubleTy()) + return Op->getValueAPF().convertToDouble(); + + bool unused; + APFloat APF = Op->getValueAPF(); + APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &unused); + return APF.convertToDouble(); +} + +static bool isManifestConstant(const Constant *c) { + if (isa<ConstantData>(c)) { + return true; + } else if (isa<ConstantAggregate>(c) || isa<ConstantExpr>(c)) { + for (const Value *subc : c->operand_values()) { + if (!isManifestConstant(cast<Constant>(subc))) + return false; + } + return true; + } + return false; +} + +static bool getConstIntOrUndef(Value *Op, const APInt *&C) { + if (auto *CI = dyn_cast<ConstantInt>(Op)) { + C = &CI->getValue(); + return true; + } + if (isa<UndefValue>(Op)) { + C = nullptr; + return true; + } + return false; +} + +static Constant *ConstantFoldScalarCall1(StringRef Name, + Intrinsic::ID IntrinsicID, + Type *Ty, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI, + const CallBase *Call) { + assert(Operands.size() == 1 && "Wrong number of operands."); + + if (IntrinsicID == Intrinsic::is_constant) { + // We know we have a "Constant" argument. But we want to only + // return true for manifest constants, not those that depend on + // constants with unknowable values, e.g. GlobalValue or BlockAddress. + if (isManifestConstant(Operands[0])) + return ConstantInt::getTrue(Ty->getContext()); + return nullptr; + } + if (isa<UndefValue>(Operands[0])) { + // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN. + // ctpop() is between 0 and bitwidth, pick 0 for undef. + if (IntrinsicID == Intrinsic::cos || + IntrinsicID == Intrinsic::ctpop) + return Constant::getNullValue(Ty); + if (IntrinsicID == Intrinsic::bswap || + IntrinsicID == Intrinsic::bitreverse || + IntrinsicID == Intrinsic::launder_invariant_group || + IntrinsicID == Intrinsic::strip_invariant_group) + return Operands[0]; + } + + if (isa<ConstantPointerNull>(Operands[0])) { + // launder(null) == null == strip(null) iff in addrspace 0 + if (IntrinsicID == Intrinsic::launder_invariant_group || + IntrinsicID == Intrinsic::strip_invariant_group) { + // If instruction is not yet put in a basic block (e.g. when cloning + // a function during inlining), Call's caller may not be available. + // So check Call's BB first before querying Call->getCaller. + const Function *Caller = + Call->getParent() ? Call->getCaller() : nullptr; + if (Caller && + !NullPointerIsDefined( + Caller, Operands[0]->getType()->getPointerAddressSpace())) { + return Operands[0]; + } + return nullptr; + } + } + + if (auto *Op = dyn_cast<ConstantFP>(Operands[0])) { + if (IntrinsicID == Intrinsic::convert_to_fp16) { + APFloat Val(Op->getValueAPF()); + + bool lost = false; + Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &lost); + + return ConstantInt::get(Ty->getContext(), Val.bitcastToAPInt()); + } + + if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) + return nullptr; + + // Use internal versions of these intrinsics. + APFloat U = Op->getValueAPF(); + + if (IntrinsicID == Intrinsic::nearbyint || IntrinsicID == Intrinsic::rint) { + U.roundToIntegral(APFloat::rmNearestTiesToEven); + return ConstantFP::get(Ty->getContext(), U); + } + + if (IntrinsicID == Intrinsic::round) { + U.roundToIntegral(APFloat::rmNearestTiesToAway); + return ConstantFP::get(Ty->getContext(), U); + } + + if (IntrinsicID == Intrinsic::ceil) { + U.roundToIntegral(APFloat::rmTowardPositive); + return ConstantFP::get(Ty->getContext(), U); + } + + if (IntrinsicID == Intrinsic::floor) { + U.roundToIntegral(APFloat::rmTowardNegative); + return ConstantFP::get(Ty->getContext(), U); + } + + if (IntrinsicID == Intrinsic::trunc) { + U.roundToIntegral(APFloat::rmTowardZero); + return ConstantFP::get(Ty->getContext(), U); + } + + if (IntrinsicID == Intrinsic::fabs) { + U.clearSign(); + return ConstantFP::get(Ty->getContext(), U); + } + + /// We only fold functions with finite arguments. Folding NaN and inf is + /// likely to be aborted with an exception anyway, and some host libms + /// have known errors raising exceptions. + if (Op->getValueAPF().isNaN() || Op->getValueAPF().isInfinity()) + return nullptr; + + /// Currently APFloat versions of these functions do not exist, so we use + /// the host native double versions. Float versions are not called + /// directly but for all these it is true (float)(f((double)arg)) == + /// f(arg). Long double not supported yet. + double V = getValueAsDouble(Op); + + switch (IntrinsicID) { + default: break; + case Intrinsic::log: + return ConstantFoldFP(log, V, Ty); + case Intrinsic::log2: + // TODO: What about hosts that lack a C99 library? + return ConstantFoldFP(Log2, V, Ty); + case Intrinsic::log10: + // TODO: What about hosts that lack a C99 library? + return ConstantFoldFP(log10, V, Ty); + case Intrinsic::exp: + return ConstantFoldFP(exp, V, Ty); + case Intrinsic::exp2: + // Fold exp2(x) as pow(2, x), in case the host lacks a C99 library. + return ConstantFoldBinaryFP(pow, 2.0, V, Ty); + case Intrinsic::sin: + return ConstantFoldFP(sin, V, Ty); + case Intrinsic::cos: + return ConstantFoldFP(cos, V, Ty); + case Intrinsic::sqrt: + return ConstantFoldFP(sqrt, V, Ty); + } + + if (!TLI) + return nullptr; + + LibFunc Func = NotLibFunc; + TLI->getLibFunc(Name, Func); + switch (Func) { + default: + break; + case LibFunc_acos: + case LibFunc_acosf: + case LibFunc_acos_finite: + case LibFunc_acosf_finite: + if (TLI->has(Func)) + return ConstantFoldFP(acos, V, Ty); + break; + case LibFunc_asin: + case LibFunc_asinf: + case LibFunc_asin_finite: + case LibFunc_asinf_finite: + if (TLI->has(Func)) + return ConstantFoldFP(asin, V, Ty); + break; + case LibFunc_atan: + case LibFunc_atanf: + if (TLI->has(Func)) + return ConstantFoldFP(atan, V, Ty); + break; + case LibFunc_ceil: + case LibFunc_ceilf: + if (TLI->has(Func)) { + U.roundToIntegral(APFloat::rmTowardPositive); + return ConstantFP::get(Ty->getContext(), U); + } + break; + case LibFunc_cos: + case LibFunc_cosf: + if (TLI->has(Func)) + return ConstantFoldFP(cos, V, Ty); + break; + case LibFunc_cosh: + case LibFunc_coshf: + case LibFunc_cosh_finite: + case LibFunc_coshf_finite: + if (TLI->has(Func)) + return ConstantFoldFP(cosh, V, Ty); + break; + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_exp_finite: + case LibFunc_expf_finite: + if (TLI->has(Func)) + return ConstantFoldFP(exp, V, Ty); + break; + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2_finite: + case LibFunc_exp2f_finite: + if (TLI->has(Func)) + // Fold exp2(x) as pow(2, x), in case the host lacks a C99 library. + return ConstantFoldBinaryFP(pow, 2.0, V, Ty); + break; + case LibFunc_fabs: + case LibFunc_fabsf: + if (TLI->has(Func)) { + U.clearSign(); + return ConstantFP::get(Ty->getContext(), U); + } + break; + case LibFunc_floor: + case LibFunc_floorf: + if (TLI->has(Func)) { + U.roundToIntegral(APFloat::rmTowardNegative); + return ConstantFP::get(Ty->getContext(), U); + } + break; + case LibFunc_log: + case LibFunc_logf: + case LibFunc_log_finite: + case LibFunc_logf_finite: + if (V > 0.0 && TLI->has(Func)) + return ConstantFoldFP(log, V, Ty); + break; + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log2_finite: + case LibFunc_log2f_finite: + if (V > 0.0 && TLI->has(Func)) + // TODO: What about hosts that lack a C99 library? + return ConstantFoldFP(Log2, V, Ty); + break; + case LibFunc_log10: + case LibFunc_log10f: + case LibFunc_log10_finite: + case LibFunc_log10f_finite: + if (V > 0.0 && TLI->has(Func)) + // TODO: What about hosts that lack a C99 library? + return ConstantFoldFP(log10, V, Ty); + break; + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_rint: + case LibFunc_rintf: + if (TLI->has(Func)) { + U.roundToIntegral(APFloat::rmNearestTiesToEven); + return ConstantFP::get(Ty->getContext(), U); + } + break; + case LibFunc_round: + case LibFunc_roundf: + if (TLI->has(Func)) { + U.roundToIntegral(APFloat::rmNearestTiesToAway); + return ConstantFP::get(Ty->getContext(), U); + } + break; + case LibFunc_sin: + case LibFunc_sinf: + if (TLI->has(Func)) + return ConstantFoldFP(sin, V, Ty); + break; + case LibFunc_sinh: + case LibFunc_sinhf: + case LibFunc_sinh_finite: + case LibFunc_sinhf_finite: + if (TLI->has(Func)) + return ConstantFoldFP(sinh, V, Ty); + break; + case LibFunc_sqrt: + case LibFunc_sqrtf: + if (V >= 0.0 && TLI->has(Func)) + return ConstantFoldFP(sqrt, V, Ty); + break; + case LibFunc_tan: + case LibFunc_tanf: + if (TLI->has(Func)) + return ConstantFoldFP(tan, V, Ty); + break; + case LibFunc_tanh: + case LibFunc_tanhf: + if (TLI->has(Func)) + return ConstantFoldFP(tanh, V, Ty); + break; + case LibFunc_trunc: + case LibFunc_truncf: + if (TLI->has(Func)) { + U.roundToIntegral(APFloat::rmTowardZero); + return ConstantFP::get(Ty->getContext(), U); + } + break; + } + return nullptr; + } + + if (auto *Op = dyn_cast<ConstantInt>(Operands[0])) { + switch (IntrinsicID) { + case Intrinsic::bswap: + return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap()); + case Intrinsic::ctpop: + return ConstantInt::get(Ty, Op->getValue().countPopulation()); + case Intrinsic::bitreverse: + return ConstantInt::get(Ty->getContext(), Op->getValue().reverseBits()); + case Intrinsic::convert_from_fp16: { + APFloat Val(APFloat::IEEEhalf(), Op->getValue()); + + bool lost = false; + APFloat::opStatus status = Val.convert( + Ty->getFltSemantics(), APFloat::rmNearestTiesToEven, &lost); + + // Conversion is always precise. + (void)status; + assert(status == APFloat::opOK && !lost && + "Precision lost during fp16 constfolding"); + + return ConstantFP::get(Ty->getContext(), Val); + } + default: + return nullptr; + } + } + + // Support ConstantVector in case we have an Undef in the top. + if (isa<ConstantVector>(Operands[0]) || + isa<ConstantDataVector>(Operands[0])) { + auto *Op = cast<Constant>(Operands[0]); + switch (IntrinsicID) { + default: break; + case Intrinsic::x86_sse_cvtss2si: + case Intrinsic::x86_sse_cvtss2si64: + case Intrinsic::x86_sse2_cvtsd2si: + case Intrinsic::x86_sse2_cvtsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/false, Ty, + /*IsSigned*/true); + break; + case Intrinsic::x86_sse_cvttss2si: + case Intrinsic::x86_sse_cvttss2si64: + case Intrinsic::x86_sse2_cvttsd2si: + case Intrinsic::x86_sse2_cvttsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/true, Ty, + /*IsSigned*/true); + break; + } + } + + return nullptr; +} + +static Constant *ConstantFoldScalarCall2(StringRef Name, + Intrinsic::ID IntrinsicID, + Type *Ty, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI, + const CallBase *Call) { + assert(Operands.size() == 2 && "Wrong number of operands."); + + if (auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { + if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) + return nullptr; + double Op1V = getValueAsDouble(Op1); + + if (auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { + if (Op2->getType() != Op1->getType()) + return nullptr; + + double Op2V = getValueAsDouble(Op2); + if (IntrinsicID == Intrinsic::pow) { + return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); + } + if (IntrinsicID == Intrinsic::copysign) { + APFloat V1 = Op1->getValueAPF(); + const APFloat &V2 = Op2->getValueAPF(); + V1.copySign(V2); + return ConstantFP::get(Ty->getContext(), V1); + } + + if (IntrinsicID == Intrinsic::minnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), minnum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::maxnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), maxnum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::minimum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), minimum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::maximum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), maximum(C1, C2)); + } + + if (!TLI) + return nullptr; + + LibFunc Func = NotLibFunc; + TLI->getLibFunc(Name, Func); + switch (Func) { + default: + break; + case LibFunc_pow: + case LibFunc_powf: + case LibFunc_pow_finite: + case LibFunc_powf_finite: + if (TLI->has(Func)) + return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); + break; + case LibFunc_fmod: + case LibFunc_fmodf: + if (TLI->has(Func)) { + APFloat V = Op1->getValueAPF(); + if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF())) + return ConstantFP::get(Ty->getContext(), V); + } + break; + case LibFunc_atan2: + case LibFunc_atan2f: + case LibFunc_atan2_finite: + case LibFunc_atan2f_finite: + if (TLI->has(Func)) + return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); + break; + } + } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) { + if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy()) + return ConstantFP::get(Ty->getContext(), + APFloat((float)std::pow((float)Op1V, + (int)Op2C->getZExtValue()))); + if (IntrinsicID == Intrinsic::powi && Ty->isFloatTy()) + return ConstantFP::get(Ty->getContext(), + APFloat((float)std::pow((float)Op1V, + (int)Op2C->getZExtValue()))); + if (IntrinsicID == Intrinsic::powi && Ty->isDoubleTy()) + return ConstantFP::get(Ty->getContext(), + APFloat((double)std::pow((double)Op1V, + (int)Op2C->getZExtValue()))); + } + return nullptr; + } + + if (Operands[0]->getType()->isIntegerTy() && + Operands[1]->getType()->isIntegerTy()) { + const APInt *C0, *C1; + if (!getConstIntOrUndef(Operands[0], C0) || + !getConstIntOrUndef(Operands[1], C1)) + return nullptr; + + switch (IntrinsicID) { + default: break; + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + // X - undef -> { undef, false } + // undef - X -> { undef, false } + // X + undef -> { undef, false } + // undef + x -> { undef, false } + if (!C0 || !C1) { + return ConstantStruct::get( + cast<StructType>(Ty), + {UndefValue::get(Ty->getStructElementType(0)), + Constant::getNullValue(Ty->getStructElementType(1))}); + } + LLVM_FALLTHROUGH; + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: { + // undef * X -> { 0, false } + // X * undef -> { 0, false } + if (!C0 || !C1) + return Constant::getNullValue(Ty); + + APInt Res; + bool Overflow; + switch (IntrinsicID) { + default: llvm_unreachable("Invalid case"); + case Intrinsic::sadd_with_overflow: + Res = C0->sadd_ov(*C1, Overflow); + break; + case Intrinsic::uadd_with_overflow: + Res = C0->uadd_ov(*C1, Overflow); + break; + case Intrinsic::ssub_with_overflow: + Res = C0->ssub_ov(*C1, Overflow); + break; + case Intrinsic::usub_with_overflow: + Res = C0->usub_ov(*C1, Overflow); + break; + case Intrinsic::smul_with_overflow: + Res = C0->smul_ov(*C1, Overflow); + break; + case Intrinsic::umul_with_overflow: + Res = C0->umul_ov(*C1, Overflow); + break; + } + Constant *Ops[] = { + ConstantInt::get(Ty->getContext(), Res), + ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow) + }; + return ConstantStruct::get(cast<StructType>(Ty), Ops); + } + case Intrinsic::uadd_sat: + case Intrinsic::sadd_sat: + if (!C0 && !C1) + return UndefValue::get(Ty); + if (!C0 || !C1) + return Constant::getAllOnesValue(Ty); + if (IntrinsicID == Intrinsic::uadd_sat) + return ConstantInt::get(Ty, C0->uadd_sat(*C1)); + else + return ConstantInt::get(Ty, C0->sadd_sat(*C1)); + case Intrinsic::usub_sat: + case Intrinsic::ssub_sat: + if (!C0 && !C1) + return UndefValue::get(Ty); + if (!C0 || !C1) + return Constant::getNullValue(Ty); + if (IntrinsicID == Intrinsic::usub_sat) + return ConstantInt::get(Ty, C0->usub_sat(*C1)); + else + return ConstantInt::get(Ty, C0->ssub_sat(*C1)); + case Intrinsic::cttz: + case Intrinsic::ctlz: + assert(C1 && "Must be constant int"); + + // cttz(0, 1) and ctlz(0, 1) are undef. + if (C1->isOneValue() && (!C0 || C0->isNullValue())) + return UndefValue::get(Ty); + if (!C0) + return Constant::getNullValue(Ty); + if (IntrinsicID == Intrinsic::cttz) + return ConstantInt::get(Ty, C0->countTrailingZeros()); + else + return ConstantInt::get(Ty, C0->countLeadingZeros()); + } + + return nullptr; + } + + // Support ConstantVector in case we have an Undef in the top. + if ((isa<ConstantVector>(Operands[0]) || + isa<ConstantDataVector>(Operands[0])) && + // Check for default rounding mode. + // FIXME: Support other rounding modes? + isa<ConstantInt>(Operands[1]) && + cast<ConstantInt>(Operands[1])->getValue() == 4) { + auto *Op = cast<Constant>(Operands[0]); + switch (IntrinsicID) { + default: break; + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/false, Ty, + /*IsSigned*/true); + break; + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/false, Ty, + /*IsSigned*/false); + break; + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/true, Ty, + /*IsSigned*/true); + break; + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/true, Ty, + /*IsSigned*/false); + break; + } + } + return nullptr; +} + +static Constant *ConstantFoldScalarCall3(StringRef Name, + Intrinsic::ID IntrinsicID, + Type *Ty, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI, + const CallBase *Call) { + assert(Operands.size() == 3 && "Wrong number of operands."); + + if (const auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { + if (const auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { + if (const auto *Op3 = dyn_cast<ConstantFP>(Operands[2])) { + switch (IntrinsicID) { + default: break; + case Intrinsic::fma: + case Intrinsic::fmuladd: { + APFloat V = Op1->getValueAPF(); + V.fusedMultiplyAdd(Op2->getValueAPF(), Op3->getValueAPF(), + APFloat::rmNearestTiesToEven); + return ConstantFP::get(Ty->getContext(), V); + } + } + } + } + } + + if (const auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) { + if (const auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) { + if (const auto *Op3 = dyn_cast<ConstantInt>(Operands[2])) { + switch (IntrinsicID) { + default: break; + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: { + // This code performs rounding towards negative infinity in case the + // result cannot be represented exactly for the given scale. Targets + // that do care about rounding should use a target hook for specifying + // how rounding should be done, and provide their own folding to be + // consistent with rounding. This is the same approach as used by + // DAGTypeLegalizer::ExpandIntRes_MULFIX. + APInt Lhs = Op1->getValue(); + APInt Rhs = Op2->getValue(); + unsigned Scale = Op3->getValue().getZExtValue(); + unsigned Width = Lhs.getBitWidth(); + assert(Scale < Width && "Illegal scale."); + unsigned ExtendedWidth = Width * 2; + APInt Product = (Lhs.sextOrSelf(ExtendedWidth) * + Rhs.sextOrSelf(ExtendedWidth)).ashr(Scale); + if (IntrinsicID == Intrinsic::smul_fix_sat) { + APInt MaxValue = + APInt::getSignedMaxValue(Width).sextOrSelf(ExtendedWidth); + APInt MinValue = + APInt::getSignedMinValue(Width).sextOrSelf(ExtendedWidth); + Product = APIntOps::smin(Product, MaxValue); + Product = APIntOps::smax(Product, MinValue); + } + return ConstantInt::get(Ty->getContext(), + Product.sextOrTrunc(Width)); + } + } + } + } + } + + if (IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) { + const APInt *C0, *C1, *C2; + if (!getConstIntOrUndef(Operands[0], C0) || + !getConstIntOrUndef(Operands[1], C1) || + !getConstIntOrUndef(Operands[2], C2)) + return nullptr; + + bool IsRight = IntrinsicID == Intrinsic::fshr; + if (!C2) + return Operands[IsRight ? 1 : 0]; + if (!C0 && !C1) + return UndefValue::get(Ty); + + // The shift amount is interpreted as modulo the bitwidth. If the shift + // amount is effectively 0, avoid UB due to oversized inverse shift below. + unsigned BitWidth = C2->getBitWidth(); + unsigned ShAmt = C2->urem(BitWidth); + if (!ShAmt) + return Operands[IsRight ? 1 : 0]; + + // (C0 << ShlAmt) | (C1 >> LshrAmt) + unsigned LshrAmt = IsRight ? ShAmt : BitWidth - ShAmt; + unsigned ShlAmt = !IsRight ? ShAmt : BitWidth - ShAmt; + if (!C0) + return ConstantInt::get(Ty, C1->lshr(LshrAmt)); + if (!C1) + return ConstantInt::get(Ty, C0->shl(ShlAmt)); + return ConstantInt::get(Ty, C0->shl(ShlAmt) | C1->lshr(LshrAmt)); + } + + return nullptr; +} + +static Constant *ConstantFoldScalarCall(StringRef Name, + Intrinsic::ID IntrinsicID, + Type *Ty, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI, + const CallBase *Call) { + if (Operands.size() == 1) + return ConstantFoldScalarCall1(Name, IntrinsicID, Ty, Operands, TLI, Call); + + if (Operands.size() == 2) + return ConstantFoldScalarCall2(Name, IntrinsicID, Ty, Operands, TLI, Call); + + if (Operands.size() == 3) + return ConstantFoldScalarCall3(Name, IntrinsicID, Ty, Operands, TLI, Call); + + return nullptr; +} + +static Constant *ConstantFoldVectorCall(StringRef Name, + Intrinsic::ID IntrinsicID, + VectorType *VTy, + ArrayRef<Constant *> Operands, + const DataLayout &DL, + const TargetLibraryInfo *TLI, + const CallBase *Call) { + SmallVector<Constant *, 4> Result(VTy->getNumElements()); + SmallVector<Constant *, 4> Lane(Operands.size()); + Type *Ty = VTy->getElementType(); + + if (IntrinsicID == Intrinsic::masked_load) { + auto *SrcPtr = Operands[0]; + auto *Mask = Operands[2]; + auto *Passthru = Operands[3]; + + Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL); + + SmallVector<Constant *, 32> NewElements; + for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { + auto *MaskElt = Mask->getAggregateElement(I); + if (!MaskElt) + break; + auto *PassthruElt = Passthru->getAggregateElement(I); + auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr; + if (isa<UndefValue>(MaskElt)) { + if (PassthruElt) + NewElements.push_back(PassthruElt); + else if (VecElt) + NewElements.push_back(VecElt); + else + return nullptr; + } + if (MaskElt->isNullValue()) { + if (!PassthruElt) + return nullptr; + NewElements.push_back(PassthruElt); + } else if (MaskElt->isOneValue()) { + if (!VecElt) + return nullptr; + NewElements.push_back(VecElt); + } else { + return nullptr; + } + } + if (NewElements.size() != VTy->getNumElements()) + return nullptr; + return ConstantVector::get(NewElements); + } + + for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { + // Gather a column of constants. + for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) { + // Some intrinsics use a scalar type for certain arguments. + if (hasVectorInstrinsicScalarOpd(IntrinsicID, J)) { + Lane[J] = Operands[J]; + continue; + } + + Constant *Agg = Operands[J]->getAggregateElement(I); + if (!Agg) + return nullptr; + + Lane[J] = Agg; + } + + // Use the regular scalar folding to simplify this column. + Constant *Folded = + ConstantFoldScalarCall(Name, IntrinsicID, Ty, Lane, TLI, Call); + if (!Folded) + return nullptr; + Result[I] = Folded; + } + + return ConstantVector::get(Result); +} + +} // end anonymous namespace + +Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI) { + if (Call->isNoBuiltin() || Call->isStrictFP()) + return nullptr; + if (!F->hasName()) + return nullptr; + StringRef Name = F->getName(); + + Type *Ty = F->getReturnType(); + + if (auto *VTy = dyn_cast<VectorType>(Ty)) + return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands, + F->getParent()->getDataLayout(), TLI, Call); + + return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI, + Call); +} + +bool llvm::isMathLibCallNoop(const CallBase *Call, + const TargetLibraryInfo *TLI) { + // FIXME: Refactor this code; this duplicates logic in LibCallsShrinkWrap + // (and to some extent ConstantFoldScalarCall). + if (Call->isNoBuiltin() || Call->isStrictFP()) + return false; + Function *F = Call->getCalledFunction(); + if (!F) + return false; + + LibFunc Func; + if (!TLI || !TLI->getLibFunc(*F, Func)) + return false; + + if (Call->getNumArgOperands() == 1) { + if (ConstantFP *OpC = dyn_cast<ConstantFP>(Call->getArgOperand(0))) { + const APFloat &Op = OpC->getValueAPF(); + switch (Func) { + case LibFunc_logl: + case LibFunc_log: + case LibFunc_logf: + case LibFunc_log2l: + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log10l: + case LibFunc_log10: + case LibFunc_log10f: + return Op.isNaN() || (!Op.isZero() && !Op.isNegative()); + + case LibFunc_expl: + case LibFunc_exp: + case LibFunc_expf: + // FIXME: These boundaries are slightly conservative. + if (OpC->getType()->isDoubleTy()) + return Op.compare(APFloat(-745.0)) != APFloat::cmpLessThan && + Op.compare(APFloat(709.0)) != APFloat::cmpGreaterThan; + if (OpC->getType()->isFloatTy()) + return Op.compare(APFloat(-103.0f)) != APFloat::cmpLessThan && + Op.compare(APFloat(88.0f)) != APFloat::cmpGreaterThan; + break; + + case LibFunc_exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: + // FIXME: These boundaries are slightly conservative. + if (OpC->getType()->isDoubleTy()) + return Op.compare(APFloat(-1074.0)) != APFloat::cmpLessThan && + Op.compare(APFloat(1023.0)) != APFloat::cmpGreaterThan; + if (OpC->getType()->isFloatTy()) + return Op.compare(APFloat(-149.0f)) != APFloat::cmpLessThan && + Op.compare(APFloat(127.0f)) != APFloat::cmpGreaterThan; + break; + + case LibFunc_sinl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_cosl: + case LibFunc_cos: + case LibFunc_cosf: + return !Op.isInfinity(); + + case LibFunc_tanl: + case LibFunc_tan: + case LibFunc_tanf: { + // FIXME: Stop using the host math library. + // FIXME: The computation isn't done in the right precision. + Type *Ty = OpC->getType(); + if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) { + double OpV = getValueAsDouble(OpC); + return ConstantFoldFP(tan, OpV, Ty) != nullptr; + } + break; + } + + case LibFunc_asinl: + case LibFunc_asin: + case LibFunc_asinf: + case LibFunc_acosl: + case LibFunc_acos: + case LibFunc_acosf: + return Op.compare(APFloat(Op.getSemantics(), "-1")) != + APFloat::cmpLessThan && + Op.compare(APFloat(Op.getSemantics(), "1")) != + APFloat::cmpGreaterThan; + + case LibFunc_sinh: + case LibFunc_cosh: + case LibFunc_sinhf: + case LibFunc_coshf: + case LibFunc_sinhl: + case LibFunc_coshl: + // FIXME: These boundaries are slightly conservative. + if (OpC->getType()->isDoubleTy()) + return Op.compare(APFloat(-710.0)) != APFloat::cmpLessThan && + Op.compare(APFloat(710.0)) != APFloat::cmpGreaterThan; + if (OpC->getType()->isFloatTy()) + return Op.compare(APFloat(-89.0f)) != APFloat::cmpLessThan && + Op.compare(APFloat(89.0f)) != APFloat::cmpGreaterThan; + break; + + case LibFunc_sqrtl: + case LibFunc_sqrt: + case LibFunc_sqrtf: + return Op.isNaN() || Op.isZero() || !Op.isNegative(); + + // FIXME: Add more functions: sqrt_finite, atanh, expm1, log1p, + // maybe others? + default: + break; + } + } + } + + if (Call->getNumArgOperands() == 2) { + ConstantFP *Op0C = dyn_cast<ConstantFP>(Call->getArgOperand(0)); + ConstantFP *Op1C = dyn_cast<ConstantFP>(Call->getArgOperand(1)); + if (Op0C && Op1C) { + const APFloat &Op0 = Op0C->getValueAPF(); + const APFloat &Op1 = Op1C->getValueAPF(); + + switch (Func) { + case LibFunc_powl: + case LibFunc_pow: + case LibFunc_powf: { + // FIXME: Stop using the host math library. + // FIXME: The computation isn't done in the right precision. + Type *Ty = Op0C->getType(); + if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) { + if (Ty == Op1C->getType()) { + double Op0V = getValueAsDouble(Op0C); + double Op1V = getValueAsDouble(Op1C); + return ConstantFoldBinaryFP(pow, Op0V, Op1V, Ty) != nullptr; + } + } + break; + } + + case LibFunc_fmodl: + case LibFunc_fmod: + case LibFunc_fmodf: + return Op0.isNaN() || Op1.isNaN() || + (!Op0.isInfinity() && !Op1.isZero()); + + default: + break; + } + } + } + + return false; +} diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp new file mode 100644 index 000000000000..bf0cdbfd0c8b --- /dev/null +++ b/llvm/lib/Analysis/CostModel.cpp @@ -0,0 +1,111 @@ +//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the cost model analysis. It provides a very basic cost +// estimation for LLVM-IR. This analysis uses the services of the codegen +// to approximate the cost of any IR instruction when lowered to machine +// instructions. The cost results are unit-less and the cost number represents +// the throughput of the machine assuming that all loads hit the cache, all +// branches are predicted, etc. The cost numbers can be added in order to +// compare two or more transformation alternatives. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static cl::opt<TargetTransformInfo::TargetCostKind> CostKind( + "cost-kind", cl::desc("Target cost kind"), + cl::init(TargetTransformInfo::TCK_RecipThroughput), + cl::values(clEnumValN(TargetTransformInfo::TCK_RecipThroughput, + "throughput", "Reciprocal throughput"), + clEnumValN(TargetTransformInfo::TCK_Latency, + "latency", "Instruction latency"), + clEnumValN(TargetTransformInfo::TCK_CodeSize, + "code-size", "Code size"))); + +#define CM_NAME "cost-model" +#define DEBUG_TYPE CM_NAME + +namespace { + class CostModelAnalysis : public FunctionPass { + + public: + static char ID; // Class identification, replacement for typeinfo + CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) { + initializeCostModelAnalysisPass( + *PassRegistry::getPassRegistry()); + } + + /// Returns the expected cost of the instruction. + /// Returns -1 if the cost is unknown. + /// Note, this method does not cache the cost calculation and it + /// can be expensive in some cases. + unsigned getInstructionCost(const Instruction *I) const { + return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); + } + + private: + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; + void print(raw_ostream &OS, const Module*) const override; + + /// The function that we analyze. + Function *F; + /// Target information. + const TargetTransformInfo *TTI; + }; +} // End of anonymous namespace + +// Register this pass. +char CostModelAnalysis::ID = 0; +static const char cm_name[] = "Cost Model Analysis"; +INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true) +INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true) + +FunctionPass *llvm::createCostModelAnalysisPass() { + return new CostModelAnalysis(); +} + +void +CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} + +bool +CostModelAnalysis::runOnFunction(Function &F) { + this->F = &F; + auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr; + + return false; +} + +void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { + if (!F) + return; + + for (BasicBlock &B : *F) { + for (Instruction &Inst : B) { + unsigned Cost = TTI->getInstructionCost(&Inst, CostKind); + if (Cost != (unsigned)-1) + OS << "Cost Model: Found an estimated cost of " << Cost; + else + OS << "Cost Model: Unknown cost"; + + OS << " for instruction: " << Inst << "\n"; + } + } +} diff --git a/llvm/lib/Analysis/DDG.cpp b/llvm/lib/Analysis/DDG.cpp new file mode 100644 index 000000000000..b5c3c761ad98 --- /dev/null +++ b/llvm/lib/Analysis/DDG.cpp @@ -0,0 +1,203 @@ +//===- DDG.cpp - Data Dependence Graph -------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The implementation for the data dependence graph. +//===----------------------------------------------------------------------===// +#include "llvm/Analysis/DDG.h" +#include "llvm/Analysis/LoopInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "ddg" + +template class llvm::DGEdge<DDGNode, DDGEdge>; +template class llvm::DGNode<DDGNode, DDGEdge>; +template class llvm::DirectedGraph<DDGNode, DDGEdge>; + +//===--------------------------------------------------------------------===// +// DDGNode implementation +//===--------------------------------------------------------------------===// +DDGNode::~DDGNode() {} + +bool DDGNode::collectInstructions( + llvm::function_ref<bool(Instruction *)> const &Pred, + InstructionListType &IList) const { + assert(IList.empty() && "Expected the IList to be empty on entry."); + if (isa<SimpleDDGNode>(this)) { + for (auto *I : cast<const SimpleDDGNode>(this)->getInstructions()) + if (Pred(I)) + IList.push_back(I); + } else + llvm_unreachable("unimplemented type of node"); + return !IList.empty(); +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGNode::NodeKind K) { + const char *Out; + switch (K) { + case DDGNode::NodeKind::SingleInstruction: + Out = "single-instruction"; + break; + case DDGNode::NodeKind::MultiInstruction: + Out = "multi-instruction"; + break; + case DDGNode::NodeKind::Root: + Out = "root"; + break; + case DDGNode::NodeKind::Unknown: + Out = "??"; + break; + } + OS << Out; + return OS; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGNode &N) { + OS << "Node Address:" << &N << ":" << N.getKind() << "\n"; + if (isa<SimpleDDGNode>(N)) { + OS << " Instructions:\n"; + for (auto *I : cast<const SimpleDDGNode>(N).getInstructions()) + OS.indent(2) << *I << "\n"; + } else if (!isa<RootDDGNode>(N)) + llvm_unreachable("unimplemented type of node"); + + OS << (N.getEdges().empty() ? " Edges:none!\n" : " Edges:\n"); + for (auto &E : N.getEdges()) + OS.indent(2) << *E; + return OS; +} + +//===--------------------------------------------------------------------===// +// SimpleDDGNode implementation +//===--------------------------------------------------------------------===// + +SimpleDDGNode::SimpleDDGNode(Instruction &I) + : DDGNode(NodeKind::SingleInstruction), InstList() { + assert(InstList.empty() && "Expected empty list."); + InstList.push_back(&I); +} + +SimpleDDGNode::SimpleDDGNode(const SimpleDDGNode &N) + : DDGNode(N), InstList(N.InstList) { + assert(((getKind() == NodeKind::SingleInstruction && InstList.size() == 1) || + (getKind() == NodeKind::MultiInstruction && InstList.size() > 1)) && + "constructing from invalid simple node."); +} + +SimpleDDGNode::SimpleDDGNode(SimpleDDGNode &&N) + : DDGNode(std::move(N)), InstList(std::move(N.InstList)) { + assert(((getKind() == NodeKind::SingleInstruction && InstList.size() == 1) || + (getKind() == NodeKind::MultiInstruction && InstList.size() > 1)) && + "constructing from invalid simple node."); +} + +SimpleDDGNode::~SimpleDDGNode() { InstList.clear(); } + +//===--------------------------------------------------------------------===// +// DDGEdge implementation +//===--------------------------------------------------------------------===// + +raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGEdge::EdgeKind K) { + const char *Out; + switch (K) { + case DDGEdge::EdgeKind::RegisterDefUse: + Out = "def-use"; + break; + case DDGEdge::EdgeKind::MemoryDependence: + Out = "memory"; + break; + case DDGEdge::EdgeKind::Rooted: + Out = "rooted"; + break; + case DDGEdge::EdgeKind::Unknown: + Out = "??"; + break; + } + OS << Out; + return OS; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGEdge &E) { + OS << "[" << E.getKind() << "] to " << &E.getTargetNode() << "\n"; + return OS; +} + +//===--------------------------------------------------------------------===// +// DataDependenceGraph implementation +//===--------------------------------------------------------------------===// +using BasicBlockListType = SmallVector<BasicBlock *, 8>; + +DataDependenceGraph::DataDependenceGraph(Function &F, DependenceInfo &D) + : DependenceGraphInfo(F.getName().str(), D) { + BasicBlockListType BBList; + for (auto &BB : F.getBasicBlockList()) + BBList.push_back(&BB); + DDGBuilder(*this, D, BBList).populate(); +} + +DataDependenceGraph::DataDependenceGraph(const Loop &L, DependenceInfo &D) + : DependenceGraphInfo(Twine(L.getHeader()->getParent()->getName() + "." + + L.getHeader()->getName()) + .str(), + D) { + BasicBlockListType BBList; + for (BasicBlock *BB : L.blocks()) + BBList.push_back(BB); + DDGBuilder(*this, D, BBList).populate(); +} + +DataDependenceGraph::~DataDependenceGraph() { + for (auto *N : Nodes) { + for (auto *E : *N) + delete E; + delete N; + } +} + +bool DataDependenceGraph::addNode(DDGNode &N) { + if (!DDGBase::addNode(N)) + return false; + + // In general, if the root node is already created and linked, it is not safe + // to add new nodes since they may be unreachable by the root. + // TODO: Allow adding Pi-block nodes after root is created. Pi-blocks are an + // exception because they represent components that are already reachable by + // root. + assert(!Root && "Root node is already added. No more nodes can be added."); + if (isa<RootDDGNode>(N)) + Root = &N; + + return true; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const DataDependenceGraph &G) { + for (auto *Node : G) + OS << *Node << "\n"; + return OS; +} + +//===--------------------------------------------------------------------===// +// DDG Analysis Passes +//===--------------------------------------------------------------------===// + +/// DDG as a loop pass. +DDGAnalysis::Result DDGAnalysis::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR) { + Function *F = L.getHeader()->getParent(); + DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI); + return std::make_unique<DataDependenceGraph>(L, DI); +} +AnalysisKey DDGAnalysis::Key; + +PreservedAnalyses DDGAnalysisPrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + OS << "'DDG' for loop '" << L.getHeader()->getName() << "':\n"; + OS << *AM.getResult<DDGAnalysis>(L, AR); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/Delinearization.cpp b/llvm/lib/Analysis/Delinearization.cpp new file mode 100644 index 000000000000..c1043e446beb --- /dev/null +++ b/llvm/lib/Analysis/Delinearization.cpp @@ -0,0 +1,129 @@ +//===---- Delinearization.cpp - MultiDimensional Index Delinearization ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This implements an analysis pass that tries to delinearize all GEP +// instructions in all loops using the SCEV analysis functionality. This pass is +// only used for testing purposes: if your pass needs delinearization, please +// use the on-demand SCEVAddRecExpr::delinearize() function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DL_NAME "delinearize" +#define DEBUG_TYPE DL_NAME + +namespace { + +class Delinearization : public FunctionPass { + Delinearization(const Delinearization &); // do not implement +protected: + Function *F; + LoopInfo *LI; + ScalarEvolution *SE; + +public: + static char ID; // Pass identification, replacement for typeid + + Delinearization() : FunctionPass(ID) { + initializeDelinearizationPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + void print(raw_ostream &O, const Module *M = nullptr) const override; +}; + +} // end anonymous namespace + +void Delinearization::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); +} + +bool Delinearization::runOnFunction(Function &F) { + this->F = &F; + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + return false; +} + +void Delinearization::print(raw_ostream &O, const Module *) const { + O << "Delinearization on function " << F->getName() << ":\n"; + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { + Instruction *Inst = &(*I); + + // Only analyze loads and stores. + if (!isa<StoreInst>(Inst) && !isa<LoadInst>(Inst) && + !isa<GetElementPtrInst>(Inst)) + continue; + + const BasicBlock *BB = Inst->getParent(); + // Delinearize the memory access as analyzed in all the surrounding loops. + // Do not analyze memory accesses outside loops. + for (Loop *L = LI->getLoopFor(BB); L != nullptr; L = L->getParentLoop()) { + const SCEV *AccessFn = SE->getSCEVAtScope(getPointerOperand(Inst), L); + + const SCEVUnknown *BasePointer = + dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn)); + // Do not delinearize if we cannot find the base pointer. + if (!BasePointer) + break; + AccessFn = SE->getMinusSCEV(AccessFn, BasePointer); + + O << "\n"; + O << "Inst:" << *Inst << "\n"; + O << "In Loop with Header: " << L->getHeader()->getName() << "\n"; + O << "AccessFunction: " << *AccessFn << "\n"; + + SmallVector<const SCEV *, 3> Subscripts, Sizes; + SE->delinearize(AccessFn, Subscripts, Sizes, SE->getElementSize(Inst)); + if (Subscripts.size() == 0 || Sizes.size() == 0 || + Subscripts.size() != Sizes.size()) { + O << "failed to delinearize\n"; + continue; + } + + O << "Base offset: " << *BasePointer << "\n"; + O << "ArrayDecl[UnknownSize]"; + int Size = Subscripts.size(); + for (int i = 0; i < Size - 1; i++) + O << "[" << *Sizes[i] << "]"; + O << " with elements of " << *Sizes[Size - 1] << " bytes.\n"; + + O << "ArrayRef"; + for (int i = 0; i < Size; i++) + O << "[" << *Subscripts[i] << "]"; + O << "\n"; + } + } +} + +char Delinearization::ID = 0; +static const char delinearization_name[] = "Delinearization"; +INITIALIZE_PASS_BEGIN(Delinearization, DL_NAME, delinearization_name, true, + true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(Delinearization, DL_NAME, delinearization_name, true, true) + +FunctionPass *llvm::createDelinearizationPass() { return new Delinearization; } diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp new file mode 100644 index 000000000000..01b8ff10d355 --- /dev/null +++ b/llvm/lib/Analysis/DemandedBits.cpp @@ -0,0 +1,488 @@ +//===- DemandedBits.cpp - Determine demanded bits -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements a demanded bits analysis. A demanded bit is one that +// contributes to a result; bits that are not demanded can be either zero or +// one without affecting control or data flow. For example in this sequence: +// +// %1 = add i32 %x, %y +// %2 = trunc i32 %1 to i16 +// +// Only the lowest 16 bits of %1 are demanded; the rest are removed by the +// trunc. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cstdint> + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "demanded-bits" + +char DemandedBitsWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits", + "Demanded bits analysis", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits", + "Demanded bits analysis", false, false) + +DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) { + initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.setPreservesAll(); +} + +void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const { + DB->print(OS); +} + +static bool isAlwaysLive(Instruction *I) { + return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() || + I->mayHaveSideEffects(); +} + +void DemandedBits::determineLiveOperandBits( + const Instruction *UserI, const Value *Val, unsigned OperandNo, + const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2, + bool &KnownBitsComputed) { + unsigned BitWidth = AB.getBitWidth(); + + // We're called once per operand, but for some instructions, we need to + // compute known bits of both operands in order to determine the live bits of + // either (when both operands are instructions themselves). We don't, + // however, want to do this twice, so we cache the result in APInts that live + // in the caller. For the two-relevant-operands case, both operand values are + // provided here. + auto ComputeKnownBits = + [&](unsigned BitWidth, const Value *V1, const Value *V2) { + if (KnownBitsComputed) + return; + KnownBitsComputed = true; + + const DataLayout &DL = UserI->getModule()->getDataLayout(); + Known = KnownBits(BitWidth); + computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT); + + if (V2) { + Known2 = KnownBits(BitWidth); + computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT); + } + }; + + switch (UserI->getOpcode()) { + default: break; + case Instruction::Call: + case Instruction::Invoke: + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI)) + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::bswap: + // The alive bits of the input are the swapped alive bits of + // the output. + AB = AOut.byteSwap(); + break; + case Intrinsic::bitreverse: + // The alive bits of the input are the reversed alive bits of + // the output. + AB = AOut.reverseBits(); + break; + case Intrinsic::ctlz: + if (OperandNo == 0) { + // We need some output bits, so we need all bits of the + // input to the left of, and including, the leftmost bit + // known to be one. + ComputeKnownBits(BitWidth, Val, nullptr); + AB = APInt::getHighBitsSet(BitWidth, + std::min(BitWidth, Known.countMaxLeadingZeros()+1)); + } + break; + case Intrinsic::cttz: + if (OperandNo == 0) { + // We need some output bits, so we need all bits of the + // input to the right of, and including, the rightmost bit + // known to be one. + ComputeKnownBits(BitWidth, Val, nullptr); + AB = APInt::getLowBitsSet(BitWidth, + std::min(BitWidth, Known.countMaxTrailingZeros()+1)); + } + break; + case Intrinsic::fshl: + case Intrinsic::fshr: { + const APInt *SA; + if (OperandNo == 2) { + // Shift amount is modulo the bitwidth. For powers of two we have + // SA % BW == SA & (BW - 1). + if (isPowerOf2_32(BitWidth)) + AB = BitWidth - 1; + } else if (match(II->getOperand(2), m_APInt(SA))) { + // Normalize to funnel shift left. APInt shifts of BitWidth are well- + // defined, so no need to special-case zero shifts here. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + if (OperandNo == 0) + AB = AOut.lshr(ShiftAmt); + else if (OperandNo == 1) + AB = AOut.shl(BitWidth - ShiftAmt); + } + break; + } + } + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + // Find the highest live output bit. We don't need any more input + // bits than that (adds, and thus subtracts, ripple only to the + // left). + AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); + break; + case Instruction::Shl: + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { + uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); + AB = AOut.lshr(ShiftAmt); + + // If the shift is nuw/nsw, then the high bits are not dead + // (because we've promised that they *must* be zero). + const ShlOperator *S = cast<ShlOperator>(UserI); + if (S->hasNoSignedWrap()) + AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); + else if (S->hasNoUnsignedWrap()) + AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt); + } + } + break; + case Instruction::LShr: + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { + uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); + AB = AOut.shl(ShiftAmt); + + // If the shift is exact, then the low bits are not dead + // (they must be zero). + if (cast<LShrOperator>(UserI)->isExact()) + AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + } + } + break; + case Instruction::AShr: + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { + uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); + AB = AOut.shl(ShiftAmt); + // Because the high input bit is replicated into the + // high-order bits of the result, if we need any of those + // bits, then we must keep the highest input bit. + if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt)) + .getBoolValue()) + AB.setSignBit(); + + // If the shift is exact, then the low bits are not dead + // (they must be zero). + if (cast<AShrOperator>(UserI)->isExact()) + AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + } + } + break; + case Instruction::And: + AB = AOut; + + // For bits that are known zero, the corresponding bits in the + // other operand are dead (unless they're both zero, in which + // case they can't both be dead, so just mark the LHS bits as + // dead). + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + if (OperandNo == 0) + AB &= ~Known2.Zero; + else + AB &= ~(Known.Zero & ~Known2.Zero); + break; + case Instruction::Or: + AB = AOut; + + // For bits that are known one, the corresponding bits in the + // other operand are dead (unless they're both one, in which + // case they can't both be dead, so just mark the LHS bits as + // dead). + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + if (OperandNo == 0) + AB &= ~Known2.One; + else + AB &= ~(Known.One & ~Known2.One); + break; + case Instruction::Xor: + case Instruction::PHI: + AB = AOut; + break; + case Instruction::Trunc: + AB = AOut.zext(BitWidth); + break; + case Instruction::ZExt: + AB = AOut.trunc(BitWidth); + break; + case Instruction::SExt: + AB = AOut.trunc(BitWidth); + // Because the high input bit is replicated into the + // high-order bits of the result, if we need any of those + // bits, then we must keep the highest input bit. + if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(), + AOut.getBitWidth() - BitWidth)) + .getBoolValue()) + AB.setSignBit(); + break; + case Instruction::Select: + if (OperandNo != 0) + AB = AOut; + break; + case Instruction::ExtractElement: + if (OperandNo == 0) + AB = AOut; + break; + case Instruction::InsertElement: + case Instruction::ShuffleVector: + if (OperandNo == 0 || OperandNo == 1) + AB = AOut; + break; + } +} + +bool DemandedBitsWrapperPass::runOnFunction(Function &F) { + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + DB.emplace(F, AC, DT); + return false; +} + +void DemandedBitsWrapperPass::releaseMemory() { + DB.reset(); +} + +void DemandedBits::performAnalysis() { + if (Analyzed) + // Analysis already completed for this function. + return; + Analyzed = true; + + Visited.clear(); + AliveBits.clear(); + DeadUses.clear(); + + SmallSetVector<Instruction*, 16> Worklist; + + // Collect the set of "root" instructions that are known live. + for (Instruction &I : instructions(F)) { + if (!isAlwaysLive(&I)) + continue; + + LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n"); + // For integer-valued instructions, set up an initial empty set of alive + // bits and add the instruction to the work list. For other instructions + // add their operands to the work list (for integer values operands, mark + // all bits as live). + Type *T = I.getType(); + if (T->isIntOrIntVectorTy()) { + if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second) + Worklist.insert(&I); + + continue; + } + + // Non-integer-typed instructions... + for (Use &OI : I.operands()) { + if (Instruction *J = dyn_cast<Instruction>(OI)) { + Type *T = J->getType(); + if (T->isIntOrIntVectorTy()) + AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits()); + else + Visited.insert(J); + Worklist.insert(J); + } + } + // To save memory, we don't add I to the Visited set here. Instead, we + // check isAlwaysLive on every instruction when searching for dead + // instructions later (we need to check isAlwaysLive for the + // integer-typed instructions anyway). + } + + // Propagate liveness backwards to operands. + while (!Worklist.empty()) { + Instruction *UserI = Worklist.pop_back_val(); + + LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI); + APInt AOut; + bool InputIsKnownDead = false; + if (UserI->getType()->isIntOrIntVectorTy()) { + AOut = AliveBits[UserI]; + LLVM_DEBUG(dbgs() << " Alive Out: 0x" + << Twine::utohexstr(AOut.getLimitedValue())); + + // If all bits of the output are dead, then all bits of the input + // are also dead. + InputIsKnownDead = !AOut && !isAlwaysLive(UserI); + } + LLVM_DEBUG(dbgs() << "\n"); + + KnownBits Known, Known2; + bool KnownBitsComputed = false; + // Compute the set of alive bits for each operand. These are anded into the + // existing set, if any, and if that changes the set of alive bits, the + // operand is added to the work-list. + for (Use &OI : UserI->operands()) { + // We also want to detect dead uses of arguments, but will only store + // demanded bits for instructions. + Instruction *I = dyn_cast<Instruction>(OI); + if (!I && !isa<Argument>(OI)) + continue; + + Type *T = OI->getType(); + if (T->isIntOrIntVectorTy()) { + unsigned BitWidth = T->getScalarSizeInBits(); + APInt AB = APInt::getAllOnesValue(BitWidth); + if (InputIsKnownDead) { + AB = APInt(BitWidth, 0); + } else { + // Bits of each operand that are used to compute alive bits of the + // output are alive, all others are dead. + determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB, + Known, Known2, KnownBitsComputed); + + // Keep track of uses which have no demanded bits. + if (AB.isNullValue()) + DeadUses.insert(&OI); + else + DeadUses.erase(&OI); + } + + if (I) { + // If we've added to the set of alive bits (or the operand has not + // been previously visited), then re-queue the operand to be visited + // again. + auto Res = AliveBits.try_emplace(I); + if (Res.second || (AB |= Res.first->second) != Res.first->second) { + Res.first->second = std::move(AB); + Worklist.insert(I); + } + } + } else if (I && Visited.insert(I).second) { + Worklist.insert(I); + } + } + } +} + +APInt DemandedBits::getDemandedBits(Instruction *I) { + performAnalysis(); + + auto Found = AliveBits.find(I); + if (Found != AliveBits.end()) + return Found->second; + + const DataLayout &DL = I->getModule()->getDataLayout(); + return APInt::getAllOnesValue( + DL.getTypeSizeInBits(I->getType()->getScalarType())); +} + +bool DemandedBits::isInstructionDead(Instruction *I) { + performAnalysis(); + + return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() && + !isAlwaysLive(I); +} + +bool DemandedBits::isUseDead(Use *U) { + // We only track integer uses, everything else is assumed live. + if (!(*U)->getType()->isIntOrIntVectorTy()) + return false; + + // Uses by always-live instructions are never dead. + Instruction *UserI = cast<Instruction>(U->getUser()); + if (isAlwaysLive(UserI)) + return false; + + performAnalysis(); + if (DeadUses.count(U)) + return true; + + // If no output bits are demanded, no input bits are demanded and the use + // is dead. These uses might not be explicitly present in the DeadUses map. + if (UserI->getType()->isIntOrIntVectorTy()) { + auto Found = AliveBits.find(UserI); + if (Found != AliveBits.end() && Found->second.isNullValue()) + return true; + } + + return false; +} + +void DemandedBits::print(raw_ostream &OS) { + performAnalysis(); + for (auto &KV : AliveBits) { + OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue()) + << " for " << *KV.first << '\n'; + } +} + +FunctionPass *llvm::createDemandedBitsWrapperPass() { + return new DemandedBitsWrapperPass(); +} + +AnalysisKey DemandedBitsAnalysis::Key; + +DemandedBits DemandedBitsAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + return DemandedBits(F, AC, DT); +} + +PreservedAnalyses DemandedBitsPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<DemandedBitsAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp new file mode 100644 index 000000000000..0038c9fb9ce4 --- /dev/null +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -0,0 +1,4009 @@ +//===-- DependenceAnalysis.cpp - DA Implementation --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DependenceAnalysis is an LLVM pass that analyses dependences between memory +// accesses. Currently, it is an (incomplete) implementation of the approach +// described in +// +// Practical Dependence Testing +// Goff, Kennedy, Tseng +// PLDI 1991 +// +// There's a single entry point that analyzes the dependence between a pair +// of memory references in a function, returning either NULL, for no dependence, +// or a more-or-less detailed description of the dependence between them. +// +// Currently, the implementation cannot propagate constraints between +// coupled RDIV subscripts and lacks a multi-subscript MIV test. +// Both of these are conservative weaknesses; +// that is, not a source of correctness problems. +// +// Since Clang linearizes some array subscripts, the dependence +// analysis is using SCEV->delinearize to recover the representation of multiple +// subscripts, and thus avoid the more expensive and less precise MIV tests. The +// delinearization is controlled by the flag -da-delinearize. +// +// We should pay some careful attention to the possibility of integer overflow +// in the implementation of the various tests. This could happen with Add, +// Subtract, or Multiply, with both APInt's and SCEV's. +// +// Some non-linear subscript pairs can be handled by the GCD test +// (and perhaps other tests). +// Should explore how often these things occur. +// +// Finally, it seems like certain test cases expose weaknesses in the SCEV +// simplification, especially in the handling of sign and zero extensions. +// It could be useful to spend time exploring these. +// +// Please note that this is work in progress and the interface is subject to +// change. +// +//===----------------------------------------------------------------------===// +// // +// In memory of Ken Kennedy, 1945 - 2007 // +// // +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "da" + +//===----------------------------------------------------------------------===// +// statistics + +STATISTIC(TotalArrayPairs, "Array pairs tested"); +STATISTIC(SeparableSubscriptPairs, "Separable subscript pairs"); +STATISTIC(CoupledSubscriptPairs, "Coupled subscript pairs"); +STATISTIC(NonlinearSubscriptPairs, "Nonlinear subscript pairs"); +STATISTIC(ZIVapplications, "ZIV applications"); +STATISTIC(ZIVindependence, "ZIV independence"); +STATISTIC(StrongSIVapplications, "Strong SIV applications"); +STATISTIC(StrongSIVsuccesses, "Strong SIV successes"); +STATISTIC(StrongSIVindependence, "Strong SIV independence"); +STATISTIC(WeakCrossingSIVapplications, "Weak-Crossing SIV applications"); +STATISTIC(WeakCrossingSIVsuccesses, "Weak-Crossing SIV successes"); +STATISTIC(WeakCrossingSIVindependence, "Weak-Crossing SIV independence"); +STATISTIC(ExactSIVapplications, "Exact SIV applications"); +STATISTIC(ExactSIVsuccesses, "Exact SIV successes"); +STATISTIC(ExactSIVindependence, "Exact SIV independence"); +STATISTIC(WeakZeroSIVapplications, "Weak-Zero SIV applications"); +STATISTIC(WeakZeroSIVsuccesses, "Weak-Zero SIV successes"); +STATISTIC(WeakZeroSIVindependence, "Weak-Zero SIV independence"); +STATISTIC(ExactRDIVapplications, "Exact RDIV applications"); +STATISTIC(ExactRDIVindependence, "Exact RDIV independence"); +STATISTIC(SymbolicRDIVapplications, "Symbolic RDIV applications"); +STATISTIC(SymbolicRDIVindependence, "Symbolic RDIV independence"); +STATISTIC(DeltaApplications, "Delta applications"); +STATISTIC(DeltaSuccesses, "Delta successes"); +STATISTIC(DeltaIndependence, "Delta independence"); +STATISTIC(DeltaPropagations, "Delta propagations"); +STATISTIC(GCDapplications, "GCD applications"); +STATISTIC(GCDsuccesses, "GCD successes"); +STATISTIC(GCDindependence, "GCD independence"); +STATISTIC(BanerjeeApplications, "Banerjee applications"); +STATISTIC(BanerjeeIndependence, "Banerjee independence"); +STATISTIC(BanerjeeSuccesses, "Banerjee successes"); + +static cl::opt<bool> + Delinearize("da-delinearize", cl::init(true), cl::Hidden, cl::ZeroOrMore, + cl::desc("Try to delinearize array references.")); +static cl::opt<bool> DisableDelinearizationChecks( + "da-disable-delinearization-checks", cl::init(false), cl::Hidden, + cl::ZeroOrMore, + cl::desc( + "Disable checks that try to statically verify validity of " + "delinearized subscripts. Enabling this option may result in incorrect " + "dependence vectors for languages that allow the subscript of one " + "dimension to underflow or overflow into another dimension.")); + +//===----------------------------------------------------------------------===// +// basics + +DependenceAnalysis::Result +DependenceAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AA = FAM.getResult<AAManager>(F); + auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = FAM.getResult<LoopAnalysis>(F); + return DependenceInfo(&F, &AA, &SE, &LI); +} + +AnalysisKey DependenceAnalysis::Key; + +INITIALIZE_PASS_BEGIN(DependenceAnalysisWrapperPass, "da", + "Dependence Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(DependenceAnalysisWrapperPass, "da", "Dependence Analysis", + true, true) + +char DependenceAnalysisWrapperPass::ID = 0; + +FunctionPass *llvm::createDependenceAnalysisWrapperPass() { + return new DependenceAnalysisWrapperPass(); +} + +bool DependenceAnalysisWrapperPass::runOnFunction(Function &F) { + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + info.reset(new DependenceInfo(&F, &AA, &SE, &LI)); + return false; +} + +DependenceInfo &DependenceAnalysisWrapperPass::getDI() const { return *info; } + +void DependenceAnalysisWrapperPass::releaseMemory() { info.reset(); } + +void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); + AU.addRequiredTransitive<ScalarEvolutionWrapperPass>(); + AU.addRequiredTransitive<LoopInfoWrapperPass>(); +} + + +// Used to test the dependence analyzer. +// Looks through the function, noting loads and stores. +// Calls depends() on every possible pair and prints out the result. +// Ignores all other instructions. +static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA) { + auto *F = DA->getFunction(); + for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE; + ++SrcI) { + if (isa<StoreInst>(*SrcI) || isa<LoadInst>(*SrcI)) { + for (inst_iterator DstI = SrcI, DstE = inst_end(F); + DstI != DstE; ++DstI) { + if (isa<StoreInst>(*DstI) || isa<LoadInst>(*DstI)) { + OS << "da analyze - "; + if (auto D = DA->depends(&*SrcI, &*DstI, true)) { + D->dump(OS); + for (unsigned Level = 1; Level <= D->getLevels(); Level++) { + if (D->isSplitable(Level)) { + OS << "da analyze - split level = " << Level; + OS << ", iteration = " << *DA->getSplitIteration(*D, Level); + OS << "!\n"; + } + } + } + else + OS << "none!\n"; + } + } + } + } +} + +void DependenceAnalysisWrapperPass::print(raw_ostream &OS, + const Module *) const { + dumpExampleDependence(OS, info.get()); +} + +PreservedAnalyses +DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { + OS << "'Dependence Analysis' for function '" << F.getName() << "':\n"; + dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F)); + return PreservedAnalyses::all(); +} + +//===----------------------------------------------------------------------===// +// Dependence methods + +// Returns true if this is an input dependence. +bool Dependence::isInput() const { + return Src->mayReadFromMemory() && Dst->mayReadFromMemory(); +} + + +// Returns true if this is an output dependence. +bool Dependence::isOutput() const { + return Src->mayWriteToMemory() && Dst->mayWriteToMemory(); +} + + +// Returns true if this is an flow (aka true) dependence. +bool Dependence::isFlow() const { + return Src->mayWriteToMemory() && Dst->mayReadFromMemory(); +} + + +// Returns true if this is an anti dependence. +bool Dependence::isAnti() const { + return Src->mayReadFromMemory() && Dst->mayWriteToMemory(); +} + + +// Returns true if a particular level is scalar; that is, +// if no subscript in the source or destination mention the induction +// variable associated with the loop at this level. +// Leave this out of line, so it will serve as a virtual method anchor +bool Dependence::isScalar(unsigned level) const { + return false; +} + + +//===----------------------------------------------------------------------===// +// FullDependence methods + +FullDependence::FullDependence(Instruction *Source, Instruction *Destination, + bool PossiblyLoopIndependent, + unsigned CommonLevels) + : Dependence(Source, Destination), Levels(CommonLevels), + LoopIndependent(PossiblyLoopIndependent) { + Consistent = true; + if (CommonLevels) + DV = std::make_unique<DVEntry[]>(CommonLevels); +} + +// The rest are simple getters that hide the implementation. + +// getDirection - Returns the direction associated with a particular level. +unsigned FullDependence::getDirection(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].Direction; +} + + +// Returns the distance (or NULL) associated with a particular level. +const SCEV *FullDependence::getDistance(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].Distance; +} + + +// Returns true if a particular level is scalar; that is, +// if no subscript in the source or destination mention the induction +// variable associated with the loop at this level. +bool FullDependence::isScalar(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].Scalar; +} + + +// Returns true if peeling the first iteration from this loop +// will break this dependence. +bool FullDependence::isPeelFirst(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].PeelFirst; +} + + +// Returns true if peeling the last iteration from this loop +// will break this dependence. +bool FullDependence::isPeelLast(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].PeelLast; +} + + +// Returns true if splitting this loop will break the dependence. +bool FullDependence::isSplitable(unsigned Level) const { + assert(0 < Level && Level <= Levels && "Level out of range"); + return DV[Level - 1].Splitable; +} + + +//===----------------------------------------------------------------------===// +// DependenceInfo::Constraint methods + +// If constraint is a point <X, Y>, returns X. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getX() const { + assert(Kind == Point && "Kind should be Point"); + return A; +} + + +// If constraint is a point <X, Y>, returns Y. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getY() const { + assert(Kind == Point && "Kind should be Point"); + return B; +} + + +// If constraint is a line AX + BY = C, returns A. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getA() const { + assert((Kind == Line || Kind == Distance) && + "Kind should be Line (or Distance)"); + return A; +} + + +// If constraint is a line AX + BY = C, returns B. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getB() const { + assert((Kind == Line || Kind == Distance) && + "Kind should be Line (or Distance)"); + return B; +} + + +// If constraint is a line AX + BY = C, returns C. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getC() const { + assert((Kind == Line || Kind == Distance) && + "Kind should be Line (or Distance)"); + return C; +} + + +// If constraint is a distance, returns D. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getD() const { + assert(Kind == Distance && "Kind should be Distance"); + return SE->getNegativeSCEV(C); +} + + +// Returns the loop associated with this constraint. +const Loop *DependenceInfo::Constraint::getAssociatedLoop() const { + assert((Kind == Distance || Kind == Line || Kind == Point) && + "Kind should be Distance, Line, or Point"); + return AssociatedLoop; +} + +void DependenceInfo::Constraint::setPoint(const SCEV *X, const SCEV *Y, + const Loop *CurLoop) { + Kind = Point; + A = X; + B = Y; + AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setLine(const SCEV *AA, const SCEV *BB, + const SCEV *CC, const Loop *CurLoop) { + Kind = Line; + A = AA; + B = BB; + C = CC; + AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setDistance(const SCEV *D, + const Loop *CurLoop) { + Kind = Distance; + A = SE->getOne(D->getType()); + B = SE->getNegativeSCEV(A); + C = SE->getNegativeSCEV(D); + AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setEmpty() { Kind = Empty; } + +void DependenceInfo::Constraint::setAny(ScalarEvolution *NewSE) { + SE = NewSE; + Kind = Any; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +// For debugging purposes. Dumps the constraint out to OS. +LLVM_DUMP_METHOD void DependenceInfo::Constraint::dump(raw_ostream &OS) const { + if (isEmpty()) + OS << " Empty\n"; + else if (isAny()) + OS << " Any\n"; + else if (isPoint()) + OS << " Point is <" << *getX() << ", " << *getY() << ">\n"; + else if (isDistance()) + OS << " Distance is " << *getD() << + " (" << *getA() << "*X + " << *getB() << "*Y = " << *getC() << ")\n"; + else if (isLine()) + OS << " Line is " << *getA() << "*X + " << + *getB() << "*Y = " << *getC() << "\n"; + else + llvm_unreachable("unknown constraint type in Constraint::dump"); +} +#endif + + +// Updates X with the intersection +// of the Constraints X and Y. Returns true if X has changed. +// Corresponds to Figure 4 from the paper +// +// Practical Dependence Testing +// Goff, Kennedy, Tseng +// PLDI 1991 +bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) { + ++DeltaApplications; + LLVM_DEBUG(dbgs() << "\tintersect constraints\n"); + LLVM_DEBUG(dbgs() << "\t X ="; X->dump(dbgs())); + LLVM_DEBUG(dbgs() << "\t Y ="; Y->dump(dbgs())); + assert(!Y->isPoint() && "Y must not be a Point"); + if (X->isAny()) { + if (Y->isAny()) + return false; + *X = *Y; + return true; + } + if (X->isEmpty()) + return false; + if (Y->isEmpty()) { + X->setEmpty(); + return true; + } + + if (X->isDistance() && Y->isDistance()) { + LLVM_DEBUG(dbgs() << "\t intersect 2 distances\n"); + if (isKnownPredicate(CmpInst::ICMP_EQ, X->getD(), Y->getD())) + return false; + if (isKnownPredicate(CmpInst::ICMP_NE, X->getD(), Y->getD())) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + // Hmmm, interesting situation. + // I guess if either is constant, keep it and ignore the other. + if (isa<SCEVConstant>(Y->getD())) { + *X = *Y; + return true; + } + return false; + } + + // At this point, the pseudo-code in Figure 4 of the paper + // checks if (X->isPoint() && Y->isPoint()). + // This case can't occur in our implementation, + // since a Point can only arise as the result of intersecting + // two Line constraints, and the right-hand value, Y, is never + // the result of an intersection. + assert(!(X->isPoint() && Y->isPoint()) && + "We shouldn't ever see X->isPoint() && Y->isPoint()"); + + if (X->isLine() && Y->isLine()) { + LLVM_DEBUG(dbgs() << "\t intersect 2 lines\n"); + const SCEV *Prod1 = SE->getMulExpr(X->getA(), Y->getB()); + const SCEV *Prod2 = SE->getMulExpr(X->getB(), Y->getA()); + if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2)) { + // slopes are equal, so lines are parallel + LLVM_DEBUG(dbgs() << "\t\tsame slope\n"); + Prod1 = SE->getMulExpr(X->getC(), Y->getB()); + Prod2 = SE->getMulExpr(X->getB(), Y->getC()); + if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2)) + return false; + if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + return false; + } + if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) { + // slopes differ, so lines intersect + LLVM_DEBUG(dbgs() << "\t\tdifferent slopes\n"); + const SCEV *C1B2 = SE->getMulExpr(X->getC(), Y->getB()); + const SCEV *C1A2 = SE->getMulExpr(X->getC(), Y->getA()); + const SCEV *C2B1 = SE->getMulExpr(Y->getC(), X->getB()); + const SCEV *C2A1 = SE->getMulExpr(Y->getC(), X->getA()); + const SCEV *A1B2 = SE->getMulExpr(X->getA(), Y->getB()); + const SCEV *A2B1 = SE->getMulExpr(Y->getA(), X->getB()); + const SCEVConstant *C1A2_C2A1 = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1A2, C2A1)); + const SCEVConstant *C1B2_C2B1 = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1B2, C2B1)); + const SCEVConstant *A1B2_A2B1 = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(A1B2, A2B1)); + const SCEVConstant *A2B1_A1B2 = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(A2B1, A1B2)); + if (!C1B2_C2B1 || !C1A2_C2A1 || + !A1B2_A2B1 || !A2B1_A1B2) + return false; + APInt Xtop = C1B2_C2B1->getAPInt(); + APInt Xbot = A1B2_A2B1->getAPInt(); + APInt Ytop = C1A2_C2A1->getAPInt(); + APInt Ybot = A2B1_A1B2->getAPInt(); + LLVM_DEBUG(dbgs() << "\t\tXtop = " << Xtop << "\n"); + LLVM_DEBUG(dbgs() << "\t\tXbot = " << Xbot << "\n"); + LLVM_DEBUG(dbgs() << "\t\tYtop = " << Ytop << "\n"); + LLVM_DEBUG(dbgs() << "\t\tYbot = " << Ybot << "\n"); + APInt Xq = Xtop; // these need to be initialized, even + APInt Xr = Xtop; // though they're just going to be overwritten + APInt::sdivrem(Xtop, Xbot, Xq, Xr); + APInt Yq = Ytop; + APInt Yr = Ytop; + APInt::sdivrem(Ytop, Ybot, Yq, Yr); + if (Xr != 0 || Yr != 0) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + LLVM_DEBUG(dbgs() << "\t\tX = " << Xq << ", Y = " << Yq << "\n"); + if (Xq.slt(0) || Yq.slt(0)) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + if (const SCEVConstant *CUB = + collectConstantUpperBound(X->getAssociatedLoop(), Prod1->getType())) { + const APInt &UpperBound = CUB->getAPInt(); + LLVM_DEBUG(dbgs() << "\t\tupper bound = " << UpperBound << "\n"); + if (Xq.sgt(UpperBound) || Yq.sgt(UpperBound)) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + } + X->setPoint(SE->getConstant(Xq), + SE->getConstant(Yq), + X->getAssociatedLoop()); + ++DeltaSuccesses; + return true; + } + return false; + } + + // if (X->isLine() && Y->isPoint()) This case can't occur. + assert(!(X->isLine() && Y->isPoint()) && "This case should never occur"); + + if (X->isPoint() && Y->isLine()) { + LLVM_DEBUG(dbgs() << "\t intersect Point and Line\n"); + const SCEV *A1X1 = SE->getMulExpr(Y->getA(), X->getX()); + const SCEV *B1Y1 = SE->getMulExpr(Y->getB(), X->getY()); + const SCEV *Sum = SE->getAddExpr(A1X1, B1Y1); + if (isKnownPredicate(CmpInst::ICMP_EQ, Sum, Y->getC())) + return false; + if (isKnownPredicate(CmpInst::ICMP_NE, Sum, Y->getC())) { + X->setEmpty(); + ++DeltaSuccesses; + return true; + } + return false; + } + + llvm_unreachable("shouldn't reach the end of Constraint intersection"); + return false; +} + + +//===----------------------------------------------------------------------===// +// DependenceInfo methods + +// For debugging purposes. Dumps a dependence to OS. +void Dependence::dump(raw_ostream &OS) const { + bool Splitable = false; + if (isConfused()) + OS << "confused"; + else { + if (isConsistent()) + OS << "consistent "; + if (isFlow()) + OS << "flow"; + else if (isOutput()) + OS << "output"; + else if (isAnti()) + OS << "anti"; + else if (isInput()) + OS << "input"; + unsigned Levels = getLevels(); + OS << " ["; + for (unsigned II = 1; II <= Levels; ++II) { + if (isSplitable(II)) + Splitable = true; + if (isPeelFirst(II)) + OS << 'p'; + const SCEV *Distance = getDistance(II); + if (Distance) + OS << *Distance; + else if (isScalar(II)) + OS << "S"; + else { + unsigned Direction = getDirection(II); + if (Direction == DVEntry::ALL) + OS << "*"; + else { + if (Direction & DVEntry::LT) + OS << "<"; + if (Direction & DVEntry::EQ) + OS << "="; + if (Direction & DVEntry::GT) + OS << ">"; + } + } + if (isPeelLast(II)) + OS << 'p'; + if (II < Levels) + OS << " "; + } + if (isLoopIndependent()) + OS << "|<"; + OS << "]"; + if (Splitable) + OS << " splitable"; + } + OS << "!\n"; +} + +// Returns NoAlias/MayAliass/MustAlias for two memory locations based upon their +// underlaying objects. If LocA and LocB are known to not alias (for any reason: +// tbaa, non-overlapping regions etc), then it is known there is no dependecy. +// Otherwise the underlying objects are checked to see if they point to +// different identifiable objects. +static AliasResult underlyingObjectsAlias(AliasAnalysis *AA, + const DataLayout &DL, + const MemoryLocation &LocA, + const MemoryLocation &LocB) { + // Check the original locations (minus size) for noalias, which can happen for + // tbaa, incompatible underlying object locations, etc. + MemoryLocation LocAS(LocA.Ptr, LocationSize::unknown(), LocA.AATags); + MemoryLocation LocBS(LocB.Ptr, LocationSize::unknown(), LocB.AATags); + if (AA->alias(LocAS, LocBS) == NoAlias) + return NoAlias; + + // Check the underlying objects are the same + const Value *AObj = GetUnderlyingObject(LocA.Ptr, DL); + const Value *BObj = GetUnderlyingObject(LocB.Ptr, DL); + + // If the underlying objects are the same, they must alias + if (AObj == BObj) + return MustAlias; + + // We may have hit the recursion limit for underlying objects, or have + // underlying objects where we don't know they will alias. + if (!isIdentifiedObject(AObj) || !isIdentifiedObject(BObj)) + return MayAlias; + + // Otherwise we know the objects are different and both identified objects so + // must not alias. + return NoAlias; +} + + +// Returns true if the load or store can be analyzed. Atomic and volatile +// operations have properties which this analysis does not understand. +static +bool isLoadOrStore(const Instruction *I) { + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + return LI->isUnordered(); + else if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->isUnordered(); + return false; +} + + +// Examines the loop nesting of the Src and Dst +// instructions and establishes their shared loops. Sets the variables +// CommonLevels, SrcLevels, and MaxLevels. +// The source and destination instructions needn't be contained in the same +// loop. The routine establishNestingLevels finds the level of most deeply +// nested loop that contains them both, CommonLevels. An instruction that's +// not contained in a loop is at level = 0. MaxLevels is equal to the level +// of the source plus the level of the destination, minus CommonLevels. +// This lets us allocate vectors MaxLevels in length, with room for every +// distinct loop referenced in both the source and destination subscripts. +// The variable SrcLevels is the nesting depth of the source instruction. +// It's used to help calculate distinct loops referenced by the destination. +// Here's the map from loops to levels: +// 0 - unused +// 1 - outermost common loop +// ... - other common loops +// CommonLevels - innermost common loop +// ... - loops containing Src but not Dst +// SrcLevels - innermost loop containing Src but not Dst +// ... - loops containing Dst but not Src +// MaxLevels - innermost loops containing Dst but not Src +// Consider the follow code fragment: +// for (a = ...) { +// for (b = ...) { +// for (c = ...) { +// for (d = ...) { +// A[] = ...; +// } +// } +// for (e = ...) { +// for (f = ...) { +// for (g = ...) { +// ... = A[]; +// } +// } +// } +// } +// } +// If we're looking at the possibility of a dependence between the store +// to A (the Src) and the load from A (the Dst), we'll note that they +// have 2 loops in common, so CommonLevels will equal 2 and the direction +// vector for Result will have 2 entries. SrcLevels = 4 and MaxLevels = 7. +// A map from loop names to loop numbers would look like +// a - 1 +// b - 2 = CommonLevels +// c - 3 +// d - 4 = SrcLevels +// e - 5 +// f - 6 +// g - 7 = MaxLevels +void DependenceInfo::establishNestingLevels(const Instruction *Src, + const Instruction *Dst) { + const BasicBlock *SrcBlock = Src->getParent(); + const BasicBlock *DstBlock = Dst->getParent(); + unsigned SrcLevel = LI->getLoopDepth(SrcBlock); + unsigned DstLevel = LI->getLoopDepth(DstBlock); + const Loop *SrcLoop = LI->getLoopFor(SrcBlock); + const Loop *DstLoop = LI->getLoopFor(DstBlock); + SrcLevels = SrcLevel; + MaxLevels = SrcLevel + DstLevel; + while (SrcLevel > DstLevel) { + SrcLoop = SrcLoop->getParentLoop(); + SrcLevel--; + } + while (DstLevel > SrcLevel) { + DstLoop = DstLoop->getParentLoop(); + DstLevel--; + } + while (SrcLoop != DstLoop) { + SrcLoop = SrcLoop->getParentLoop(); + DstLoop = DstLoop->getParentLoop(); + SrcLevel--; + } + CommonLevels = SrcLevel; + MaxLevels -= CommonLevels; +} + + +// Given one of the loops containing the source, return +// its level index in our numbering scheme. +unsigned DependenceInfo::mapSrcLoop(const Loop *SrcLoop) const { + return SrcLoop->getLoopDepth(); +} + + +// Given one of the loops containing the destination, +// return its level index in our numbering scheme. +unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const { + unsigned D = DstLoop->getLoopDepth(); + if (D > CommonLevels) + return D - CommonLevels + SrcLevels; + else + return D; +} + + +// Returns true if Expression is loop invariant in LoopNest. +bool DependenceInfo::isLoopInvariant(const SCEV *Expression, + const Loop *LoopNest) const { + if (!LoopNest) + return true; + return SE->isLoopInvariant(Expression, LoopNest) && + isLoopInvariant(Expression, LoopNest->getParentLoop()); +} + + + +// Finds the set of loops from the LoopNest that +// have a level <= CommonLevels and are referred to by the SCEV Expression. +void DependenceInfo::collectCommonLoops(const SCEV *Expression, + const Loop *LoopNest, + SmallBitVector &Loops) const { + while (LoopNest) { + unsigned Level = LoopNest->getLoopDepth(); + if (Level <= CommonLevels && !SE->isLoopInvariant(Expression, LoopNest)) + Loops.set(Level); + LoopNest = LoopNest->getParentLoop(); + } +} + +void DependenceInfo::unifySubscriptType(ArrayRef<Subscript *> Pairs) { + + unsigned widestWidthSeen = 0; + Type *widestType; + + // Go through each pair and find the widest bit to which we need + // to extend all of them. + for (Subscript *Pair : Pairs) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; + IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); + IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); + if (SrcTy == nullptr || DstTy == nullptr) { + assert(SrcTy == DstTy && "This function only unify integer types and " + "expect Src and Dst share the same type " + "otherwise."); + continue; + } + if (SrcTy->getBitWidth() > widestWidthSeen) { + widestWidthSeen = SrcTy->getBitWidth(); + widestType = SrcTy; + } + if (DstTy->getBitWidth() > widestWidthSeen) { + widestWidthSeen = DstTy->getBitWidth(); + widestType = DstTy; + } + } + + + assert(widestWidthSeen > 0); + + // Now extend each pair to the widest seen. + for (Subscript *Pair : Pairs) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; + IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); + IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); + if (SrcTy == nullptr || DstTy == nullptr) { + assert(SrcTy == DstTy && "This function only unify integer types and " + "expect Src and Dst share the same type " + "otherwise."); + continue; + } + if (SrcTy->getBitWidth() < widestWidthSeen) + // Sign-extend Src to widestType + Pair->Src = SE->getSignExtendExpr(Src, widestType); + if (DstTy->getBitWidth() < widestWidthSeen) { + // Sign-extend Dst to widestType + Pair->Dst = SE->getSignExtendExpr(Dst, widestType); + } + } +} + +// removeMatchingExtensions - Examines a subscript pair. +// If the source and destination are identically sign (or zero) +// extended, it strips off the extension in an effect to simplify +// the actual analysis. +void DependenceInfo::removeMatchingExtensions(Subscript *Pair) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; + if ((isa<SCEVZeroExtendExpr>(Src) && isa<SCEVZeroExtendExpr>(Dst)) || + (isa<SCEVSignExtendExpr>(Src) && isa<SCEVSignExtendExpr>(Dst))) { + const SCEVCastExpr *SrcCast = cast<SCEVCastExpr>(Src); + const SCEVCastExpr *DstCast = cast<SCEVCastExpr>(Dst); + const SCEV *SrcCastOp = SrcCast->getOperand(); + const SCEV *DstCastOp = DstCast->getOperand(); + if (SrcCastOp->getType() == DstCastOp->getType()) { + Pair->Src = SrcCastOp; + Pair->Dst = DstCastOp; + } + } +} + + +// Examine the scev and return true iff it's linear. +// Collect any loops mentioned in the set of "Loops". +bool DependenceInfo::checkSrcSubscript(const SCEV *Src, const Loop *LoopNest, + SmallBitVector &Loops) { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Src); + if (!AddRec) + return isLoopInvariant(Src, LoopNest); + const SCEV *Start = AddRec->getStart(); + const SCEV *Step = AddRec->getStepRecurrence(*SE); + const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop()); + if (!isa<SCEVCouldNotCompute>(UB)) { + if (SE->getTypeSizeInBits(Start->getType()) < + SE->getTypeSizeInBits(UB->getType())) { + if (!AddRec->getNoWrapFlags()) + return false; + } + } + if (!isLoopInvariant(Step, LoopNest)) + return false; + Loops.set(mapSrcLoop(AddRec->getLoop())); + return checkSrcSubscript(Start, LoopNest, Loops); +} + + + +// Examine the scev and return true iff it's linear. +// Collect any loops mentioned in the set of "Loops". +bool DependenceInfo::checkDstSubscript(const SCEV *Dst, const Loop *LoopNest, + SmallBitVector &Loops) { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Dst); + if (!AddRec) + return isLoopInvariant(Dst, LoopNest); + const SCEV *Start = AddRec->getStart(); + const SCEV *Step = AddRec->getStepRecurrence(*SE); + const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop()); + if (!isa<SCEVCouldNotCompute>(UB)) { + if (SE->getTypeSizeInBits(Start->getType()) < + SE->getTypeSizeInBits(UB->getType())) { + if (!AddRec->getNoWrapFlags()) + return false; + } + } + if (!isLoopInvariant(Step, LoopNest)) + return false; + Loops.set(mapDstLoop(AddRec->getLoop())); + return checkDstSubscript(Start, LoopNest, Loops); +} + + +// Examines the subscript pair (the Src and Dst SCEVs) +// and classifies it as either ZIV, SIV, RDIV, MIV, or Nonlinear. +// Collects the associated loops in a set. +DependenceInfo::Subscript::ClassificationKind +DependenceInfo::classifyPair(const SCEV *Src, const Loop *SrcLoopNest, + const SCEV *Dst, const Loop *DstLoopNest, + SmallBitVector &Loops) { + SmallBitVector SrcLoops(MaxLevels + 1); + SmallBitVector DstLoops(MaxLevels + 1); + if (!checkSrcSubscript(Src, SrcLoopNest, SrcLoops)) + return Subscript::NonLinear; + if (!checkDstSubscript(Dst, DstLoopNest, DstLoops)) + return Subscript::NonLinear; + Loops = SrcLoops; + Loops |= DstLoops; + unsigned N = Loops.count(); + if (N == 0) + return Subscript::ZIV; + if (N == 1) + return Subscript::SIV; + if (N == 2 && (SrcLoops.count() == 0 || + DstLoops.count() == 0 || + (SrcLoops.count() == 1 && DstLoops.count() == 1))) + return Subscript::RDIV; + return Subscript::MIV; +} + + +// A wrapper around SCEV::isKnownPredicate. +// Looks for cases where we're interested in comparing for equality. +// If both X and Y have been identically sign or zero extended, +// it strips off the (confusing) extensions before invoking +// SCEV::isKnownPredicate. Perhaps, someday, the ScalarEvolution package +// will be similarly updated. +// +// If SCEV::isKnownPredicate can't prove the predicate, +// we try simple subtraction, which seems to help in some cases +// involving symbolics. +bool DependenceInfo::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *X, + const SCEV *Y) const { + if (Pred == CmpInst::ICMP_EQ || + Pred == CmpInst::ICMP_NE) { + if ((isa<SCEVSignExtendExpr>(X) && + isa<SCEVSignExtendExpr>(Y)) || + (isa<SCEVZeroExtendExpr>(X) && + isa<SCEVZeroExtendExpr>(Y))) { + const SCEVCastExpr *CX = cast<SCEVCastExpr>(X); + const SCEVCastExpr *CY = cast<SCEVCastExpr>(Y); + const SCEV *Xop = CX->getOperand(); + const SCEV *Yop = CY->getOperand(); + if (Xop->getType() == Yop->getType()) { + X = Xop; + Y = Yop; + } + } + } + if (SE->isKnownPredicate(Pred, X, Y)) + return true; + // If SE->isKnownPredicate can't prove the condition, + // we try the brute-force approach of subtracting + // and testing the difference. + // By testing with SE->isKnownPredicate first, we avoid + // the possibility of overflow when the arguments are constants. + const SCEV *Delta = SE->getMinusSCEV(X, Y); + switch (Pred) { + case CmpInst::ICMP_EQ: + return Delta->isZero(); + case CmpInst::ICMP_NE: + return SE->isKnownNonZero(Delta); + case CmpInst::ICMP_SGE: + return SE->isKnownNonNegative(Delta); + case CmpInst::ICMP_SLE: + return SE->isKnownNonPositive(Delta); + case CmpInst::ICMP_SGT: + return SE->isKnownPositive(Delta); + case CmpInst::ICMP_SLT: + return SE->isKnownNegative(Delta); + default: + llvm_unreachable("unexpected predicate in isKnownPredicate"); + } +} + +/// Compare to see if S is less than Size, using isKnownNegative(S - max(Size, 1)) +/// with some extra checking if S is an AddRec and we can prove less-than using +/// the loop bounds. +bool DependenceInfo::isKnownLessThan(const SCEV *S, const SCEV *Size) const { + // First unify to the same type + auto *SType = dyn_cast<IntegerType>(S->getType()); + auto *SizeType = dyn_cast<IntegerType>(Size->getType()); + if (!SType || !SizeType) + return false; + Type *MaxType = + (SType->getBitWidth() >= SizeType->getBitWidth()) ? SType : SizeType; + S = SE->getTruncateOrZeroExtend(S, MaxType); + Size = SE->getTruncateOrZeroExtend(Size, MaxType); + + // Special check for addrecs using BE taken count + const SCEV *Bound = SE->getMinusSCEV(S, Size); + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Bound)) { + if (AddRec->isAffine()) { + const SCEV *BECount = SE->getBackedgeTakenCount(AddRec->getLoop()); + if (!isa<SCEVCouldNotCompute>(BECount)) { + const SCEV *Limit = AddRec->evaluateAtIteration(BECount, *SE); + if (SE->isKnownNegative(Limit)) + return true; + } + } + } + + // Check using normal isKnownNegative + const SCEV *LimitedBound = + SE->getMinusSCEV(S, SE->getSMaxExpr(Size, SE->getOne(Size->getType()))); + return SE->isKnownNegative(LimitedBound); +} + +bool DependenceInfo::isKnownNonNegative(const SCEV *S, const Value *Ptr) const { + bool Inbounds = false; + if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(Ptr)) + Inbounds = SrcGEP->isInBounds(); + if (Inbounds) { + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { + if (AddRec->isAffine()) { + // We know S is for Ptr, the operand on a load/store, so doesn't wrap. + // If both parts are NonNegative, the end result will be NonNegative + if (SE->isKnownNonNegative(AddRec->getStart()) && + SE->isKnownNonNegative(AddRec->getOperand(1))) + return true; + } + } + } + + return SE->isKnownNonNegative(S); +} + +// All subscripts are all the same type. +// Loop bound may be smaller (e.g., a char). +// Should zero extend loop bound, since it's always >= 0. +// This routine collects upper bound and extends or truncates if needed. +// Truncating is safe when subscripts are known not to wrap. Cases without +// nowrap flags should have been rejected earlier. +// Return null if no bound available. +const SCEV *DependenceInfo::collectUpperBound(const Loop *L, Type *T) const { + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + const SCEV *UB = SE->getBackedgeTakenCount(L); + return SE->getTruncateOrZeroExtend(UB, T); + } + return nullptr; +} + + +// Calls collectUpperBound(), then attempts to cast it to SCEVConstant. +// If the cast fails, returns NULL. +const SCEVConstant *DependenceInfo::collectConstantUpperBound(const Loop *L, + Type *T) const { + if (const SCEV *UB = collectUpperBound(L, T)) + return dyn_cast<SCEVConstant>(UB); + return nullptr; +} + + +// testZIV - +// When we have a pair of subscripts of the form [c1] and [c2], +// where c1 and c2 are both loop invariant, we attack it using +// the ZIV test. Basically, we test by comparing the two values, +// but there are actually three possible results: +// 1) the values are equal, so there's a dependence +// 2) the values are different, so there's no dependence +// 3) the values might be equal, so we have to assume a dependence. +// +// Return true if dependence disproved. +bool DependenceInfo::testZIV(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { + LLVM_DEBUG(dbgs() << " src = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n"); + ++ZIVapplications; + if (isKnownPredicate(CmpInst::ICMP_EQ, Src, Dst)) { + LLVM_DEBUG(dbgs() << " provably dependent\n"); + return false; // provably dependent + } + if (isKnownPredicate(CmpInst::ICMP_NE, Src, Dst)) { + LLVM_DEBUG(dbgs() << " provably independent\n"); + ++ZIVindependence; + return true; // provably independent + } + LLVM_DEBUG(dbgs() << " possibly dependent\n"); + Result.Consistent = false; + return false; // possibly dependent +} + + +// strongSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.1 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2 + a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the Strong SIV test. +// +// Can prove independence. Failing that, can compute distance (and direction). +// In the presence of symbolic terms, we can sometimes make progress. +// +// If there's a dependence, +// +// c1 + a*i = c2 + a*i' +// +// The dependence distance is +// +// d = i' - i = (c1 - c2)/a +// +// A dependence only exists if d is an integer and abs(d) <= U, where U is the +// loop's upper bound. If a dependence exists, the dependence direction is +// defined as +// +// { < if d > 0 +// direction = { = if d = 0 +// { > if d < 0 +// +// Return true if dependence disproved. +bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, + const SCEV *DstConst, const Loop *CurLoop, + unsigned Level, FullDependence &Result, + Constraint &NewConstraint) const { + LLVM_DEBUG(dbgs() << "\tStrong SIV test\n"); + LLVM_DEBUG(dbgs() << "\t Coeff = " << *Coeff); + LLVM_DEBUG(dbgs() << ", " << *Coeff->getType() << "\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst); + LLVM_DEBUG(dbgs() << ", " << *SrcConst->getType() << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst); + LLVM_DEBUG(dbgs() << ", " << *DstConst->getType() << "\n"); + ++StrongSIVapplications; + assert(0 < Level && Level <= CommonLevels && "level out of range"); + Level--; + + const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta); + LLVM_DEBUG(dbgs() << ", " << *Delta->getType() << "\n"); + + // check that |Delta| < iteration count + if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { + LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound); + LLVM_DEBUG(dbgs() << ", " << *UpperBound->getType() << "\n"); + const SCEV *AbsDelta = + SE->isKnownNonNegative(Delta) ? Delta : SE->getNegativeSCEV(Delta); + const SCEV *AbsCoeff = + SE->isKnownNonNegative(Coeff) ? Coeff : SE->getNegativeSCEV(Coeff); + const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff); + if (isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product)) { + // Distance greater than trip count - no dependence + ++StrongSIVindependence; + ++StrongSIVsuccesses; + return true; + } + } + + // Can we compute distance? + if (isa<SCEVConstant>(Delta) && isa<SCEVConstant>(Coeff)) { + APInt ConstDelta = cast<SCEVConstant>(Delta)->getAPInt(); + APInt ConstCoeff = cast<SCEVConstant>(Coeff)->getAPInt(); + APInt Distance = ConstDelta; // these need to be initialized + APInt Remainder = ConstDelta; + APInt::sdivrem(ConstDelta, ConstCoeff, Distance, Remainder); + LLVM_DEBUG(dbgs() << "\t Distance = " << Distance << "\n"); + LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n"); + // Make sure Coeff divides Delta exactly + if (Remainder != 0) { + // Coeff doesn't divide Distance, no dependence + ++StrongSIVindependence; + ++StrongSIVsuccesses; + return true; + } + Result.DV[Level].Distance = SE->getConstant(Distance); + NewConstraint.setDistance(SE->getConstant(Distance), CurLoop); + if (Distance.sgt(0)) + Result.DV[Level].Direction &= Dependence::DVEntry::LT; + else if (Distance.slt(0)) + Result.DV[Level].Direction &= Dependence::DVEntry::GT; + else + Result.DV[Level].Direction &= Dependence::DVEntry::EQ; + ++StrongSIVsuccesses; + } + else if (Delta->isZero()) { + // since 0/X == 0 + Result.DV[Level].Distance = Delta; + NewConstraint.setDistance(Delta, CurLoop); + Result.DV[Level].Direction &= Dependence::DVEntry::EQ; + ++StrongSIVsuccesses; + } + else { + if (Coeff->isOne()) { + LLVM_DEBUG(dbgs() << "\t Distance = " << *Delta << "\n"); + Result.DV[Level].Distance = Delta; // since X/1 == X + NewConstraint.setDistance(Delta, CurLoop); + } + else { + Result.Consistent = false; + NewConstraint.setLine(Coeff, + SE->getNegativeSCEV(Coeff), + SE->getNegativeSCEV(Delta), CurLoop); + } + + // maybe we can get a useful direction + bool DeltaMaybeZero = !SE->isKnownNonZero(Delta); + bool DeltaMaybePositive = !SE->isKnownNonPositive(Delta); + bool DeltaMaybeNegative = !SE->isKnownNonNegative(Delta); + bool CoeffMaybePositive = !SE->isKnownNonPositive(Coeff); + bool CoeffMaybeNegative = !SE->isKnownNonNegative(Coeff); + // The double negatives above are confusing. + // It helps to read !SE->isKnownNonZero(Delta) + // as "Delta might be Zero" + unsigned NewDirection = Dependence::DVEntry::NONE; + if ((DeltaMaybePositive && CoeffMaybePositive) || + (DeltaMaybeNegative && CoeffMaybeNegative)) + NewDirection = Dependence::DVEntry::LT; + if (DeltaMaybeZero) + NewDirection |= Dependence::DVEntry::EQ; + if ((DeltaMaybeNegative && CoeffMaybePositive) || + (DeltaMaybePositive && CoeffMaybeNegative)) + NewDirection |= Dependence::DVEntry::GT; + if (NewDirection < Result.DV[Level].Direction) + ++StrongSIVsuccesses; + Result.DV[Level].Direction &= NewDirection; + } + return false; +} + + +// weakCrossingSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2 - a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Crossing SIV test. +// +// Given c1 + a*i = c2 - a*i', we can look for the intersection of +// the two lines, where i = i', yielding +// +// c1 + a*i = c2 - a*i +// 2a*i = c2 - c1 +// i = (c2 - c1)/2a +// +// If i < 0, there is no dependence. +// If i > upperbound, there is no dependence. +// If i = 0 (i.e., if c1 = c2), there's a dependence with distance = 0. +// If i = upperbound, there's a dependence with distance = 0. +// If i is integral, there's a dependence (all directions). +// If the non-integer part = 1/2, there's a dependence (<> directions). +// Otherwise, there's no dependence. +// +// Can prove independence. Failing that, +// can sometimes refine the directions. +// Can determine iteration for splitting. +// +// Return true if dependence disproved. +bool DependenceInfo::weakCrossingSIVtest( + const SCEV *Coeff, const SCEV *SrcConst, const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, FullDependence &Result, + Constraint &NewConstraint, const SCEV *&SplitIter) const { + LLVM_DEBUG(dbgs() << "\tWeak-Crossing SIV test\n"); + LLVM_DEBUG(dbgs() << "\t Coeff = " << *Coeff << "\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n"); + ++WeakCrossingSIVapplications; + assert(0 < Level && Level <= CommonLevels && "Level out of range"); + Level--; + Result.Consistent = false; + const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + NewConstraint.setLine(Coeff, Coeff, Delta, CurLoop); + if (Delta->isZero()) { + Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT); + Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT); + ++WeakCrossingSIVsuccesses; + if (!Result.DV[Level].Direction) { + ++WeakCrossingSIVindependence; + return true; + } + Result.DV[Level].Distance = Delta; // = 0 + return false; + } + const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(Coeff); + if (!ConstCoeff) + return false; + + Result.DV[Level].Splitable = true; + if (SE->isKnownNegative(ConstCoeff)) { + ConstCoeff = dyn_cast<SCEVConstant>(SE->getNegativeSCEV(ConstCoeff)); + assert(ConstCoeff && + "dynamic cast of negative of ConstCoeff should yield constant"); + Delta = SE->getNegativeSCEV(Delta); + } + assert(SE->isKnownPositive(ConstCoeff) && "ConstCoeff should be positive"); + + // compute SplitIter for use by DependenceInfo::getSplitIteration() + SplitIter = SE->getUDivExpr( + SE->getSMaxExpr(SE->getZero(Delta->getType()), Delta), + SE->getMulExpr(SE->getConstant(Delta->getType(), 2), ConstCoeff)); + LLVM_DEBUG(dbgs() << "\t Split iter = " << *SplitIter << "\n"); + + const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); + if (!ConstDelta) + return false; + + // We're certain that ConstCoeff > 0; therefore, + // if Delta < 0, then no dependence. + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + LLVM_DEBUG(dbgs() << "\t ConstCoeff = " << *ConstCoeff << "\n"); + if (SE->isKnownNegative(Delta)) { + // No dependence, Delta < 0 + ++WeakCrossingSIVindependence; + ++WeakCrossingSIVsuccesses; + return true; + } + + // We're certain that Delta > 0 and ConstCoeff > 0. + // Check Delta/(2*ConstCoeff) against upper loop bound + if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { + LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n"); + const SCEV *ConstantTwo = SE->getConstant(UpperBound->getType(), 2); + const SCEV *ML = SE->getMulExpr(SE->getMulExpr(ConstCoeff, UpperBound), + ConstantTwo); + LLVM_DEBUG(dbgs() << "\t ML = " << *ML << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, ML)) { + // Delta too big, no dependence + ++WeakCrossingSIVindependence; + ++WeakCrossingSIVsuccesses; + return true; + } + if (isKnownPredicate(CmpInst::ICMP_EQ, Delta, ML)) { + // i = i' = UB + Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT); + Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT); + ++WeakCrossingSIVsuccesses; + if (!Result.DV[Level].Direction) { + ++WeakCrossingSIVindependence; + return true; + } + Result.DV[Level].Splitable = false; + Result.DV[Level].Distance = SE->getZero(Delta->getType()); + return false; + } + } + + // check that Coeff divides Delta + APInt APDelta = ConstDelta->getAPInt(); + APInt APCoeff = ConstCoeff->getAPInt(); + APInt Distance = APDelta; // these need to be initialzed + APInt Remainder = APDelta; + APInt::sdivrem(APDelta, APCoeff, Distance, Remainder); + LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n"); + if (Remainder != 0) { + // Coeff doesn't divide Delta, no dependence + ++WeakCrossingSIVindependence; + ++WeakCrossingSIVsuccesses; + return true; + } + LLVM_DEBUG(dbgs() << "\t Distance = " << Distance << "\n"); + + // if 2*Coeff doesn't divide Delta, then the equal direction isn't possible + APInt Two = APInt(Distance.getBitWidth(), 2, true); + Remainder = Distance.srem(Two); + LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n"); + if (Remainder != 0) { + // Equal direction isn't possible + Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::EQ); + ++WeakCrossingSIVsuccesses; + } + return false; +} + + +// Kirch's algorithm, from +// +// Optimizing Supercompilers for Supercomputers +// Michael Wolfe +// MIT Press, 1989 +// +// Program 2.1, page 29. +// Computes the GCD of AM and BM. +// Also finds a solution to the equation ax - by = gcd(a, b). +// Returns true if dependence disproved; i.e., gcd does not divide Delta. +static bool findGCD(unsigned Bits, const APInt &AM, const APInt &BM, + const APInt &Delta, APInt &G, APInt &X, APInt &Y) { + APInt A0(Bits, 1, true), A1(Bits, 0, true); + APInt B0(Bits, 0, true), B1(Bits, 1, true); + APInt G0 = AM.abs(); + APInt G1 = BM.abs(); + APInt Q = G0; // these need to be initialized + APInt R = G0; + APInt::sdivrem(G0, G1, Q, R); + while (R != 0) { + APInt A2 = A0 - Q*A1; A0 = A1; A1 = A2; + APInt B2 = B0 - Q*B1; B0 = B1; B1 = B2; + G0 = G1; G1 = R; + APInt::sdivrem(G0, G1, Q, R); + } + G = G1; + LLVM_DEBUG(dbgs() << "\t GCD = " << G << "\n"); + X = AM.slt(0) ? -A1 : A1; + Y = BM.slt(0) ? B1 : -B1; + + // make sure gcd divides Delta + R = Delta.srem(G); + if (R != 0) + return true; // gcd doesn't divide Delta, no dependence + Q = Delta.sdiv(G); + X *= Q; + Y *= Q; + return false; +} + +static APInt floorOfQuotient(const APInt &A, const APInt &B) { + APInt Q = A; // these need to be initialized + APInt R = A; + APInt::sdivrem(A, B, Q, R); + if (R == 0) + return Q; + if ((A.sgt(0) && B.sgt(0)) || + (A.slt(0) && B.slt(0))) + return Q; + else + return Q - 1; +} + +static APInt ceilingOfQuotient(const APInt &A, const APInt &B) { + APInt Q = A; // these need to be initialized + APInt R = A; + APInt::sdivrem(A, B, Q, R); + if (R == 0) + return Q; + if ((A.sgt(0) && B.sgt(0)) || + (A.slt(0) && B.slt(0))) + return Q + 1; + else + return Q; +} + + +static +APInt maxAPInt(APInt A, APInt B) { + return A.sgt(B) ? A : B; +} + + +static +APInt minAPInt(APInt A, APInt B) { + return A.slt(B) ? A : B; +} + + +// exactSIVtest - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*i], +// where i is an induction variable, c1 and c2 are loop invariant, and a1 +// and a2 are constant, we can solve it exactly using an algorithm developed +// by Banerjee and Wolfe. See Section 2.5.3 in +// +// Optimizing Supercompilers for Supercomputers +// Michael Wolfe +// MIT Press, 1989 +// +// It's slower than the specialized tests (strong SIV, weak-zero SIV, etc), +// so use them if possible. They're also a bit better with symbolics and, +// in the case of the strong SIV test, can compute Distances. +// +// Return true if dependence disproved. +bool DependenceInfo::exactSIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, + const SCEV *SrcConst, const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { + LLVM_DEBUG(dbgs() << "\tExact SIV test\n"); + LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n"); + LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n"); + ++ExactSIVapplications; + assert(0 < Level && Level <= CommonLevels && "Level out of range"); + Level--; + Result.Consistent = false; + const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + NewConstraint.setLine(SrcCoeff, SE->getNegativeSCEV(DstCoeff), + Delta, CurLoop); + const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); + const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff); + const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff); + if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff) + return false; + + // find gcd + APInt G, X, Y; + APInt AM = ConstSrcCoeff->getAPInt(); + APInt BM = ConstDstCoeff->getAPInt(); + unsigned Bits = AM.getBitWidth(); + if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) { + // gcd doesn't divide Delta, no dependence + ++ExactSIVindependence; + ++ExactSIVsuccesses; + return true; + } + + LLVM_DEBUG(dbgs() << "\t X = " << X << ", Y = " << Y << "\n"); + + // since SCEV construction normalizes, LM = 0 + APInt UM(Bits, 1, true); + bool UMvalid = false; + // UM is perhaps unavailable, let's check + if (const SCEVConstant *CUB = + collectConstantUpperBound(CurLoop, Delta->getType())) { + UM = CUB->getAPInt(); + LLVM_DEBUG(dbgs() << "\t UM = " << UM << "\n"); + UMvalid = true; + } + + APInt TU(APInt::getSignedMaxValue(Bits)); + APInt TL(APInt::getSignedMinValue(Bits)); + + // test(BM/G, LM-X) and test(-BM/G, X-UM) + APInt TMUL = BM.sdiv(G); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + if (UMvalid) { + TU = minAPInt(TU, floorOfQuotient(UM - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + } + } + else { + TU = minAPInt(TU, floorOfQuotient(-X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + if (UMvalid) { + TL = maxAPInt(TL, ceilingOfQuotient(UM - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + } + } + + // test(AM/G, LM-Y) and test(-AM/G, Y-UM) + TMUL = AM.sdiv(G); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + if (UMvalid) { + TU = minAPInt(TU, floorOfQuotient(UM - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + } + } + else { + TU = minAPInt(TU, floorOfQuotient(-Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + if (UMvalid) { + TL = maxAPInt(TL, ceilingOfQuotient(UM - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + } + } + if (TL.sgt(TU)) { + ++ExactSIVindependence; + ++ExactSIVsuccesses; + return true; + } + + // explore directions + unsigned NewDirection = Dependence::DVEntry::NONE; + + // less than + APInt SaveTU(TU); // save these + APInt SaveTL(TL); + LLVM_DEBUG(dbgs() << "\t exploring LT direction\n"); + TMUL = AM - BM; + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(X - Y + 1, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n"); + } + else { + TU = minAPInt(TU, floorOfQuotient(X - Y + 1, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n"); + } + if (TL.sle(TU)) { + NewDirection |= Dependence::DVEntry::LT; + ++ExactSIVsuccesses; + } + + // equal + TU = SaveTU; // restore + TL = SaveTL; + LLVM_DEBUG(dbgs() << "\t exploring EQ direction\n"); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(X - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n"); + } + else { + TU = minAPInt(TU, floorOfQuotient(X - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n"); + } + TMUL = BM - AM; + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(Y - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n"); + } + else { + TU = minAPInt(TU, floorOfQuotient(Y - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n"); + } + if (TL.sle(TU)) { + NewDirection |= Dependence::DVEntry::EQ; + ++ExactSIVsuccesses; + } + + // greater than + TU = SaveTU; // restore + TL = SaveTL; + LLVM_DEBUG(dbgs() << "\t exploring GT direction\n"); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(Y - X + 1, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n"); + } + else { + TU = minAPInt(TU, floorOfQuotient(Y - X + 1, TMUL)); + LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n"); + } + if (TL.sle(TU)) { + NewDirection |= Dependence::DVEntry::GT; + ++ExactSIVsuccesses; + } + + // finished + Result.DV[Level].Direction &= NewDirection; + if (Result.DV[Level].Direction == Dependence::DVEntry::NONE) + ++ExactSIVindependence; + return Result.DV[Level].Direction == Dependence::DVEntry::NONE; +} + + + +// Return true if the divisor evenly divides the dividend. +static +bool isRemainderZero(const SCEVConstant *Dividend, + const SCEVConstant *Divisor) { + const APInt &ConstDividend = Dividend->getAPInt(); + const APInt &ConstDivisor = Divisor->getAPInt(); + return ConstDividend.srem(ConstDivisor) == 0; +} + + +// weakZeroSrcSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1] and [c2 + a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Zero SIV test. +// +// Given +// +// c1 = c2 + a*i +// +// we get +// +// (c1 - c2)/a = i +// +// If i is not an integer, there's no dependence. +// If i < 0 or > UB, there's no dependence. +// If i = 0, the direction is >= and peeling the +// 1st iteration will break the dependence. +// If i = UB, the direction is <= and peeling the +// last iteration will break the dependence. +// Otherwise, the direction is *. +// +// Can prove independence. Failing that, we can sometimes refine +// the directions. Can sometimes show that first or last +// iteration carries all the dependences (so worth peeling). +// +// (see also weakZeroDstSIVtest) +// +// Return true if dependence disproved. +bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff, + const SCEV *SrcConst, + const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { + // For the WeakSIV test, it's possible the loop isn't common to + // the Src and Dst loops. If it isn't, then there's no need to + // record a direction. + LLVM_DEBUG(dbgs() << "\tWeak-Zero (src) SIV test\n"); + LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << "\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n"); + ++WeakZeroSIVapplications; + assert(0 < Level && Level <= MaxLevels && "Level out of range"); + Level--; + Result.Consistent = false; + const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst); + NewConstraint.setLine(SE->getZero(Delta->getType()), DstCoeff, Delta, + CurLoop); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + if (isKnownPredicate(CmpInst::ICMP_EQ, SrcConst, DstConst)) { + if (Level < CommonLevels) { + Result.DV[Level].Direction &= Dependence::DVEntry::GE; + Result.DV[Level].PeelFirst = true; + ++WeakZeroSIVsuccesses; + } + return false; // dependences caused by first iteration + } + const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(DstCoeff); + if (!ConstCoeff) + return false; + const SCEV *AbsCoeff = + SE->isKnownNegative(ConstCoeff) ? + SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; + const SCEV *NewDelta = + SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + + // check that Delta/SrcCoeff < iteration count + // really check NewDelta < count*AbsCoeff + if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { + LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n"); + const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound); + if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) { + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) { + // dependences caused by last iteration + if (Level < CommonLevels) { + Result.DV[Level].Direction &= Dependence::DVEntry::LE; + Result.DV[Level].PeelLast = true; + ++WeakZeroSIVsuccesses; + } + return false; + } + } + + // check that Delta/SrcCoeff >= 0 + // really check that NewDelta >= 0 + if (SE->isKnownNegative(NewDelta)) { + // No dependence, newDelta < 0 + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + + // if SrcCoeff doesn't divide Delta, then no dependence + if (isa<SCEVConstant>(Delta) && + !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) { + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + return false; +} + + +// weakZeroDstSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Zero SIV test. +// +// Given +// +// c1 + a*i = c2 +// +// we get +// +// i = (c2 - c1)/a +// +// If i is not an integer, there's no dependence. +// If i < 0 or > UB, there's no dependence. +// If i = 0, the direction is <= and peeling the +// 1st iteration will break the dependence. +// If i = UB, the direction is >= and peeling the +// last iteration will break the dependence. +// Otherwise, the direction is *. +// +// Can prove independence. Failing that, we can sometimes refine +// the directions. Can sometimes show that first or last +// iteration carries all the dependences (so worth peeling). +// +// (see also weakZeroSrcSIVtest) +// +// Return true if dependence disproved. +bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff, + const SCEV *SrcConst, + const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { + // For the WeakSIV test, it's possible the loop isn't common to the + // Src and Dst loops. If it isn't, then there's no need to record a direction. + LLVM_DEBUG(dbgs() << "\tWeak-Zero (dst) SIV test\n"); + LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << "\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n"); + ++WeakZeroSIVapplications; + assert(0 < Level && Level <= SrcLevels && "Level out of range"); + Level--; + Result.Consistent = false; + const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); + NewConstraint.setLine(SrcCoeff, SE->getZero(Delta->getType()), Delta, + CurLoop); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + if (isKnownPredicate(CmpInst::ICMP_EQ, DstConst, SrcConst)) { + if (Level < CommonLevels) { + Result.DV[Level].Direction &= Dependence::DVEntry::LE; + Result.DV[Level].PeelFirst = true; + ++WeakZeroSIVsuccesses; + } + return false; // dependences caused by first iteration + } + const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(SrcCoeff); + if (!ConstCoeff) + return false; + const SCEV *AbsCoeff = + SE->isKnownNegative(ConstCoeff) ? + SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; + const SCEV *NewDelta = + SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + + // check that Delta/SrcCoeff < iteration count + // really check NewDelta < count*AbsCoeff + if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { + LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n"); + const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound); + if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) { + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) { + // dependences caused by last iteration + if (Level < CommonLevels) { + Result.DV[Level].Direction &= Dependence::DVEntry::GE; + Result.DV[Level].PeelLast = true; + ++WeakZeroSIVsuccesses; + } + return false; + } + } + + // check that Delta/SrcCoeff >= 0 + // really check that NewDelta >= 0 + if (SE->isKnownNegative(NewDelta)) { + // No dependence, newDelta < 0 + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + + // if SrcCoeff doesn't divide Delta, then no dependence + if (isa<SCEVConstant>(Delta) && + !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) { + ++WeakZeroSIVindependence; + ++WeakZeroSIVsuccesses; + return true; + } + return false; +} + + +// exactRDIVtest - Tests the RDIV subscript pair for dependence. +// Things of the form [c1 + a*i] and [c2 + b*j], +// where i and j are induction variable, c1 and c2 are loop invariant, +// and a and b are constants. +// Returns true if any possible dependence is disproved. +// Marks the result as inconsistent. +// Works in some cases that symbolicRDIVtest doesn't, and vice versa. +bool DependenceInfo::exactRDIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, + const SCEV *SrcConst, const SCEV *DstConst, + const Loop *SrcLoop, const Loop *DstLoop, + FullDependence &Result) const { + LLVM_DEBUG(dbgs() << "\tExact RDIV test\n"); + LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n"); + LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n"); + LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); + LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n"); + ++ExactRDIVapplications; + Result.Consistent = false; + const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); + LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n"); + const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); + const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff); + const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff); + if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff) + return false; + + // find gcd + APInt G, X, Y; + APInt AM = ConstSrcCoeff->getAPInt(); + APInt BM = ConstDstCoeff->getAPInt(); + unsigned Bits = AM.getBitWidth(); + if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) { + // gcd doesn't divide Delta, no dependence + ++ExactRDIVindependence; + return true; + } + + LLVM_DEBUG(dbgs() << "\t X = " << X << ", Y = " << Y << "\n"); + + // since SCEV construction seems to normalize, LM = 0 + APInt SrcUM(Bits, 1, true); + bool SrcUMvalid = false; + // SrcUM is perhaps unavailable, let's check + if (const SCEVConstant *UpperBound = + collectConstantUpperBound(SrcLoop, Delta->getType())) { + SrcUM = UpperBound->getAPInt(); + LLVM_DEBUG(dbgs() << "\t SrcUM = " << SrcUM << "\n"); + SrcUMvalid = true; + } + + APInt DstUM(Bits, 1, true); + bool DstUMvalid = false; + // UM is perhaps unavailable, let's check + if (const SCEVConstant *UpperBound = + collectConstantUpperBound(DstLoop, Delta->getType())) { + DstUM = UpperBound->getAPInt(); + LLVM_DEBUG(dbgs() << "\t DstUM = " << DstUM << "\n"); + DstUMvalid = true; + } + + APInt TU(APInt::getSignedMaxValue(Bits)); + APInt TL(APInt::getSignedMinValue(Bits)); + + // test(BM/G, LM-X) and test(-BM/G, X-UM) + APInt TMUL = BM.sdiv(G); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + if (SrcUMvalid) { + TU = minAPInt(TU, floorOfQuotient(SrcUM - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + } + } + else { + TU = minAPInt(TU, floorOfQuotient(-X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + if (SrcUMvalid) { + TL = maxAPInt(TL, ceilingOfQuotient(SrcUM - X, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + } + } + + // test(AM/G, LM-Y) and test(-AM/G, Y-UM) + TMUL = AM.sdiv(G); + if (TMUL.sgt(0)) { + TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + if (DstUMvalid) { + TU = minAPInt(TU, floorOfQuotient(DstUM - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + } + } + else { + TU = minAPInt(TU, floorOfQuotient(-Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n"); + if (DstUMvalid) { + TL = maxAPInt(TL, ceilingOfQuotient(DstUM - Y, TMUL)); + LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n"); + } + } + if (TL.sgt(TU)) + ++ExactRDIVindependence; + return TL.sgt(TU); +} + + +// symbolicRDIVtest - +// In Section 4.5 of the Practical Dependence Testing paper,the authors +// introduce a special case of Banerjee's Inequalities (also called the +// Extreme-Value Test) that can handle some of the SIV and RDIV cases, +// particularly cases with symbolics. Since it's only able to disprove +// dependence (not compute distances or directions), we'll use it as a +// fall back for the other tests. +// +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j] +// where i and j are induction variables and c1 and c2 are loop invariants, +// we can use the symbolic tests to disprove some dependences, serving as a +// backup for the RDIV test. Note that i and j can be the same variable, +// letting this test serve as a backup for the various SIV tests. +// +// For a dependence to exist, c1 + a1*i must equal c2 + a2*j for some +// 0 <= i <= N1 and some 0 <= j <= N2, where N1 and N2 are the (normalized) +// loop bounds for the i and j loops, respectively. So, ... +// +// c1 + a1*i = c2 + a2*j +// a1*i - a2*j = c2 - c1 +// +// To test for a dependence, we compute c2 - c1 and make sure it's in the +// range of the maximum and minimum possible values of a1*i - a2*j. +// Considering the signs of a1 and a2, we have 4 possible cases: +// +// 1) If a1 >= 0 and a2 >= 0, then +// a1*0 - a2*N2 <= c2 - c1 <= a1*N1 - a2*0 +// -a2*N2 <= c2 - c1 <= a1*N1 +// +// 2) If a1 >= 0 and a2 <= 0, then +// a1*0 - a2*0 <= c2 - c1 <= a1*N1 - a2*N2 +// 0 <= c2 - c1 <= a1*N1 - a2*N2 +// +// 3) If a1 <= 0 and a2 >= 0, then +// a1*N1 - a2*N2 <= c2 - c1 <= a1*0 - a2*0 +// a1*N1 - a2*N2 <= c2 - c1 <= 0 +// +// 4) If a1 <= 0 and a2 <= 0, then +// a1*N1 - a2*0 <= c2 - c1 <= a1*0 - a2*N2 +// a1*N1 <= c2 - c1 <= -a2*N2 +// +// return true if dependence disproved +bool DependenceInfo::symbolicRDIVtest(const SCEV *A1, const SCEV *A2, + const SCEV *C1, const SCEV *C2, + const Loop *Loop1, + const Loop *Loop2) const { + ++SymbolicRDIVapplications; + LLVM_DEBUG(dbgs() << "\ttry symbolic RDIV test\n"); + LLVM_DEBUG(dbgs() << "\t A1 = " << *A1); + LLVM_DEBUG(dbgs() << ", type = " << *A1->getType() << "\n"); + LLVM_DEBUG(dbgs() << "\t A2 = " << *A2 << "\n"); + LLVM_DEBUG(dbgs() << "\t C1 = " << *C1 << "\n"); + LLVM_DEBUG(dbgs() << "\t C2 = " << *C2 << "\n"); + const SCEV *N1 = collectUpperBound(Loop1, A1->getType()); + const SCEV *N2 = collectUpperBound(Loop2, A1->getType()); + LLVM_DEBUG(if (N1) dbgs() << "\t N1 = " << *N1 << "\n"); + LLVM_DEBUG(if (N2) dbgs() << "\t N2 = " << *N2 << "\n"); + const SCEV *C2_C1 = SE->getMinusSCEV(C2, C1); + const SCEV *C1_C2 = SE->getMinusSCEV(C1, C2); + LLVM_DEBUG(dbgs() << "\t C2 - C1 = " << *C2_C1 << "\n"); + LLVM_DEBUG(dbgs() << "\t C1 - C2 = " << *C1_C2 << "\n"); + if (SE->isKnownNonNegative(A1)) { + if (SE->isKnownNonNegative(A2)) { + // A1 >= 0 && A2 >= 0 + if (N1) { + // make sure that c2 - c1 <= a1*N1 + const SCEV *A1N1 = SE->getMulExpr(A1, N1); + LLVM_DEBUG(dbgs() << "\t A1*N1 = " << *A1N1 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1)) { + ++SymbolicRDIVindependence; + return true; + } + } + if (N2) { + // make sure that -a2*N2 <= c2 - c1, or a2*N2 >= c1 - c2 + const SCEV *A2N2 = SE->getMulExpr(A2, N2); + LLVM_DEBUG(dbgs() << "\t A2*N2 = " << *A2N2 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SLT, A2N2, C1_C2)) { + ++SymbolicRDIVindependence; + return true; + } + } + } + else if (SE->isKnownNonPositive(A2)) { + // a1 >= 0 && a2 <= 0 + if (N1 && N2) { + // make sure that c2 - c1 <= a1*N1 - a2*N2 + const SCEV *A1N1 = SE->getMulExpr(A1, N1); + const SCEV *A2N2 = SE->getMulExpr(A2, N2); + const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2); + LLVM_DEBUG(dbgs() << "\t A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1_A2N2)) { + ++SymbolicRDIVindependence; + return true; + } + } + // make sure that 0 <= c2 - c1 + if (SE->isKnownNegative(C2_C1)) { + ++SymbolicRDIVindependence; + return true; + } + } + } + else if (SE->isKnownNonPositive(A1)) { + if (SE->isKnownNonNegative(A2)) { + // a1 <= 0 && a2 >= 0 + if (N1 && N2) { + // make sure that a1*N1 - a2*N2 <= c2 - c1 + const SCEV *A1N1 = SE->getMulExpr(A1, N1); + const SCEV *A2N2 = SE->getMulExpr(A2, N2); + const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2); + LLVM_DEBUG(dbgs() << "\t A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1_A2N2, C2_C1)) { + ++SymbolicRDIVindependence; + return true; + } + } + // make sure that c2 - c1 <= 0 + if (SE->isKnownPositive(C2_C1)) { + ++SymbolicRDIVindependence; + return true; + } + } + else if (SE->isKnownNonPositive(A2)) { + // a1 <= 0 && a2 <= 0 + if (N1) { + // make sure that a1*N1 <= c2 - c1 + const SCEV *A1N1 = SE->getMulExpr(A1, N1); + LLVM_DEBUG(dbgs() << "\t A1*N1 = " << *A1N1 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1, C2_C1)) { + ++SymbolicRDIVindependence; + return true; + } + } + if (N2) { + // make sure that c2 - c1 <= -a2*N2, or c1 - c2 >= a2*N2 + const SCEV *A2N2 = SE->getMulExpr(A2, N2); + LLVM_DEBUG(dbgs() << "\t A2*N2 = " << *A2N2 << "\n"); + if (isKnownPredicate(CmpInst::ICMP_SLT, C1_C2, A2N2)) { + ++SymbolicRDIVindependence; + return true; + } + } + } + } + return false; +} + + +// testSIV - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 - a2*i] +// where i is an induction variable, c1 and c2 are loop invariant, and a1 and +// a2 are constant, we attack it with an SIV test. While they can all be +// solved with the Exact SIV test, it's worthwhile to use simpler tests when +// they apply; they're cheaper and sometimes more precise. +// +// Return true if dependence disproved. +bool DependenceInfo::testSIV(const SCEV *Src, const SCEV *Dst, unsigned &Level, + FullDependence &Result, Constraint &NewConstraint, + const SCEV *&SplitIter) const { + LLVM_DEBUG(dbgs() << " src = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n"); + const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src); + const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst); + if (SrcAddRec && DstAddRec) { + const SCEV *SrcConst = SrcAddRec->getStart(); + const SCEV *DstConst = DstAddRec->getStart(); + const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE); + const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE); + const Loop *CurLoop = SrcAddRec->getLoop(); + assert(CurLoop == DstAddRec->getLoop() && + "both loops in SIV should be same"); + Level = mapSrcLoop(CurLoop); + bool disproven; + if (SrcCoeff == DstCoeff) + disproven = strongSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, + Level, Result, NewConstraint); + else if (SrcCoeff == SE->getNegativeSCEV(DstCoeff)) + disproven = weakCrossingSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, + Level, Result, NewConstraint, SplitIter); + else + disproven = exactSIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop, + Level, Result, NewConstraint); + return disproven || + gcdMIVtest(Src, Dst, Result) || + symbolicRDIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop, CurLoop); + } + if (SrcAddRec) { + const SCEV *SrcConst = SrcAddRec->getStart(); + const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE); + const SCEV *DstConst = Dst; + const Loop *CurLoop = SrcAddRec->getLoop(); + Level = mapSrcLoop(CurLoop); + return weakZeroDstSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, + Level, Result, NewConstraint) || + gcdMIVtest(Src, Dst, Result); + } + if (DstAddRec) { + const SCEV *DstConst = DstAddRec->getStart(); + const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE); + const SCEV *SrcConst = Src; + const Loop *CurLoop = DstAddRec->getLoop(); + Level = mapDstLoop(CurLoop); + return weakZeroSrcSIVtest(DstCoeff, SrcConst, DstConst, + CurLoop, Level, Result, NewConstraint) || + gcdMIVtest(Src, Dst, Result); + } + llvm_unreachable("SIV test expected at least one AddRec"); + return false; +} + + +// testRDIV - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j] +// where i and j are induction variables, c1 and c2 are loop invariant, +// and a1 and a2 are constant, we can solve it exactly with an easy adaptation +// of the Exact SIV test, the Restricted Double Index Variable (RDIV) test. +// It doesn't make sense to talk about distance or direction in this case, +// so there's no point in making special versions of the Strong SIV test or +// the Weak-crossing SIV test. +// +// With minor algebra, this test can also be used for things like +// [c1 + a1*i + a2*j][c2]. +// +// Return true if dependence disproved. +bool DependenceInfo::testRDIV(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { + // we have 3 possible situations here: + // 1) [a*i + b] and [c*j + d] + // 2) [a*i + c*j + b] and [d] + // 3) [b] and [a*i + c*j + d] + // We need to find what we've got and get organized + + const SCEV *SrcConst, *DstConst; + const SCEV *SrcCoeff, *DstCoeff; + const Loop *SrcLoop, *DstLoop; + + LLVM_DEBUG(dbgs() << " src = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n"); + const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src); + const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst); + if (SrcAddRec && DstAddRec) { + SrcConst = SrcAddRec->getStart(); + SrcCoeff = SrcAddRec->getStepRecurrence(*SE); + SrcLoop = SrcAddRec->getLoop(); + DstConst = DstAddRec->getStart(); + DstCoeff = DstAddRec->getStepRecurrence(*SE); + DstLoop = DstAddRec->getLoop(); + } + else if (SrcAddRec) { + if (const SCEVAddRecExpr *tmpAddRec = + dyn_cast<SCEVAddRecExpr>(SrcAddRec->getStart())) { + SrcConst = tmpAddRec->getStart(); + SrcCoeff = tmpAddRec->getStepRecurrence(*SE); + SrcLoop = tmpAddRec->getLoop(); + DstConst = Dst; + DstCoeff = SE->getNegativeSCEV(SrcAddRec->getStepRecurrence(*SE)); + DstLoop = SrcAddRec->getLoop(); + } + else + llvm_unreachable("RDIV reached by surprising SCEVs"); + } + else if (DstAddRec) { + if (const SCEVAddRecExpr *tmpAddRec = + dyn_cast<SCEVAddRecExpr>(DstAddRec->getStart())) { + DstConst = tmpAddRec->getStart(); + DstCoeff = tmpAddRec->getStepRecurrence(*SE); + DstLoop = tmpAddRec->getLoop(); + SrcConst = Src; + SrcCoeff = SE->getNegativeSCEV(DstAddRec->getStepRecurrence(*SE)); + SrcLoop = DstAddRec->getLoop(); + } + else + llvm_unreachable("RDIV reached by surprising SCEVs"); + } + else + llvm_unreachable("RDIV expected at least one AddRec"); + return exactRDIVtest(SrcCoeff, DstCoeff, + SrcConst, DstConst, + SrcLoop, DstLoop, + Result) || + gcdMIVtest(Src, Dst, Result) || + symbolicRDIVtest(SrcCoeff, DstCoeff, + SrcConst, DstConst, + SrcLoop, DstLoop); +} + + +// Tests the single-subscript MIV pair (Src and Dst) for dependence. +// Return true if dependence disproved. +// Can sometimes refine direction vectors. +bool DependenceInfo::testMIV(const SCEV *Src, const SCEV *Dst, + const SmallBitVector &Loops, + FullDependence &Result) const { + LLVM_DEBUG(dbgs() << " src = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n"); + Result.Consistent = false; + return gcdMIVtest(Src, Dst, Result) || + banerjeeMIVtest(Src, Dst, Loops, Result); +} + + +// Given a product, e.g., 10*X*Y, returns the first constant operand, +// in this case 10. If there is no constant part, returns NULL. +static +const SCEVConstant *getConstantPart(const SCEV *Expr) { + if (const auto *Constant = dyn_cast<SCEVConstant>(Expr)) + return Constant; + else if (const auto *Product = dyn_cast<SCEVMulExpr>(Expr)) + if (const auto *Constant = dyn_cast<SCEVConstant>(Product->getOperand(0))) + return Constant; + return nullptr; +} + + +//===----------------------------------------------------------------------===// +// gcdMIVtest - +// Tests an MIV subscript pair for dependence. +// Returns true if any possible dependence is disproved. +// Marks the result as inconsistent. +// Can sometimes disprove the equal direction for 1 or more loops, +// as discussed in Michael Wolfe's book, +// High Performance Compilers for Parallel Computing, page 235. +// +// We spend some effort (code!) to handle cases like +// [10*i + 5*N*j + 15*M + 6], where i and j are induction variables, +// but M and N are just loop-invariant variables. +// This should help us handle linearized subscripts; +// also makes this test a useful backup to the various SIV tests. +// +// It occurs to me that the presence of loop-invariant variables +// changes the nature of the test from "greatest common divisor" +// to "a common divisor". +bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { + LLVM_DEBUG(dbgs() << "starting gcd\n"); + ++GCDapplications; + unsigned BitWidth = SE->getTypeSizeInBits(Src->getType()); + APInt RunningGCD = APInt::getNullValue(BitWidth); + + // Examine Src coefficients. + // Compute running GCD and record source constant. + // Because we're looking for the constant at the end of the chain, + // we can't quit the loop just because the GCD == 1. + const SCEV *Coefficients = Src; + while (const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(Coefficients)) { + const SCEV *Coeff = AddRec->getStepRecurrence(*SE); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + const auto *Constant = getConstantPart(Coeff); + if (!Constant) + return false; + APInt ConstCoeff = Constant->getAPInt(); + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); + Coefficients = AddRec->getStart(); + } + const SCEV *SrcConst = Coefficients; + + // Examine Dst coefficients. + // Compute running GCD and record destination constant. + // Because we're looking for the constant at the end of the chain, + // we can't quit the loop just because the GCD == 1. + Coefficients = Dst; + while (const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(Coefficients)) { + const SCEV *Coeff = AddRec->getStepRecurrence(*SE); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + const auto *Constant = getConstantPart(Coeff); + if (!Constant) + return false; + APInt ConstCoeff = Constant->getAPInt(); + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); + Coefficients = AddRec->getStart(); + } + const SCEV *DstConst = Coefficients; + + APInt ExtraGCD = APInt::getNullValue(BitWidth); + const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); + LLVM_DEBUG(dbgs() << " Delta = " << *Delta << "\n"); + const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Delta); + if (const SCEVAddExpr *Sum = dyn_cast<SCEVAddExpr>(Delta)) { + // If Delta is a sum of products, we may be able to make further progress. + for (unsigned Op = 0, Ops = Sum->getNumOperands(); Op < Ops; Op++) { + const SCEV *Operand = Sum->getOperand(Op); + if (isa<SCEVConstant>(Operand)) { + assert(!Constant && "Surprised to find multiple constants"); + Constant = cast<SCEVConstant>(Operand); + } + else if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Operand)) { + // Search for constant operand to participate in GCD; + // If none found; return false. + const SCEVConstant *ConstOp = getConstantPart(Product); + if (!ConstOp) + return false; + APInt ConstOpValue = ConstOp->getAPInt(); + ExtraGCD = APIntOps::GreatestCommonDivisor(ExtraGCD, + ConstOpValue.abs()); + } + else + return false; + } + } + if (!Constant) + return false; + APInt ConstDelta = cast<SCEVConstant>(Constant)->getAPInt(); + LLVM_DEBUG(dbgs() << " ConstDelta = " << ConstDelta << "\n"); + if (ConstDelta == 0) + return false; + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ExtraGCD); + LLVM_DEBUG(dbgs() << " RunningGCD = " << RunningGCD << "\n"); + APInt Remainder = ConstDelta.srem(RunningGCD); + if (Remainder != 0) { + ++GCDindependence; + return true; + } + + // Try to disprove equal directions. + // For example, given a subscript pair [3*i + 2*j] and [i' + 2*j' - 1], + // the code above can't disprove the dependence because the GCD = 1. + // So we consider what happen if i = i' and what happens if j = j'. + // If i = i', we can simplify the subscript to [2*i + 2*j] and [2*j' - 1], + // which is infeasible, so we can disallow the = direction for the i level. + // Setting j = j' doesn't help matters, so we end up with a direction vector + // of [<>, *] + // + // Given A[5*i + 10*j*M + 9*M*N] and A[15*i + 20*j*M - 21*N*M + 5], + // we need to remember that the constant part is 5 and the RunningGCD should + // be initialized to ExtraGCD = 30. + LLVM_DEBUG(dbgs() << " ExtraGCD = " << ExtraGCD << '\n'); + + bool Improved = false; + Coefficients = Src; + while (const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(Coefficients)) { + Coefficients = AddRec->getStart(); + const Loop *CurLoop = AddRec->getLoop(); + RunningGCD = ExtraGCD; + const SCEV *SrcCoeff = AddRec->getStepRecurrence(*SE); + const SCEV *DstCoeff = SE->getMinusSCEV(SrcCoeff, SrcCoeff); + const SCEV *Inner = Src; + while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) { + AddRec = cast<SCEVAddRecExpr>(Inner); + const SCEV *Coeff = AddRec->getStepRecurrence(*SE); + if (CurLoop == AddRec->getLoop()) + ; // SrcCoeff == Coeff + else { + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Coeff); + if (!Constant) + return false; + APInt ConstCoeff = Constant->getAPInt(); + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); + } + Inner = AddRec->getStart(); + } + Inner = Dst; + while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) { + AddRec = cast<SCEVAddRecExpr>(Inner); + const SCEV *Coeff = AddRec->getStepRecurrence(*SE); + if (CurLoop == AddRec->getLoop()) + DstCoeff = Coeff; + else { + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Coeff); + if (!Constant) + return false; + APInt ConstCoeff = Constant->getAPInt(); + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); + } + Inner = AddRec->getStart(); + } + Delta = SE->getMinusSCEV(SrcCoeff, DstCoeff); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Delta); + if (!Constant) + // The difference of the two coefficients might not be a product + // or constant, in which case we give up on this direction. + continue; + APInt ConstCoeff = Constant->getAPInt(); + RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); + LLVM_DEBUG(dbgs() << "\tRunningGCD = " << RunningGCD << "\n"); + if (RunningGCD != 0) { + Remainder = ConstDelta.srem(RunningGCD); + LLVM_DEBUG(dbgs() << "\tRemainder = " << Remainder << "\n"); + if (Remainder != 0) { + unsigned Level = mapSrcLoop(CurLoop); + Result.DV[Level - 1].Direction &= unsigned(~Dependence::DVEntry::EQ); + Improved = true; + } + } + } + if (Improved) + ++GCDsuccesses; + LLVM_DEBUG(dbgs() << "all done\n"); + return false; +} + + +//===----------------------------------------------------------------------===// +// banerjeeMIVtest - +// Use Banerjee's Inequalities to test an MIV subscript pair. +// (Wolfe, in the race-car book, calls this the Extreme Value Test.) +// Generally follows the discussion in Section 2.5.2 of +// +// Optimizing Supercompilers for Supercomputers +// Michael Wolfe +// +// The inequalities given on page 25 are simplified in that loops are +// normalized so that the lower bound is always 0 and the stride is always 1. +// For example, Wolfe gives +// +// LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +// +// where A_k is the coefficient of the kth index in the source subscript, +// B_k is the coefficient of the kth index in the destination subscript, +// U_k is the upper bound of the kth index, L_k is the lower bound of the Kth +// index, and N_k is the stride of the kth index. Since all loops are normalized +// by the SCEV package, N_k = 1 and L_k = 0, allowing us to simplify the +// equation to +// +// LB^<_k = (A^-_k - B_k)^- (U_k - 0 - 1) + (A_k - B_k)0 - B_k 1 +// = (A^-_k - B_k)^- (U_k - 1) - B_k +// +// Similar simplifications are possible for the other equations. +// +// When we can't determine the number of iterations for a loop, +// we use NULL as an indicator for the worst case, infinity. +// When computing the upper bound, NULL denotes +inf; +// for the lower bound, NULL denotes -inf. +// +// Return true if dependence disproved. +bool DependenceInfo::banerjeeMIVtest(const SCEV *Src, const SCEV *Dst, + const SmallBitVector &Loops, + FullDependence &Result) const { + LLVM_DEBUG(dbgs() << "starting Banerjee\n"); + ++BanerjeeApplications; + LLVM_DEBUG(dbgs() << " Src = " << *Src << '\n'); + const SCEV *A0; + CoefficientInfo *A = collectCoeffInfo(Src, true, A0); + LLVM_DEBUG(dbgs() << " Dst = " << *Dst << '\n'); + const SCEV *B0; + CoefficientInfo *B = collectCoeffInfo(Dst, false, B0); + BoundInfo *Bound = new BoundInfo[MaxLevels + 1]; + const SCEV *Delta = SE->getMinusSCEV(B0, A0); + LLVM_DEBUG(dbgs() << "\tDelta = " << *Delta << '\n'); + + // Compute bounds for all the * directions. + LLVM_DEBUG(dbgs() << "\tBounds[*]\n"); + for (unsigned K = 1; K <= MaxLevels; ++K) { + Bound[K].Iterations = A[K].Iterations ? A[K].Iterations : B[K].Iterations; + Bound[K].Direction = Dependence::DVEntry::ALL; + Bound[K].DirSet = Dependence::DVEntry::NONE; + findBoundsALL(A, B, Bound, K); +#ifndef NDEBUG + LLVM_DEBUG(dbgs() << "\t " << K << '\t'); + if (Bound[K].Lower[Dependence::DVEntry::ALL]) + LLVM_DEBUG(dbgs() << *Bound[K].Lower[Dependence::DVEntry::ALL] << '\t'); + else + LLVM_DEBUG(dbgs() << "-inf\t"); + if (Bound[K].Upper[Dependence::DVEntry::ALL]) + LLVM_DEBUG(dbgs() << *Bound[K].Upper[Dependence::DVEntry::ALL] << '\n'); + else + LLVM_DEBUG(dbgs() << "+inf\n"); +#endif + } + + // Test the *, *, *, ... case. + bool Disproved = false; + if (testBounds(Dependence::DVEntry::ALL, 0, Bound, Delta)) { + // Explore the direction vector hierarchy. + unsigned DepthExpanded = 0; + unsigned NewDeps = exploreDirections(1, A, B, Bound, + Loops, DepthExpanded, Delta); + if (NewDeps > 0) { + bool Improved = false; + for (unsigned K = 1; K <= CommonLevels; ++K) { + if (Loops[K]) { + unsigned Old = Result.DV[K - 1].Direction; + Result.DV[K - 1].Direction = Old & Bound[K].DirSet; + Improved |= Old != Result.DV[K - 1].Direction; + if (!Result.DV[K - 1].Direction) { + Improved = false; + Disproved = true; + break; + } + } + } + if (Improved) + ++BanerjeeSuccesses; + } + else { + ++BanerjeeIndependence; + Disproved = true; + } + } + else { + ++BanerjeeIndependence; + Disproved = true; + } + delete [] Bound; + delete [] A; + delete [] B; + return Disproved; +} + + +// Hierarchically expands the direction vector +// search space, combining the directions of discovered dependences +// in the DirSet field of Bound. Returns the number of distinct +// dependences discovered. If the dependence is disproved, +// it will return 0. +unsigned DependenceInfo::exploreDirections(unsigned Level, CoefficientInfo *A, + CoefficientInfo *B, BoundInfo *Bound, + const SmallBitVector &Loops, + unsigned &DepthExpanded, + const SCEV *Delta) const { + if (Level > CommonLevels) { + // record result + LLVM_DEBUG(dbgs() << "\t["); + for (unsigned K = 1; K <= CommonLevels; ++K) { + if (Loops[K]) { + Bound[K].DirSet |= Bound[K].Direction; +#ifndef NDEBUG + switch (Bound[K].Direction) { + case Dependence::DVEntry::LT: + LLVM_DEBUG(dbgs() << " <"); + break; + case Dependence::DVEntry::EQ: + LLVM_DEBUG(dbgs() << " ="); + break; + case Dependence::DVEntry::GT: + LLVM_DEBUG(dbgs() << " >"); + break; + case Dependence::DVEntry::ALL: + LLVM_DEBUG(dbgs() << " *"); + break; + default: + llvm_unreachable("unexpected Bound[K].Direction"); + } +#endif + } + } + LLVM_DEBUG(dbgs() << " ]\n"); + return 1; + } + if (Loops[Level]) { + if (Level > DepthExpanded) { + DepthExpanded = Level; + // compute bounds for <, =, > at current level + findBoundsLT(A, B, Bound, Level); + findBoundsGT(A, B, Bound, Level); + findBoundsEQ(A, B, Bound, Level); +#ifndef NDEBUG + LLVM_DEBUG(dbgs() << "\tBound for level = " << Level << '\n'); + LLVM_DEBUG(dbgs() << "\t <\t"); + if (Bound[Level].Lower[Dependence::DVEntry::LT]) + LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::LT] + << '\t'); + else + LLVM_DEBUG(dbgs() << "-inf\t"); + if (Bound[Level].Upper[Dependence::DVEntry::LT]) + LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::LT] + << '\n'); + else + LLVM_DEBUG(dbgs() << "+inf\n"); + LLVM_DEBUG(dbgs() << "\t =\t"); + if (Bound[Level].Lower[Dependence::DVEntry::EQ]) + LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::EQ] + << '\t'); + else + LLVM_DEBUG(dbgs() << "-inf\t"); + if (Bound[Level].Upper[Dependence::DVEntry::EQ]) + LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::EQ] + << '\n'); + else + LLVM_DEBUG(dbgs() << "+inf\n"); + LLVM_DEBUG(dbgs() << "\t >\t"); + if (Bound[Level].Lower[Dependence::DVEntry::GT]) + LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::GT] + << '\t'); + else + LLVM_DEBUG(dbgs() << "-inf\t"); + if (Bound[Level].Upper[Dependence::DVEntry::GT]) + LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::GT] + << '\n'); + else + LLVM_DEBUG(dbgs() << "+inf\n"); +#endif + } + + unsigned NewDeps = 0; + + // test bounds for <, *, *, ... + if (testBounds(Dependence::DVEntry::LT, Level, Bound, Delta)) + NewDeps += exploreDirections(Level + 1, A, B, Bound, + Loops, DepthExpanded, Delta); + + // Test bounds for =, *, *, ... + if (testBounds(Dependence::DVEntry::EQ, Level, Bound, Delta)) + NewDeps += exploreDirections(Level + 1, A, B, Bound, + Loops, DepthExpanded, Delta); + + // test bounds for >, *, *, ... + if (testBounds(Dependence::DVEntry::GT, Level, Bound, Delta)) + NewDeps += exploreDirections(Level + 1, A, B, Bound, + Loops, DepthExpanded, Delta); + + Bound[Level].Direction = Dependence::DVEntry::ALL; + return NewDeps; + } + else + return exploreDirections(Level + 1, A, B, Bound, Loops, DepthExpanded, Delta); +} + + +// Returns true iff the current bounds are plausible. +bool DependenceInfo::testBounds(unsigned char DirKind, unsigned Level, + BoundInfo *Bound, const SCEV *Delta) const { + Bound[Level].Direction = DirKind; + if (const SCEV *LowerBound = getLowerBound(Bound)) + if (isKnownPredicate(CmpInst::ICMP_SGT, LowerBound, Delta)) + return false; + if (const SCEV *UpperBound = getUpperBound(Bound)) + if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, UpperBound)) + return false; + return true; +} + + +// Computes the upper and lower bounds for level K +// using the * direction. Records them in Bound. +// Wolfe gives the equations +// +// LB^*_k = (A^-_k - B^+_k)(U_k - L_k) + (A_k - B_k)L_k +// UB^*_k = (A^+_k - B^-_k)(U_k - L_k) + (A_k - B_k)L_k +// +// Since we normalize loops, we can simplify these equations to +// +// LB^*_k = (A^-_k - B^+_k)U_k +// UB^*_k = (A^+_k - B^-_k)U_k +// +// We must be careful to handle the case where the upper bound is unknown. +// Note that the lower bound is always <= 0 +// and the upper bound is always >= 0. +void DependenceInfo::findBoundsALL(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { + Bound[K].Lower[Dependence::DVEntry::ALL] = nullptr; // Default value = -infinity. + Bound[K].Upper[Dependence::DVEntry::ALL] = nullptr; // Default value = +infinity. + if (Bound[K].Iterations) { + Bound[K].Lower[Dependence::DVEntry::ALL] = + SE->getMulExpr(SE->getMinusSCEV(A[K].NegPart, B[K].PosPart), + Bound[K].Iterations); + Bound[K].Upper[Dependence::DVEntry::ALL] = + SE->getMulExpr(SE->getMinusSCEV(A[K].PosPart, B[K].NegPart), + Bound[K].Iterations); + } + else { + // If the difference is 0, we won't need to know the number of iterations. + if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].NegPart, B[K].PosPart)) + Bound[K].Lower[Dependence::DVEntry::ALL] = + SE->getZero(A[K].Coeff->getType()); + if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].PosPart, B[K].NegPart)) + Bound[K].Upper[Dependence::DVEntry::ALL] = + SE->getZero(A[K].Coeff->getType()); + } +} + + +// Computes the upper and lower bounds for level K +// using the = direction. Records them in Bound. +// Wolfe gives the equations +// +// LB^=_k = (A_k - B_k)^- (U_k - L_k) + (A_k - B_k)L_k +// UB^=_k = (A_k - B_k)^+ (U_k - L_k) + (A_k - B_k)L_k +// +// Since we normalize loops, we can simplify these equations to +// +// LB^=_k = (A_k - B_k)^- U_k +// UB^=_k = (A_k - B_k)^+ U_k +// +// We must be careful to handle the case where the upper bound is unknown. +// Note that the lower bound is always <= 0 +// and the upper bound is always >= 0. +void DependenceInfo::findBoundsEQ(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { + Bound[K].Lower[Dependence::DVEntry::EQ] = nullptr; // Default value = -infinity. + Bound[K].Upper[Dependence::DVEntry::EQ] = nullptr; // Default value = +infinity. + if (Bound[K].Iterations) { + const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff); + const SCEV *NegativePart = getNegativePart(Delta); + Bound[K].Lower[Dependence::DVEntry::EQ] = + SE->getMulExpr(NegativePart, Bound[K].Iterations); + const SCEV *PositivePart = getPositivePart(Delta); + Bound[K].Upper[Dependence::DVEntry::EQ] = + SE->getMulExpr(PositivePart, Bound[K].Iterations); + } + else { + // If the positive/negative part of the difference is 0, + // we won't need to know the number of iterations. + const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff); + const SCEV *NegativePart = getNegativePart(Delta); + if (NegativePart->isZero()) + Bound[K].Lower[Dependence::DVEntry::EQ] = NegativePart; // Zero + const SCEV *PositivePart = getPositivePart(Delta); + if (PositivePart->isZero()) + Bound[K].Upper[Dependence::DVEntry::EQ] = PositivePart; // Zero + } +} + + +// Computes the upper and lower bounds for level K +// using the < direction. Records them in Bound. +// Wolfe gives the equations +// +// LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +// UB^<_k = (A^+_k - B_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +// +// Since we normalize loops, we can simplify these equations to +// +// LB^<_k = (A^-_k - B_k)^- (U_k - 1) - B_k +// UB^<_k = (A^+_k - B_k)^+ (U_k - 1) - B_k +// +// We must be careful to handle the case where the upper bound is unknown. +void DependenceInfo::findBoundsLT(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { + Bound[K].Lower[Dependence::DVEntry::LT] = nullptr; // Default value = -infinity. + Bound[K].Upper[Dependence::DVEntry::LT] = nullptr; // Default value = +infinity. + if (Bound[K].Iterations) { + const SCEV *Iter_1 = SE->getMinusSCEV( + Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType())); + const SCEV *NegPart = + getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff)); + Bound[K].Lower[Dependence::DVEntry::LT] = + SE->getMinusSCEV(SE->getMulExpr(NegPart, Iter_1), B[K].Coeff); + const SCEV *PosPart = + getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff)); + Bound[K].Upper[Dependence::DVEntry::LT] = + SE->getMinusSCEV(SE->getMulExpr(PosPart, Iter_1), B[K].Coeff); + } + else { + // If the positive/negative part of the difference is 0, + // we won't need to know the number of iterations. + const SCEV *NegPart = + getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff)); + if (NegPart->isZero()) + Bound[K].Lower[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff); + const SCEV *PosPart = + getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff)); + if (PosPart->isZero()) + Bound[K].Upper[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff); + } +} + + +// Computes the upper and lower bounds for level K +// using the > direction. Records them in Bound. +// Wolfe gives the equations +// +// LB^>_k = (A_k - B^+_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k +// UB^>_k = (A_k - B^-_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k +// +// Since we normalize loops, we can simplify these equations to +// +// LB^>_k = (A_k - B^+_k)^- (U_k - 1) + A_k +// UB^>_k = (A_k - B^-_k)^+ (U_k - 1) + A_k +// +// We must be careful to handle the case where the upper bound is unknown. +void DependenceInfo::findBoundsGT(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { + Bound[K].Lower[Dependence::DVEntry::GT] = nullptr; // Default value = -infinity. + Bound[K].Upper[Dependence::DVEntry::GT] = nullptr; // Default value = +infinity. + if (Bound[K].Iterations) { + const SCEV *Iter_1 = SE->getMinusSCEV( + Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType())); + const SCEV *NegPart = + getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart)); + Bound[K].Lower[Dependence::DVEntry::GT] = + SE->getAddExpr(SE->getMulExpr(NegPart, Iter_1), A[K].Coeff); + const SCEV *PosPart = + getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart)); + Bound[K].Upper[Dependence::DVEntry::GT] = + SE->getAddExpr(SE->getMulExpr(PosPart, Iter_1), A[K].Coeff); + } + else { + // If the positive/negative part of the difference is 0, + // we won't need to know the number of iterations. + const SCEV *NegPart = getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart)); + if (NegPart->isZero()) + Bound[K].Lower[Dependence::DVEntry::GT] = A[K].Coeff; + const SCEV *PosPart = getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart)); + if (PosPart->isZero()) + Bound[K].Upper[Dependence::DVEntry::GT] = A[K].Coeff; + } +} + + +// X^+ = max(X, 0) +const SCEV *DependenceInfo::getPositivePart(const SCEV *X) const { + return SE->getSMaxExpr(X, SE->getZero(X->getType())); +} + + +// X^- = min(X, 0) +const SCEV *DependenceInfo::getNegativePart(const SCEV *X) const { + return SE->getSMinExpr(X, SE->getZero(X->getType())); +} + + +// Walks through the subscript, +// collecting each coefficient, the associated loop bounds, +// and recording its positive and negative parts for later use. +DependenceInfo::CoefficientInfo * +DependenceInfo::collectCoeffInfo(const SCEV *Subscript, bool SrcFlag, + const SCEV *&Constant) const { + const SCEV *Zero = SE->getZero(Subscript->getType()); + CoefficientInfo *CI = new CoefficientInfo[MaxLevels + 1]; + for (unsigned K = 1; K <= MaxLevels; ++K) { + CI[K].Coeff = Zero; + CI[K].PosPart = Zero; + CI[K].NegPart = Zero; + CI[K].Iterations = nullptr; + } + while (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Subscript)) { + const Loop *L = AddRec->getLoop(); + unsigned K = SrcFlag ? mapSrcLoop(L) : mapDstLoop(L); + CI[K].Coeff = AddRec->getStepRecurrence(*SE); + CI[K].PosPart = getPositivePart(CI[K].Coeff); + CI[K].NegPart = getNegativePart(CI[K].Coeff); + CI[K].Iterations = collectUpperBound(L, Subscript->getType()); + Subscript = AddRec->getStart(); + } + Constant = Subscript; +#ifndef NDEBUG + LLVM_DEBUG(dbgs() << "\tCoefficient Info\n"); + for (unsigned K = 1; K <= MaxLevels; ++K) { + LLVM_DEBUG(dbgs() << "\t " << K << "\t" << *CI[K].Coeff); + LLVM_DEBUG(dbgs() << "\tPos Part = "); + LLVM_DEBUG(dbgs() << *CI[K].PosPart); + LLVM_DEBUG(dbgs() << "\tNeg Part = "); + LLVM_DEBUG(dbgs() << *CI[K].NegPart); + LLVM_DEBUG(dbgs() << "\tUpper Bound = "); + if (CI[K].Iterations) + LLVM_DEBUG(dbgs() << *CI[K].Iterations); + else + LLVM_DEBUG(dbgs() << "+inf"); + LLVM_DEBUG(dbgs() << '\n'); + } + LLVM_DEBUG(dbgs() << "\t Constant = " << *Subscript << '\n'); +#endif + return CI; +} + + +// Looks through all the bounds info and +// computes the lower bound given the current direction settings +// at each level. If the lower bound for any level is -inf, +// the result is -inf. +const SCEV *DependenceInfo::getLowerBound(BoundInfo *Bound) const { + const SCEV *Sum = Bound[1].Lower[Bound[1].Direction]; + for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { + if (Bound[K].Lower[Bound[K].Direction]) + Sum = SE->getAddExpr(Sum, Bound[K].Lower[Bound[K].Direction]); + else + Sum = nullptr; + } + return Sum; +} + + +// Looks through all the bounds info and +// computes the upper bound given the current direction settings +// at each level. If the upper bound at any level is +inf, +// the result is +inf. +const SCEV *DependenceInfo::getUpperBound(BoundInfo *Bound) const { + const SCEV *Sum = Bound[1].Upper[Bound[1].Direction]; + for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { + if (Bound[K].Upper[Bound[K].Direction]) + Sum = SE->getAddExpr(Sum, Bound[K].Upper[Bound[K].Direction]); + else + Sum = nullptr; + } + return Sum; +} + + +//===----------------------------------------------------------------------===// +// Constraint manipulation for Delta test. + +// Given a linear SCEV, +// return the coefficient (the step) +// corresponding to the specified loop. +// If there isn't one, return 0. +// For example, given a*i + b*j + c*k, finding the coefficient +// corresponding to the j loop would yield b. +const SCEV *DependenceInfo::findCoefficient(const SCEV *Expr, + const Loop *TargetLoop) const { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); + if (!AddRec) + return SE->getZero(Expr->getType()); + if (AddRec->getLoop() == TargetLoop) + return AddRec->getStepRecurrence(*SE); + return findCoefficient(AddRec->getStart(), TargetLoop); +} + + +// Given a linear SCEV, +// return the SCEV given by zeroing out the coefficient +// corresponding to the specified loop. +// For example, given a*i + b*j + c*k, zeroing the coefficient +// corresponding to the j loop would yield a*i + c*k. +const SCEV *DependenceInfo::zeroCoefficient(const SCEV *Expr, + const Loop *TargetLoop) const { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); + if (!AddRec) + return Expr; // ignore + if (AddRec->getLoop() == TargetLoop) + return AddRec->getStart(); + return SE->getAddRecExpr(zeroCoefficient(AddRec->getStart(), TargetLoop), + AddRec->getStepRecurrence(*SE), + AddRec->getLoop(), + AddRec->getNoWrapFlags()); +} + + +// Given a linear SCEV Expr, +// return the SCEV given by adding some Value to the +// coefficient corresponding to the specified TargetLoop. +// For example, given a*i + b*j + c*k, adding 1 to the coefficient +// corresponding to the j loop would yield a*i + (b+1)*j + c*k. +const SCEV *DependenceInfo::addToCoefficient(const SCEV *Expr, + const Loop *TargetLoop, + const SCEV *Value) const { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); + if (!AddRec) // create a new addRec + return SE->getAddRecExpr(Expr, + Value, + TargetLoop, + SCEV::FlagAnyWrap); // Worst case, with no info. + if (AddRec->getLoop() == TargetLoop) { + const SCEV *Sum = SE->getAddExpr(AddRec->getStepRecurrence(*SE), Value); + if (Sum->isZero()) + return AddRec->getStart(); + return SE->getAddRecExpr(AddRec->getStart(), + Sum, + AddRec->getLoop(), + AddRec->getNoWrapFlags()); + } + if (SE->isLoopInvariant(AddRec, TargetLoop)) + return SE->getAddRecExpr(AddRec, Value, TargetLoop, SCEV::FlagAnyWrap); + return SE->getAddRecExpr( + addToCoefficient(AddRec->getStart(), TargetLoop, Value), + AddRec->getStepRecurrence(*SE), AddRec->getLoop(), + AddRec->getNoWrapFlags()); +} + + +// Review the constraints, looking for opportunities +// to simplify a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +// Corresponds to Figure 5 from the paper +// +// Practical Dependence Testing +// Goff, Kennedy, Tseng +// PLDI 1991 +bool DependenceInfo::propagate(const SCEV *&Src, const SCEV *&Dst, + SmallBitVector &Loops, + SmallVectorImpl<Constraint> &Constraints, + bool &Consistent) { + bool Result = false; + for (unsigned LI : Loops.set_bits()) { + LLVM_DEBUG(dbgs() << "\t Constraint[" << LI << "] is"); + LLVM_DEBUG(Constraints[LI].dump(dbgs())); + if (Constraints[LI].isDistance()) + Result |= propagateDistance(Src, Dst, Constraints[LI], Consistent); + else if (Constraints[LI].isLine()) + Result |= propagateLine(Src, Dst, Constraints[LI], Consistent); + else if (Constraints[LI].isPoint()) + Result |= propagatePoint(Src, Dst, Constraints[LI]); + } + return Result; +} + + +// Attempt to propagate a distance +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +bool DependenceInfo::propagateDistance(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint, + bool &Consistent) { + const Loop *CurLoop = CurConstraint.getAssociatedLoop(); + LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n"); + const SCEV *A_K = findCoefficient(Src, CurLoop); + if (A_K->isZero()) + return false; + const SCEV *DA_K = SE->getMulExpr(A_K, CurConstraint.getD()); + Src = SE->getMinusSCEV(Src, DA_K); + Src = zeroCoefficient(Src, CurLoop); + LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n"); + LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n"); + Dst = addToCoefficient(Dst, CurLoop, SE->getNegativeSCEV(A_K)); + LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n"); + if (!findCoefficient(Dst, CurLoop)->isZero()) + Consistent = false; + return true; +} + + +// Attempt to propagate a line +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +bool DependenceInfo::propagateLine(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint, + bool &Consistent) { + const Loop *CurLoop = CurConstraint.getAssociatedLoop(); + const SCEV *A = CurConstraint.getA(); + const SCEV *B = CurConstraint.getB(); + const SCEV *C = CurConstraint.getC(); + LLVM_DEBUG(dbgs() << "\t\tA = " << *A << ", B = " << *B << ", C = " << *C + << "\n"); + LLVM_DEBUG(dbgs() << "\t\tSrc = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << "\t\tDst = " << *Dst << "\n"); + if (A->isZero()) { + const SCEVConstant *Bconst = dyn_cast<SCEVConstant>(B); + const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); + if (!Bconst || !Cconst) return false; + APInt Beta = Bconst->getAPInt(); + APInt Charlie = Cconst->getAPInt(); + APInt CdivB = Charlie.sdiv(Beta); + assert(Charlie.srem(Beta) == 0 && "C should be evenly divisible by B"); + const SCEV *AP_K = findCoefficient(Dst, CurLoop); + // Src = SE->getAddExpr(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB))); + Src = SE->getMinusSCEV(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB))); + Dst = zeroCoefficient(Dst, CurLoop); + if (!findCoefficient(Src, CurLoop)->isZero()) + Consistent = false; + } + else if (B->isZero()) { + const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A); + const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); + if (!Aconst || !Cconst) return false; + APInt Alpha = Aconst->getAPInt(); + APInt Charlie = Cconst->getAPInt(); + APInt CdivA = Charlie.sdiv(Alpha); + assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A"); + const SCEV *A_K = findCoefficient(Src, CurLoop); + Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA))); + Src = zeroCoefficient(Src, CurLoop); + if (!findCoefficient(Dst, CurLoop)->isZero()) + Consistent = false; + } + else if (isKnownPredicate(CmpInst::ICMP_EQ, A, B)) { + const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A); + const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); + if (!Aconst || !Cconst) return false; + APInt Alpha = Aconst->getAPInt(); + APInt Charlie = Cconst->getAPInt(); + APInt CdivA = Charlie.sdiv(Alpha); + assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A"); + const SCEV *A_K = findCoefficient(Src, CurLoop); + Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA))); + Src = zeroCoefficient(Src, CurLoop); + Dst = addToCoefficient(Dst, CurLoop, A_K); + if (!findCoefficient(Dst, CurLoop)->isZero()) + Consistent = false; + } + else { + // paper is incorrect here, or perhaps just misleading + const SCEV *A_K = findCoefficient(Src, CurLoop); + Src = SE->getMulExpr(Src, A); + Dst = SE->getMulExpr(Dst, A); + Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, C)); + Src = zeroCoefficient(Src, CurLoop); + Dst = addToCoefficient(Dst, CurLoop, SE->getMulExpr(A_K, B)); + if (!findCoefficient(Dst, CurLoop)->isZero()) + Consistent = false; + } + LLVM_DEBUG(dbgs() << "\t\tnew Src = " << *Src << "\n"); + LLVM_DEBUG(dbgs() << "\t\tnew Dst = " << *Dst << "\n"); + return true; +} + + +// Attempt to propagate a point +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +bool DependenceInfo::propagatePoint(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint) { + const Loop *CurLoop = CurConstraint.getAssociatedLoop(); + const SCEV *A_K = findCoefficient(Src, CurLoop); + const SCEV *AP_K = findCoefficient(Dst, CurLoop); + const SCEV *XA_K = SE->getMulExpr(A_K, CurConstraint.getX()); + const SCEV *YAP_K = SE->getMulExpr(AP_K, CurConstraint.getY()); + LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n"); + Src = SE->getAddExpr(Src, SE->getMinusSCEV(XA_K, YAP_K)); + Src = zeroCoefficient(Src, CurLoop); + LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n"); + LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n"); + Dst = zeroCoefficient(Dst, CurLoop); + LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n"); + return true; +} + + +// Update direction vector entry based on the current constraint. +void DependenceInfo::updateDirection(Dependence::DVEntry &Level, + const Constraint &CurConstraint) const { + LLVM_DEBUG(dbgs() << "\tUpdate direction, constraint ="); + LLVM_DEBUG(CurConstraint.dump(dbgs())); + if (CurConstraint.isAny()) + ; // use defaults + else if (CurConstraint.isDistance()) { + // this one is consistent, the others aren't + Level.Scalar = false; + Level.Distance = CurConstraint.getD(); + unsigned NewDirection = Dependence::DVEntry::NONE; + if (!SE->isKnownNonZero(Level.Distance)) // if may be zero + NewDirection = Dependence::DVEntry::EQ; + if (!SE->isKnownNonPositive(Level.Distance)) // if may be positive + NewDirection |= Dependence::DVEntry::LT; + if (!SE->isKnownNonNegative(Level.Distance)) // if may be negative + NewDirection |= Dependence::DVEntry::GT; + Level.Direction &= NewDirection; + } + else if (CurConstraint.isLine()) { + Level.Scalar = false; + Level.Distance = nullptr; + // direction should be accurate + } + else if (CurConstraint.isPoint()) { + Level.Scalar = false; + Level.Distance = nullptr; + unsigned NewDirection = Dependence::DVEntry::NONE; + if (!isKnownPredicate(CmpInst::ICMP_NE, + CurConstraint.getY(), + CurConstraint.getX())) + // if X may be = Y + NewDirection |= Dependence::DVEntry::EQ; + if (!isKnownPredicate(CmpInst::ICMP_SLE, + CurConstraint.getY(), + CurConstraint.getX())) + // if Y may be > X + NewDirection |= Dependence::DVEntry::LT; + if (!isKnownPredicate(CmpInst::ICMP_SGE, + CurConstraint.getY(), + CurConstraint.getX())) + // if Y may be < X + NewDirection |= Dependence::DVEntry::GT; + Level.Direction &= NewDirection; + } + else + llvm_unreachable("constraint has unexpected kind"); +} + +/// Check if we can delinearize the subscripts. If the SCEVs representing the +/// source and destination array references are recurrences on a nested loop, +/// this function flattens the nested recurrences into separate recurrences +/// for each loop level. +bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, + SmallVectorImpl<Subscript> &Pair) { + assert(isLoadOrStore(Src) && "instruction is not load or store"); + assert(isLoadOrStore(Dst) && "instruction is not load or store"); + Value *SrcPtr = getLoadStorePointerOperand(Src); + Value *DstPtr = getLoadStorePointerOperand(Dst); + + Loop *SrcLoop = LI->getLoopFor(Src->getParent()); + Loop *DstLoop = LI->getLoopFor(Dst->getParent()); + + // Below code mimics the code in Delinearization.cpp + const SCEV *SrcAccessFn = + SE->getSCEVAtScope(SrcPtr, SrcLoop); + const SCEV *DstAccessFn = + SE->getSCEVAtScope(DstPtr, DstLoop); + + const SCEVUnknown *SrcBase = + dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcAccessFn)); + const SCEVUnknown *DstBase = + dyn_cast<SCEVUnknown>(SE->getPointerBase(DstAccessFn)); + + if (!SrcBase || !DstBase || SrcBase != DstBase) + return false; + + const SCEV *ElementSize = SE->getElementSize(Src); + if (ElementSize != SE->getElementSize(Dst)) + return false; + + const SCEV *SrcSCEV = SE->getMinusSCEV(SrcAccessFn, SrcBase); + const SCEV *DstSCEV = SE->getMinusSCEV(DstAccessFn, DstBase); + + const SCEVAddRecExpr *SrcAR = dyn_cast<SCEVAddRecExpr>(SrcSCEV); + const SCEVAddRecExpr *DstAR = dyn_cast<SCEVAddRecExpr>(DstSCEV); + if (!SrcAR || !DstAR || !SrcAR->isAffine() || !DstAR->isAffine()) + return false; + + // First step: collect parametric terms in both array references. + SmallVector<const SCEV *, 4> Terms; + SE->collectParametricTerms(SrcAR, Terms); + SE->collectParametricTerms(DstAR, Terms); + + // Second step: find subscript sizes. + SmallVector<const SCEV *, 4> Sizes; + SE->findArrayDimensions(Terms, Sizes, ElementSize); + + // Third step: compute the access functions for each subscript. + SmallVector<const SCEV *, 4> SrcSubscripts, DstSubscripts; + SE->computeAccessFunctions(SrcAR, SrcSubscripts, Sizes); + SE->computeAccessFunctions(DstAR, DstSubscripts, Sizes); + + // Fail when there is only a subscript: that's a linearized access function. + if (SrcSubscripts.size() < 2 || DstSubscripts.size() < 2 || + SrcSubscripts.size() != DstSubscripts.size()) + return false; + + int size = SrcSubscripts.size(); + + // Statically check that the array bounds are in-range. The first subscript we + // don't have a size for and it cannot overflow into another subscript, so is + // always safe. The others need to be 0 <= subscript[i] < bound, for both src + // and dst. + // FIXME: It may be better to record these sizes and add them as constraints + // to the dependency checks. + if (!DisableDelinearizationChecks) + for (int i = 1; i < size; ++i) { + if (!isKnownNonNegative(SrcSubscripts[i], SrcPtr)) + return false; + + if (!isKnownLessThan(SrcSubscripts[i], Sizes[i - 1])) + return false; + + if (!isKnownNonNegative(DstSubscripts[i], DstPtr)) + return false; + + if (!isKnownLessThan(DstSubscripts[i], Sizes[i - 1])) + return false; + } + + LLVM_DEBUG({ + dbgs() << "\nSrcSubscripts: "; + for (int i = 0; i < size; i++) + dbgs() << *SrcSubscripts[i]; + dbgs() << "\nDstSubscripts: "; + for (int i = 0; i < size; i++) + dbgs() << *DstSubscripts[i]; + }); + + // The delinearization transforms a single-subscript MIV dependence test into + // a multi-subscript SIV dependence test that is easier to compute. So we + // resize Pair to contain as many pairs of subscripts as the delinearization + // has found, and then initialize the pairs following the delinearization. + Pair.resize(size); + for (int i = 0; i < size; ++i) { + Pair[i].Src = SrcSubscripts[i]; + Pair[i].Dst = DstSubscripts[i]; + unifySubscriptType(&Pair[i]); + } + + return true; +} + +//===----------------------------------------------------------------------===// + +#ifndef NDEBUG +// For debugging purposes, dump a small bit vector to dbgs(). +static void dumpSmallBitVector(SmallBitVector &BV) { + dbgs() << "{"; + for (unsigned VI : BV.set_bits()) { + dbgs() << VI; + if (BV.find_next(VI) >= 0) + dbgs() << ' '; + } + dbgs() << "}\n"; +} +#endif + +bool DependenceInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // Check if the analysis itself has been invalidated. + auto PAC = PA.getChecker<DependenceAnalysis>(); + if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>()) + return true; + + // Check transitive dependencies. + return Inv.invalidate<AAManager>(F, PA) || + Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) || + Inv.invalidate<LoopAnalysis>(F, PA); +} + +// depends - +// Returns NULL if there is no dependence. +// Otherwise, return a Dependence with as many details as possible. +// Corresponds to Section 3.1 in the paper +// +// Practical Dependence Testing +// Goff, Kennedy, Tseng +// PLDI 1991 +// +// Care is required to keep the routine below, getSplitIteration(), +// up to date with respect to this routine. +std::unique_ptr<Dependence> +DependenceInfo::depends(Instruction *Src, Instruction *Dst, + bool PossiblyLoopIndependent) { + if (Src == Dst) + PossiblyLoopIndependent = false; + + if ((!Src->mayReadFromMemory() && !Src->mayWriteToMemory()) || + (!Dst->mayReadFromMemory() && !Dst->mayWriteToMemory())) + // if both instructions don't reference memory, there's no dependence + return nullptr; + + if (!isLoadOrStore(Src) || !isLoadOrStore(Dst)) { + // can only analyze simple loads and stores, i.e., no calls, invokes, etc. + LLVM_DEBUG(dbgs() << "can only handle simple loads and stores\n"); + return std::make_unique<Dependence>(Src, Dst); + } + + assert(isLoadOrStore(Src) && "instruction is not load or store"); + assert(isLoadOrStore(Dst) && "instruction is not load or store"); + Value *SrcPtr = getLoadStorePointerOperand(Src); + Value *DstPtr = getLoadStorePointerOperand(Dst); + + switch (underlyingObjectsAlias(AA, F->getParent()->getDataLayout(), + MemoryLocation::get(Dst), + MemoryLocation::get(Src))) { + case MayAlias: + case PartialAlias: + // cannot analyse objects if we don't understand their aliasing. + LLVM_DEBUG(dbgs() << "can't analyze may or partial alias\n"); + return std::make_unique<Dependence>(Src, Dst); + case NoAlias: + // If the objects noalias, they are distinct, accesses are independent. + LLVM_DEBUG(dbgs() << "no alias\n"); + return nullptr; + case MustAlias: + break; // The underlying objects alias; test accesses for dependence. + } + + // establish loop nesting levels + establishNestingLevels(Src, Dst); + LLVM_DEBUG(dbgs() << " common nesting levels = " << CommonLevels << "\n"); + LLVM_DEBUG(dbgs() << " maximum nesting levels = " << MaxLevels << "\n"); + + FullDependence Result(Src, Dst, PossiblyLoopIndependent, CommonLevels); + ++TotalArrayPairs; + + unsigned Pairs = 1; + SmallVector<Subscript, 2> Pair(Pairs); + const SCEV *SrcSCEV = SE->getSCEV(SrcPtr); + const SCEV *DstSCEV = SE->getSCEV(DstPtr); + LLVM_DEBUG(dbgs() << " SrcSCEV = " << *SrcSCEV << "\n"); + LLVM_DEBUG(dbgs() << " DstSCEV = " << *DstSCEV << "\n"); + Pair[0].Src = SrcSCEV; + Pair[0].Dst = DstSCEV; + + if (Delinearize) { + if (tryDelinearize(Src, Dst, Pair)) { + LLVM_DEBUG(dbgs() << " delinearized\n"); + Pairs = Pair.size(); + } + } + + for (unsigned P = 0; P < Pairs; ++P) { + Pair[P].Loops.resize(MaxLevels + 1); + Pair[P].GroupLoops.resize(MaxLevels + 1); + Pair[P].Group.resize(Pairs); + removeMatchingExtensions(&Pair[P]); + Pair[P].Classification = + classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()), + Pair[P].Dst, LI->getLoopFor(Dst->getParent()), + Pair[P].Loops); + Pair[P].GroupLoops = Pair[P].Loops; + Pair[P].Group.set(P); + LLVM_DEBUG(dbgs() << " subscript " << P << "\n"); + LLVM_DEBUG(dbgs() << "\tsrc = " << *Pair[P].Src << "\n"); + LLVM_DEBUG(dbgs() << "\tdst = " << *Pair[P].Dst << "\n"); + LLVM_DEBUG(dbgs() << "\tclass = " << Pair[P].Classification << "\n"); + LLVM_DEBUG(dbgs() << "\tloops = "); + LLVM_DEBUG(dumpSmallBitVector(Pair[P].Loops)); + } + + SmallBitVector Separable(Pairs); + SmallBitVector Coupled(Pairs); + + // Partition subscripts into separable and minimally-coupled groups + // Algorithm in paper is algorithmically better; + // this may be faster in practice. Check someday. + // + // Here's an example of how it works. Consider this code: + // + // for (i = ...) { + // for (j = ...) { + // for (k = ...) { + // for (l = ...) { + // for (m = ...) { + // A[i][j][k][m] = ...; + // ... = A[0][j][l][i + j]; + // } + // } + // } + // } + // } + // + // There are 4 subscripts here: + // 0 [i] and [0] + // 1 [j] and [j] + // 2 [k] and [l] + // 3 [m] and [i + j] + // + // We've already classified each subscript pair as ZIV, SIV, etc., + // and collected all the loops mentioned by pair P in Pair[P].Loops. + // In addition, we've initialized Pair[P].GroupLoops to Pair[P].Loops + // and set Pair[P].Group = {P}. + // + // Src Dst Classification Loops GroupLoops Group + // 0 [i] [0] SIV {1} {1} {0} + // 1 [j] [j] SIV {2} {2} {1} + // 2 [k] [l] RDIV {3,4} {3,4} {2} + // 3 [m] [i + j] MIV {1,2,5} {1,2,5} {3} + // + // For each subscript SI 0 .. 3, we consider each remaining subscript, SJ. + // So, 0 is compared against 1, 2, and 3; 1 is compared against 2 and 3, etc. + // + // We begin by comparing 0 and 1. The intersection of the GroupLoops is empty. + // Next, 0 and 2. Again, the intersection of their GroupLoops is empty. + // Next 0 and 3. The intersection of their GroupLoop = {1}, not empty, + // so Pair[3].Group = {0,3} and Done = false (that is, 0 will not be added + // to either Separable or Coupled). + // + // Next, we consider 1 and 2. The intersection of the GroupLoops is empty. + // Next, 1 and 3. The intersection of their GroupLoops = {2}, not empty, + // so Pair[3].Group = {0, 1, 3} and Done = false. + // + // Next, we compare 2 against 3. The intersection of the GroupLoops is empty. + // Since Done remains true, we add 2 to the set of Separable pairs. + // + // Finally, we consider 3. There's nothing to compare it with, + // so Done remains true and we add it to the Coupled set. + // Pair[3].Group = {0, 1, 3} and GroupLoops = {1, 2, 5}. + // + // In the end, we've got 1 separable subscript and 1 coupled group. + for (unsigned SI = 0; SI < Pairs; ++SI) { + if (Pair[SI].Classification == Subscript::NonLinear) { + // ignore these, but collect loops for later + ++NonlinearSubscriptPairs; + collectCommonLoops(Pair[SI].Src, + LI->getLoopFor(Src->getParent()), + Pair[SI].Loops); + collectCommonLoops(Pair[SI].Dst, + LI->getLoopFor(Dst->getParent()), + Pair[SI].Loops); + Result.Consistent = false; + } else if (Pair[SI].Classification == Subscript::ZIV) { + // always separable + Separable.set(SI); + } + else { + // SIV, RDIV, or MIV, so check for coupled group + bool Done = true; + for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) { + SmallBitVector Intersection = Pair[SI].GroupLoops; + Intersection &= Pair[SJ].GroupLoops; + if (Intersection.any()) { + // accumulate set of all the loops in group + Pair[SJ].GroupLoops |= Pair[SI].GroupLoops; + // accumulate set of all subscripts in group + Pair[SJ].Group |= Pair[SI].Group; + Done = false; + } + } + if (Done) { + if (Pair[SI].Group.count() == 1) { + Separable.set(SI); + ++SeparableSubscriptPairs; + } + else { + Coupled.set(SI); + ++CoupledSubscriptPairs; + } + } + } + } + + LLVM_DEBUG(dbgs() << " Separable = "); + LLVM_DEBUG(dumpSmallBitVector(Separable)); + LLVM_DEBUG(dbgs() << " Coupled = "); + LLVM_DEBUG(dumpSmallBitVector(Coupled)); + + Constraint NewConstraint; + NewConstraint.setAny(SE); + + // test separable subscripts + for (unsigned SI : Separable.set_bits()) { + LLVM_DEBUG(dbgs() << "testing subscript " << SI); + switch (Pair[SI].Classification) { + case Subscript::ZIV: + LLVM_DEBUG(dbgs() << ", ZIV\n"); + if (testZIV(Pair[SI].Src, Pair[SI].Dst, Result)) + return nullptr; + break; + case Subscript::SIV: { + LLVM_DEBUG(dbgs() << ", SIV\n"); + unsigned Level; + const SCEV *SplitIter = nullptr; + if (testSIV(Pair[SI].Src, Pair[SI].Dst, Level, Result, NewConstraint, + SplitIter)) + return nullptr; + break; + } + case Subscript::RDIV: + LLVM_DEBUG(dbgs() << ", RDIV\n"); + if (testRDIV(Pair[SI].Src, Pair[SI].Dst, Result)) + return nullptr; + break; + case Subscript::MIV: + LLVM_DEBUG(dbgs() << ", MIV\n"); + if (testMIV(Pair[SI].Src, Pair[SI].Dst, Pair[SI].Loops, Result)) + return nullptr; + break; + default: + llvm_unreachable("subscript has unexpected classification"); + } + } + + if (Coupled.count()) { + // test coupled subscript groups + LLVM_DEBUG(dbgs() << "starting on coupled subscripts\n"); + LLVM_DEBUG(dbgs() << "MaxLevels + 1 = " << MaxLevels + 1 << "\n"); + SmallVector<Constraint, 4> Constraints(MaxLevels + 1); + for (unsigned II = 0; II <= MaxLevels; ++II) + Constraints[II].setAny(SE); + for (unsigned SI : Coupled.set_bits()) { + LLVM_DEBUG(dbgs() << "testing subscript group " << SI << " { "); + SmallBitVector Group(Pair[SI].Group); + SmallBitVector Sivs(Pairs); + SmallBitVector Mivs(Pairs); + SmallBitVector ConstrainedLevels(MaxLevels + 1); + SmallVector<Subscript *, 4> PairsInGroup; + for (unsigned SJ : Group.set_bits()) { + LLVM_DEBUG(dbgs() << SJ << " "); + if (Pair[SJ].Classification == Subscript::SIV) + Sivs.set(SJ); + else + Mivs.set(SJ); + PairsInGroup.push_back(&Pair[SJ]); + } + unifySubscriptType(PairsInGroup); + LLVM_DEBUG(dbgs() << "}\n"); + while (Sivs.any()) { + bool Changed = false; + for (unsigned SJ : Sivs.set_bits()) { + LLVM_DEBUG(dbgs() << "testing subscript " << SJ << ", SIV\n"); + // SJ is an SIV subscript that's part of the current coupled group + unsigned Level; + const SCEV *SplitIter = nullptr; + LLVM_DEBUG(dbgs() << "SIV\n"); + if (testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level, Result, NewConstraint, + SplitIter)) + return nullptr; + ConstrainedLevels.set(Level); + if (intersectConstraints(&Constraints[Level], &NewConstraint)) { + if (Constraints[Level].isEmpty()) { + ++DeltaIndependence; + return nullptr; + } + Changed = true; + } + Sivs.reset(SJ); + } + if (Changed) { + // propagate, possibly creating new SIVs and ZIVs + LLVM_DEBUG(dbgs() << " propagating\n"); + LLVM_DEBUG(dbgs() << "\tMivs = "); + LLVM_DEBUG(dumpSmallBitVector(Mivs)); + for (unsigned SJ : Mivs.set_bits()) { + // SJ is an MIV subscript that's part of the current coupled group + LLVM_DEBUG(dbgs() << "\tSJ = " << SJ << "\n"); + if (propagate(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, + Constraints, Result.Consistent)) { + LLVM_DEBUG(dbgs() << "\t Changed\n"); + ++DeltaPropagations; + Pair[SJ].Classification = + classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()), + Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()), + Pair[SJ].Loops); + switch (Pair[SJ].Classification) { + case Subscript::ZIV: + LLVM_DEBUG(dbgs() << "ZIV\n"); + if (testZIV(Pair[SJ].Src, Pair[SJ].Dst, Result)) + return nullptr; + Mivs.reset(SJ); + break; + case Subscript::SIV: + Sivs.set(SJ); + Mivs.reset(SJ); + break; + case Subscript::RDIV: + case Subscript::MIV: + break; + default: + llvm_unreachable("bad subscript classification"); + } + } + } + } + } + + // test & propagate remaining RDIVs + for (unsigned SJ : Mivs.set_bits()) { + if (Pair[SJ].Classification == Subscript::RDIV) { + LLVM_DEBUG(dbgs() << "RDIV test\n"); + if (testRDIV(Pair[SJ].Src, Pair[SJ].Dst, Result)) + return nullptr; + // I don't yet understand how to propagate RDIV results + Mivs.reset(SJ); + } + } + + // test remaining MIVs + // This code is temporary. + // Better to somehow test all remaining subscripts simultaneously. + for (unsigned SJ : Mivs.set_bits()) { + if (Pair[SJ].Classification == Subscript::MIV) { + LLVM_DEBUG(dbgs() << "MIV test\n"); + if (testMIV(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, Result)) + return nullptr; + } + else + llvm_unreachable("expected only MIV subscripts at this point"); + } + + // update Result.DV from constraint vector + LLVM_DEBUG(dbgs() << " updating\n"); + for (unsigned SJ : ConstrainedLevels.set_bits()) { + if (SJ > CommonLevels) + break; + updateDirection(Result.DV[SJ - 1], Constraints[SJ]); + if (Result.DV[SJ - 1].Direction == Dependence::DVEntry::NONE) + return nullptr; + } + } + } + + // Make sure the Scalar flags are set correctly. + SmallBitVector CompleteLoops(MaxLevels + 1); + for (unsigned SI = 0; SI < Pairs; ++SI) + CompleteLoops |= Pair[SI].Loops; + for (unsigned II = 1; II <= CommonLevels; ++II) + if (CompleteLoops[II]) + Result.DV[II - 1].Scalar = false; + + if (PossiblyLoopIndependent) { + // Make sure the LoopIndependent flag is set correctly. + // All directions must include equal, otherwise no + // loop-independent dependence is possible. + for (unsigned II = 1; II <= CommonLevels; ++II) { + if (!(Result.getDirection(II) & Dependence::DVEntry::EQ)) { + Result.LoopIndependent = false; + break; + } + } + } + else { + // On the other hand, if all directions are equal and there's no + // loop-independent dependence possible, then no dependence exists. + bool AllEqual = true; + for (unsigned II = 1; II <= CommonLevels; ++II) { + if (Result.getDirection(II) != Dependence::DVEntry::EQ) { + AllEqual = false; + break; + } + } + if (AllEqual) + return nullptr; + } + + return std::make_unique<FullDependence>(std::move(Result)); +} + + + +//===----------------------------------------------------------------------===// +// getSplitIteration - +// Rather than spend rarely-used space recording the splitting iteration +// during the Weak-Crossing SIV test, we re-compute it on demand. +// The re-computation is basically a repeat of the entire dependence test, +// though simplified since we know that the dependence exists. +// It's tedious, since we must go through all propagations, etc. +// +// Care is required to keep this code up to date with respect to the routine +// above, depends(). +// +// Generally, the dependence analyzer will be used to build +// a dependence graph for a function (basically a map from instructions +// to dependences). Looking for cycles in the graph shows us loops +// that cannot be trivially vectorized/parallelized. +// +// We can try to improve the situation by examining all the dependences +// that make up the cycle, looking for ones we can break. +// Sometimes, peeling the first or last iteration of a loop will break +// dependences, and we've got flags for those possibilities. +// Sometimes, splitting a loop at some other iteration will do the trick, +// and we've got a flag for that case. Rather than waste the space to +// record the exact iteration (since we rarely know), we provide +// a method that calculates the iteration. It's a drag that it must work +// from scratch, but wonderful in that it's possible. +// +// Here's an example: +// +// for (i = 0; i < 10; i++) +// A[i] = ... +// ... = A[11 - i] +// +// There's a loop-carried flow dependence from the store to the load, +// found by the weak-crossing SIV test. The dependence will have a flag, +// indicating that the dependence can be broken by splitting the loop. +// Calling getSplitIteration will return 5. +// Splitting the loop breaks the dependence, like so: +// +// for (i = 0; i <= 5; i++) +// A[i] = ... +// ... = A[11 - i] +// for (i = 6; i < 10; i++) +// A[i] = ... +// ... = A[11 - i] +// +// breaks the dependence and allows us to vectorize/parallelize +// both loops. +const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, + unsigned SplitLevel) { + assert(Dep.isSplitable(SplitLevel) && + "Dep should be splitable at SplitLevel"); + Instruction *Src = Dep.getSrc(); + Instruction *Dst = Dep.getDst(); + assert(Src->mayReadFromMemory() || Src->mayWriteToMemory()); + assert(Dst->mayReadFromMemory() || Dst->mayWriteToMemory()); + assert(isLoadOrStore(Src)); + assert(isLoadOrStore(Dst)); + Value *SrcPtr = getLoadStorePointerOperand(Src); + Value *DstPtr = getLoadStorePointerOperand(Dst); + assert(underlyingObjectsAlias(AA, F->getParent()->getDataLayout(), + MemoryLocation::get(Dst), + MemoryLocation::get(Src)) == MustAlias); + + // establish loop nesting levels + establishNestingLevels(Src, Dst); + + FullDependence Result(Src, Dst, false, CommonLevels); + + unsigned Pairs = 1; + SmallVector<Subscript, 2> Pair(Pairs); + const SCEV *SrcSCEV = SE->getSCEV(SrcPtr); + const SCEV *DstSCEV = SE->getSCEV(DstPtr); + Pair[0].Src = SrcSCEV; + Pair[0].Dst = DstSCEV; + + if (Delinearize) { + if (tryDelinearize(Src, Dst, Pair)) { + LLVM_DEBUG(dbgs() << " delinearized\n"); + Pairs = Pair.size(); + } + } + + for (unsigned P = 0; P < Pairs; ++P) { + Pair[P].Loops.resize(MaxLevels + 1); + Pair[P].GroupLoops.resize(MaxLevels + 1); + Pair[P].Group.resize(Pairs); + removeMatchingExtensions(&Pair[P]); + Pair[P].Classification = + classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()), + Pair[P].Dst, LI->getLoopFor(Dst->getParent()), + Pair[P].Loops); + Pair[P].GroupLoops = Pair[P].Loops; + Pair[P].Group.set(P); + } + + SmallBitVector Separable(Pairs); + SmallBitVector Coupled(Pairs); + + // partition subscripts into separable and minimally-coupled groups + for (unsigned SI = 0; SI < Pairs; ++SI) { + if (Pair[SI].Classification == Subscript::NonLinear) { + // ignore these, but collect loops for later + collectCommonLoops(Pair[SI].Src, + LI->getLoopFor(Src->getParent()), + Pair[SI].Loops); + collectCommonLoops(Pair[SI].Dst, + LI->getLoopFor(Dst->getParent()), + Pair[SI].Loops); + Result.Consistent = false; + } + else if (Pair[SI].Classification == Subscript::ZIV) + Separable.set(SI); + else { + // SIV, RDIV, or MIV, so check for coupled group + bool Done = true; + for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) { + SmallBitVector Intersection = Pair[SI].GroupLoops; + Intersection &= Pair[SJ].GroupLoops; + if (Intersection.any()) { + // accumulate set of all the loops in group + Pair[SJ].GroupLoops |= Pair[SI].GroupLoops; + // accumulate set of all subscripts in group + Pair[SJ].Group |= Pair[SI].Group; + Done = false; + } + } + if (Done) { + if (Pair[SI].Group.count() == 1) + Separable.set(SI); + else + Coupled.set(SI); + } + } + } + + Constraint NewConstraint; + NewConstraint.setAny(SE); + + // test separable subscripts + for (unsigned SI : Separable.set_bits()) { + switch (Pair[SI].Classification) { + case Subscript::SIV: { + unsigned Level; + const SCEV *SplitIter = nullptr; + (void) testSIV(Pair[SI].Src, Pair[SI].Dst, Level, + Result, NewConstraint, SplitIter); + if (Level == SplitLevel) { + assert(SplitIter != nullptr); + return SplitIter; + } + break; + } + case Subscript::ZIV: + case Subscript::RDIV: + case Subscript::MIV: + break; + default: + llvm_unreachable("subscript has unexpected classification"); + } + } + + if (Coupled.count()) { + // test coupled subscript groups + SmallVector<Constraint, 4> Constraints(MaxLevels + 1); + for (unsigned II = 0; II <= MaxLevels; ++II) + Constraints[II].setAny(SE); + for (unsigned SI : Coupled.set_bits()) { + SmallBitVector Group(Pair[SI].Group); + SmallBitVector Sivs(Pairs); + SmallBitVector Mivs(Pairs); + SmallBitVector ConstrainedLevels(MaxLevels + 1); + for (unsigned SJ : Group.set_bits()) { + if (Pair[SJ].Classification == Subscript::SIV) + Sivs.set(SJ); + else + Mivs.set(SJ); + } + while (Sivs.any()) { + bool Changed = false; + for (unsigned SJ : Sivs.set_bits()) { + // SJ is an SIV subscript that's part of the current coupled group + unsigned Level; + const SCEV *SplitIter = nullptr; + (void) testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level, + Result, NewConstraint, SplitIter); + if (Level == SplitLevel && SplitIter) + return SplitIter; + ConstrainedLevels.set(Level); + if (intersectConstraints(&Constraints[Level], &NewConstraint)) + Changed = true; + Sivs.reset(SJ); + } + if (Changed) { + // propagate, possibly creating new SIVs and ZIVs + for (unsigned SJ : Mivs.set_bits()) { + // SJ is an MIV subscript that's part of the current coupled group + if (propagate(Pair[SJ].Src, Pair[SJ].Dst, + Pair[SJ].Loops, Constraints, Result.Consistent)) { + Pair[SJ].Classification = + classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()), + Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()), + Pair[SJ].Loops); + switch (Pair[SJ].Classification) { + case Subscript::ZIV: + Mivs.reset(SJ); + break; + case Subscript::SIV: + Sivs.set(SJ); + Mivs.reset(SJ); + break; + case Subscript::RDIV: + case Subscript::MIV: + break; + default: + llvm_unreachable("bad subscript classification"); + } + } + } + } + } + } + } + llvm_unreachable("somehow reached end of routine"); + return nullptr; +} diff --git a/llvm/lib/Analysis/DependenceGraphBuilder.cpp b/llvm/lib/Analysis/DependenceGraphBuilder.cpp new file mode 100644 index 000000000000..ed1d8351b2f0 --- /dev/null +++ b/llvm/lib/Analysis/DependenceGraphBuilder.cpp @@ -0,0 +1,228 @@ +//===- DependenceGraphBuilder.cpp ------------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file implements common steps of the build algorithm for construction +// of dependence graphs such as DDG and PDG. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DependenceGraphBuilder.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DDG.h" + +using namespace llvm; + +#define DEBUG_TYPE "dgb" + +STATISTIC(TotalGraphs, "Number of dependence graphs created."); +STATISTIC(TotalDefUseEdges, "Number of def-use edges created."); +STATISTIC(TotalMemoryEdges, "Number of memory dependence edges created."); +STATISTIC(TotalFineGrainedNodes, "Number of fine-grained nodes created."); +STATISTIC(TotalConfusedEdges, + "Number of confused memory dependencies between two nodes."); +STATISTIC(TotalEdgeReversals, + "Number of times the source and sink of dependence was reversed to " + "expose cycles in the graph."); + +using InstructionListType = SmallVector<Instruction *, 2>; + +//===--------------------------------------------------------------------===// +// AbstractDependenceGraphBuilder implementation +//===--------------------------------------------------------------------===// + +template <class G> +void AbstractDependenceGraphBuilder<G>::createFineGrainedNodes() { + ++TotalGraphs; + assert(IMap.empty() && "Expected empty instruction map at start"); + for (BasicBlock *BB : BBList) + for (Instruction &I : *BB) { + auto &NewNode = createFineGrainedNode(I); + IMap.insert(std::make_pair(&I, &NewNode)); + ++TotalFineGrainedNodes; + } +} + +template <class G> +void AbstractDependenceGraphBuilder<G>::createAndConnectRootNode() { + // Create a root node that connects to every connected component of the graph. + // This is done to allow graph iterators to visit all the disjoint components + // of the graph, in a single walk. + // + // This algorithm works by going through each node of the graph and for each + // node N, do a DFS starting from N. A rooted edge is established between the + // root node and N (if N is not yet visited). All the nodes reachable from N + // are marked as visited and are skipped in the DFS of subsequent nodes. + // + // Note: This algorithm tries to limit the number of edges out of the root + // node to some extent, but there may be redundant edges created depending on + // the iteration order. For example for a graph {A -> B}, an edge from the + // root node is added to both nodes if B is visited before A. While it does + // not result in minimal number of edges, this approach saves compile-time + // while keeping the number of edges in check. + auto &RootNode = createRootNode(); + df_iterator_default_set<const NodeType *, 4> Visited; + for (auto *N : Graph) { + if (*N == RootNode) + continue; + for (auto I : depth_first_ext(N, Visited)) + if (I == N) + createRootedEdge(RootNode, *N); + } +} + +template <class G> void AbstractDependenceGraphBuilder<G>::createDefUseEdges() { + for (NodeType *N : Graph) { + InstructionListType SrcIList; + N->collectInstructions([](const Instruction *I) { return true; }, SrcIList); + + // Use a set to mark the targets that we link to N, so we don't add + // duplicate def-use edges when more than one instruction in a target node + // use results of instructions that are contained in N. + SmallPtrSet<NodeType *, 4> VisitedTargets; + + for (Instruction *II : SrcIList) { + for (User *U : II->users()) { + Instruction *UI = dyn_cast<Instruction>(U); + if (!UI) + continue; + NodeType *DstNode = nullptr; + if (IMap.find(UI) != IMap.end()) + DstNode = IMap.find(UI)->second; + + // In the case of loops, the scope of the subgraph is all the + // basic blocks (and instructions within them) belonging to the loop. We + // simply ignore all the edges coming from (or going into) instructions + // or basic blocks outside of this range. + if (!DstNode) { + LLVM_DEBUG( + dbgs() + << "skipped def-use edge since the sink" << *UI + << " is outside the range of instructions being considered.\n"); + continue; + } + + // Self dependencies are ignored because they are redundant and + // uninteresting. + if (DstNode == N) { + LLVM_DEBUG(dbgs() + << "skipped def-use edge since the sink and the source (" + << N << ") are the same.\n"); + continue; + } + + if (VisitedTargets.insert(DstNode).second) { + createDefUseEdge(*N, *DstNode); + ++TotalDefUseEdges; + } + } + } + } +} + +template <class G> +void AbstractDependenceGraphBuilder<G>::createMemoryDependencyEdges() { + using DGIterator = typename G::iterator; + auto isMemoryAccess = [](const Instruction *I) { + return I->mayReadOrWriteMemory(); + }; + for (DGIterator SrcIt = Graph.begin(), E = Graph.end(); SrcIt != E; ++SrcIt) { + InstructionListType SrcIList; + (*SrcIt)->collectInstructions(isMemoryAccess, SrcIList); + if (SrcIList.empty()) + continue; + + for (DGIterator DstIt = SrcIt; DstIt != E; ++DstIt) { + if (**SrcIt == **DstIt) + continue; + InstructionListType DstIList; + (*DstIt)->collectInstructions(isMemoryAccess, DstIList); + if (DstIList.empty()) + continue; + bool ForwardEdgeCreated = false; + bool BackwardEdgeCreated = false; + for (Instruction *ISrc : SrcIList) { + for (Instruction *IDst : DstIList) { + auto D = DI.depends(ISrc, IDst, true); + if (!D) + continue; + + // If we have a dependence with its left-most non-'=' direction + // being '>' we need to reverse the direction of the edge, because + // the source of the dependence cannot occur after the sink. For + // confused dependencies, we will create edges in both directions to + // represent the possibility of a cycle. + + auto createConfusedEdges = [&](NodeType &Src, NodeType &Dst) { + if (!ForwardEdgeCreated) { + createMemoryEdge(Src, Dst); + ++TotalMemoryEdges; + } + if (!BackwardEdgeCreated) { + createMemoryEdge(Dst, Src); + ++TotalMemoryEdges; + } + ForwardEdgeCreated = BackwardEdgeCreated = true; + ++TotalConfusedEdges; + }; + + auto createForwardEdge = [&](NodeType &Src, NodeType &Dst) { + if (!ForwardEdgeCreated) { + createMemoryEdge(Src, Dst); + ++TotalMemoryEdges; + } + ForwardEdgeCreated = true; + }; + + auto createBackwardEdge = [&](NodeType &Src, NodeType &Dst) { + if (!BackwardEdgeCreated) { + createMemoryEdge(Dst, Src); + ++TotalMemoryEdges; + } + BackwardEdgeCreated = true; + }; + + if (D->isConfused()) + createConfusedEdges(**SrcIt, **DstIt); + else if (D->isOrdered() && !D->isLoopIndependent()) { + bool ReversedEdge = false; + for (unsigned Level = 1; Level <= D->getLevels(); ++Level) { + if (D->getDirection(Level) == Dependence::DVEntry::EQ) + continue; + else if (D->getDirection(Level) == Dependence::DVEntry::GT) { + createBackwardEdge(**SrcIt, **DstIt); + ReversedEdge = true; + ++TotalEdgeReversals; + break; + } else if (D->getDirection(Level) == Dependence::DVEntry::LT) + break; + else { + createConfusedEdges(**SrcIt, **DstIt); + break; + } + } + if (!ReversedEdge) + createForwardEdge(**SrcIt, **DstIt); + } else + createForwardEdge(**SrcIt, **DstIt); + + // Avoid creating duplicate edges. + if (ForwardEdgeCreated && BackwardEdgeCreated) + break; + } + + // If we've created edges in both directions, there is no more + // unique edge that we can create between these two nodes, so we + // can exit early. + if (ForwardEdgeCreated && BackwardEdgeCreated) + break; + } + } + } +} + +template class llvm::AbstractDependenceGraphBuilder<DataDependenceGraph>; +template class llvm::DependenceGraphInfo<DDGNode>; diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp new file mode 100644 index 000000000000..3d1be1e1cce0 --- /dev/null +++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -0,0 +1,466 @@ +//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a general divergence analysis for loop vectorization +// and GPU programs. It determines which branches and values in a loop or GPU +// program are divergent. It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same execution group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they re-converge. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can interfere with thread re-convergence. +// Therefore, an analysis that computes which branches in a GPU program are +// divergent can help the compiler to selectively run these optimizations. +// +// This implementation is derived from the Vectorization Analysis of the +// Region Vectorizer (RV). That implementation in turn is based on the approach +// described in +// +// Improving Performance of OpenCL on CPUs +// Ralf Karrenberg and Sebastian Hack +// CC '12 +// +// This DivergenceAnalysis implementation is generic in the sense that it does +// not itself identify original sources of divergence. +// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and +// (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence +// (e.g., special variables that hold the thread ID or the iteration variable). +// +// The generic implementation propagates divergence to variables that are data +// or sync dependent on a source of divergence. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// The sync dependence detection (which branch induces divergence in which join +// points) is implemented in the SyncDependenceAnalysis. +// +// The current DivergenceAnalysis implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +// non-kernel-entry function and the return value of a function call as +// divergent. +// 2. memory as black box. It conservatively considers values loaded from +// generic or local address as divergent. This can be improved by leveraging +// pointer analysis and/or by modelling non-escaping memory objects in SSA +// as done in RV. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "divergence-analysis" + +// class DivergenceAnalysis +DivergenceAnalysis::DivergenceAnalysis( + const Function &F, const Loop *RegionLoop, const DominatorTree &DT, + const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) + : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), + IsLCSSAForm(IsLCSSAForm) {} + +void DivergenceAnalysis::markDivergent(const Value &DivVal) { + assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); + assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); + DivergentValues.insert(&DivVal); +} + +void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { + UniformOverrides.insert(&UniVal); +} + +bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const { + if (Term.getNumSuccessors() <= 1) + return false; + if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) { + assert(BranchTerm->isConditional()); + return isDivergent(*BranchTerm->getCondition()); + } + if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) { + return isDivergent(*SwitchTerm->getCondition()); + } + if (isa<InvokeInst>(Term)) { + return false; // ignore abnormal executions through landingpad + } + + llvm_unreachable("unexpected terminator"); +} + +bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { + // TODO function calls with side effects, etc + for (const auto &Op : I.operands()) { + if (isDivergent(*Op)) + return true; + } + return false; +} + +bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, + const Value &Val) const { + const auto *Inst = dyn_cast<const Instruction>(&Val); + if (!Inst) + return false; + // check whether any divergent loop carrying Val terminates before control + // proceeds to ObservingBlock + for (const auto *Loop = LI.getLoopFor(Inst->getParent()); + Loop != RegionLoop && !Loop->contains(&ObservingBlock); + Loop = Loop->getParentLoop()) { + if (DivergentLoops.find(Loop) != DivergentLoops.end()) + return true; + } + + return false; +} + +bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { + // joining divergent disjoint path in Phi parent block + if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { + return true; + } + + // An incoming value could be divergent by itself. + // Otherwise, an incoming value could be uniform within the loop + // that carries its definition but it may appear divergent + // from outside the loop. This happens when divergent loop exits + // drop definitions of that uniform value in different iterations. + // + // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop + // if (i % thread_id == 0) break; // divergent loop exit + // } + // int divI = i; // divI is divergent + for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { + const auto *InVal = Phi.getIncomingValue(i); + if (isDivergent(*Phi.getIncomingValue(i)) || + isTemporalDivergent(*Phi.getParent(), *InVal)) { + return true; + } + } + return false; +} + +bool DivergenceAnalysis::inRegion(const Instruction &I) const { + return I.getParent() && inRegion(*I.getParent()); +} + +bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { + return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); +} + +// marks all users of loop-carried values of the loop headed by LoopHeader as +// divergent +void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { + auto *DivLoop = LI.getLoopFor(&LoopHeader); + assert(DivLoop && "loopHeader is not actually part of a loop"); + + SmallVector<BasicBlock *, 8> TaintStack; + DivLoop->getExitBlocks(TaintStack); + + // Otherwise potential users of loop-carried values could be anywhere in the + // dominance region of DivLoop (including its fringes for phi nodes) + DenseSet<const BasicBlock *> Visited; + for (auto *Block : TaintStack) { + Visited.insert(Block); + } + Visited.insert(&LoopHeader); + + while (!TaintStack.empty()) { + auto *UserBlock = TaintStack.back(); + TaintStack.pop_back(); + + // don't spread divergence beyond the region + if (!inRegion(*UserBlock)) + continue; + + assert(!DivLoop->contains(UserBlock) && + "irreducible control flow detected"); + + // phi nodes at the fringes of the dominance region + if (!DT.dominates(&LoopHeader, UserBlock)) { + // all PHI nodes of UserBlock become divergent + for (auto &Phi : UserBlock->phis()) { + Worklist.push_back(&Phi); + } + continue; + } + + // taint outside users of values carried by DivLoop + for (auto &I : *UserBlock) { + if (isAlwaysUniform(I)) + continue; + if (isDivergent(I)) + continue; + + for (auto &Op : I.operands()) { + auto *OpInst = dyn_cast<Instruction>(&Op); + if (!OpInst) + continue; + if (DivLoop->contains(OpInst->getParent())) { + markDivergent(I); + pushUsers(I); + break; + } + } + } + + // visit all blocks in the dominance region + for (auto *SuccBlock : successors(UserBlock)) { + if (!Visited.insert(SuccBlock).second) { + continue; + } + TaintStack.push_back(SuccBlock); + } + } +} + +void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { + for (const auto &Phi : Block.phis()) { + if (isDivergent(Phi)) + continue; + Worklist.push_back(&Phi); + } +} + +void DivergenceAnalysis::pushUsers(const Value &V) { + for (const auto *User : V.users()) { + const auto *UserInst = dyn_cast<const Instruction>(User); + if (!UserInst) + continue; + + if (isDivergent(*UserInst)) + continue; + + // only compute divergent inside loop + if (!inRegion(*UserInst)) + continue; + Worklist.push_back(UserInst); + } +} + +bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, + const Loop *BranchLoop) { + LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); + + // ignore divergence outside the region + if (!inRegion(JoinBlock)) { + return false; + } + + // push non-divergent phi nodes in JoinBlock to the worklist + pushPHINodes(JoinBlock); + + // JoinBlock is a divergent loop exit + if (BranchLoop && !BranchLoop->contains(&JoinBlock)) { + return true; + } + + // disjoint-paths divergent at JoinBlock + markBlockJoinDivergent(JoinBlock); + return false; +} + +void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { + LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); + + markDivergent(Term); + + const auto *BranchLoop = LI.getLoopFor(Term.getParent()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint from Term within the loop + // also iterates over loop exits that become divergent due to Term. + for (const auto *JoinBlock : SDA.join_blocks(Term)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } + + // Branch loop is a divergent loop due to the divergent branch in Term + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; + } + propagateLoopDivergence(*BranchLoop); + } +} + +void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { + LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); + + // don't propagate beyond region + if (!inRegion(*ExitingLoop.getHeader())) + return; + + const auto *BranchLoop = ExitingLoop.getParentLoop(); + + // Uses of loop-carried values could occur anywhere + // within the dominance region of the definition. All loop-carried + // definitions are dominated by the loop header (reducible control). + // Thus all users have to be in the dominance region of the loop header, + // except PHI nodes that can also live at the fringes of the dom region + // (incoming defining value). + if (!IsLCSSAForm) + taintLoopLiveOuts(*ExitingLoop.getHeader()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint paths from exits of + // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn + // become divergent. + for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } + + // Branch loop is a divergent due to divergent loop exit in ExitingLoop + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; + } + propagateLoopDivergence(*BranchLoop); + } +} + +void DivergenceAnalysis::compute() { + for (auto *DivVal : DivergentValues) { + pushUsers(*DivVal); + } + + // propagate divergence + while (!Worklist.empty()) { + const Instruction &I = *Worklist.back(); + Worklist.pop_back(); + + // maintain uniformity of overrides + if (isAlwaysUniform(I)) + continue; + + bool WasDivergent = isDivergent(I); + if (WasDivergent) + continue; + + // propagate divergence caused by terminator + if (I.isTerminator()) { + if (updateTerminator(I)) { + // propagate control divergence to affected instructions + propagateBranchDivergence(I); + continue; + } + } + + // update divergence of I due to divergent operands + bool DivergentUpd = false; + const auto *Phi = dyn_cast<const PHINode>(&I); + if (Phi) { + DivergentUpd = updatePHINode(*Phi); + } else { + DivergentUpd = updateNormalInstruction(I); + } + + // propagate value divergence to users + if (DivergentUpd) { + markDivergent(I); + pushUsers(I); + } + } +} + +bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { + return UniformOverrides.find(&V) != UniformOverrides.end(); +} + +bool DivergenceAnalysis::isDivergent(const Value &V) const { + return DivergentValues.find(&V) != DivergentValues.end(); +} + +bool DivergenceAnalysis::isDivergentUse(const Use &U) const { + Value &V = *U.get(); + Instruction &I = *cast<Instruction>(U.getUser()); + return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); +} + +void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { + if (DivergentValues.empty()) + return; + // iterate instructions using instructions() to ensure a deterministic order. + for (auto &I : instructions(F)) { + if (isDivergent(I)) + OS << "DIVERGENT:" << I << '\n'; + } +} + +// class GPUDivergenceAnalysis +GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, + const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI, + const TargetTransformInfo &TTI) + : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) { + for (auto &I : instructions(F)) { + if (TTI.isSourceOfDivergence(&I)) { + DA.markDivergent(I); + } else if (TTI.isAlwaysUniform(&I)) { + DA.addUniformOverride(I); + } + } + for (auto &Arg : F.args()) { + if (TTI.isSourceOfDivergence(&Arg)) { + DA.markDivergent(Arg); + } + } + + DA.compute(); +} + +bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { + return DA.isDivergent(val); +} + +bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const { + return DA.isDivergentUse(use); +} + +void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { + OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; + DA.print(OS, mod); + OS << "}\n"; +} diff --git a/llvm/lib/Analysis/DomPrinter.cpp b/llvm/lib/Analysis/DomPrinter.cpp new file mode 100644 index 000000000000..d9f43dd746ef --- /dev/null +++ b/llvm/lib/Analysis/DomPrinter.cpp @@ -0,0 +1,297 @@ +//===- DomPrinter.cpp - DOT printer for the dominance trees ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines '-dot-dom' and '-dot-postdom' analysis passes, which emit +// a dom.<fnname>.dot or postdom.<fnname>.dot file for each function in the +// program, with a graph of the dominance/postdominance tree of that +// function. +// +// There are also passes available to directly call dotty ('-view-dom' or +// '-view-postdom'). By appending '-only' like '-dot-dom-only' only the +// names of the bbs are printed, but the content is hidden. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DomPrinter.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" +#include "llvm/Analysis/PostDominators.h" + +using namespace llvm; + +namespace llvm { +template<> +struct DOTGraphTraits<DomTreeNode*> : public DefaultDOTGraphTraits { + + DOTGraphTraits (bool isSimple=false) + : DefaultDOTGraphTraits(isSimple) {} + + std::string getNodeLabel(DomTreeNode *Node, DomTreeNode *Graph) { + + BasicBlock *BB = Node->getBlock(); + + if (!BB) + return "Post dominance root node"; + + + if (isSimple()) + return DOTGraphTraits<const Function*> + ::getSimpleNodeLabel(BB, BB->getParent()); + else + return DOTGraphTraits<const Function*> + ::getCompleteNodeLabel(BB, BB->getParent()); + } +}; + +template<> +struct DOTGraphTraits<DominatorTree*> : public DOTGraphTraits<DomTreeNode*> { + + DOTGraphTraits (bool isSimple=false) + : DOTGraphTraits<DomTreeNode*>(isSimple) {} + + static std::string getGraphName(DominatorTree *DT) { + return "Dominator tree"; + } + + std::string getNodeLabel(DomTreeNode *Node, DominatorTree *G) { + return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); + } +}; + +template<> +struct DOTGraphTraits<PostDominatorTree*> + : public DOTGraphTraits<DomTreeNode*> { + + DOTGraphTraits (bool isSimple=false) + : DOTGraphTraits<DomTreeNode*>(isSimple) {} + + static std::string getGraphName(PostDominatorTree *DT) { + return "Post dominator tree"; + } + + std::string getNodeLabel(DomTreeNode *Node, PostDominatorTree *G ) { + return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); + } +}; +} + +void DominatorTree::viewGraph(const Twine &Name, const Twine &Title) { +#ifndef NDEBUG + ViewGraph(this, Name, false, Title); +#else + errs() << "DomTree dump not available, build with DEBUG\n"; +#endif // NDEBUG +} + +void DominatorTree::viewGraph() { +#ifndef NDEBUG + this->viewGraph("domtree", "Dominator Tree for function"); +#else + errs() << "DomTree dump not available, build with DEBUG\n"; +#endif // NDEBUG +} + +namespace { +struct DominatorTreeWrapperPassAnalysisGraphTraits { + static DominatorTree *getGraph(DominatorTreeWrapperPass *DTWP) { + return &DTWP->getDomTree(); + } +}; + +struct DomViewer : public DOTGraphTraitsViewer< + DominatorTreeWrapperPass, false, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + DomViewer() + : DOTGraphTraitsViewer<DominatorTreeWrapperPass, false, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits>( + "dom", ID) { + initializeDomViewerPass(*PassRegistry::getPassRegistry()); + } +}; + +struct DomOnlyViewer : public DOTGraphTraitsViewer< + DominatorTreeWrapperPass, true, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + DomOnlyViewer() + : DOTGraphTraitsViewer<DominatorTreeWrapperPass, true, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits>( + "domonly", ID) { + initializeDomOnlyViewerPass(*PassRegistry::getPassRegistry()); + } +}; + +struct PostDominatorTreeWrapperPassAnalysisGraphTraits { + static PostDominatorTree *getGraph(PostDominatorTreeWrapperPass *PDTWP) { + return &PDTWP->getPostDomTree(); + } +}; + +struct PostDomViewer : public DOTGraphTraitsViewer< + PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + PostDomViewer() : + DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdom", ID){ + initializePostDomViewerPass(*PassRegistry::getPassRegistry()); + } +}; + +struct PostDomOnlyViewer : public DOTGraphTraitsViewer< + PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + PostDomOnlyViewer() : + DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID){ + initializePostDomOnlyViewerPass(*PassRegistry::getPassRegistry()); + } +}; +} // end anonymous namespace + +char DomViewer::ID = 0; +INITIALIZE_PASS(DomViewer, "view-dom", + "View dominance tree of function", false, false) + +char DomOnlyViewer::ID = 0; +INITIALIZE_PASS(DomOnlyViewer, "view-dom-only", + "View dominance tree of function (with no function bodies)", + false, false) + +char PostDomViewer::ID = 0; +INITIALIZE_PASS(PostDomViewer, "view-postdom", + "View postdominance tree of function", false, false) + +char PostDomOnlyViewer::ID = 0; +INITIALIZE_PASS(PostDomOnlyViewer, "view-postdom-only", + "View postdominance tree of function " + "(with no function bodies)", + false, false) + +namespace { +struct DomPrinter : public DOTGraphTraitsPrinter< + DominatorTreeWrapperPass, false, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + DomPrinter() + : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, false, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits>( + "dom", ID) { + initializeDomPrinterPass(*PassRegistry::getPassRegistry()); + } +}; + +struct DomOnlyPrinter : public DOTGraphTraitsPrinter< + DominatorTreeWrapperPass, true, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + DomOnlyPrinter() + : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, true, DominatorTree *, + DominatorTreeWrapperPassAnalysisGraphTraits>( + "domonly", ID) { + initializeDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); + } +}; + +struct PostDomPrinter + : public DOTGraphTraitsPrinter< + PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + PostDomPrinter() : + DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdom", ID) { + initializePostDomPrinterPass(*PassRegistry::getPassRegistry()); + } +}; + +struct PostDomOnlyPrinter + : public DOTGraphTraitsPrinter< + PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { + static char ID; + PostDomOnlyPrinter() : + DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID) { + initializePostDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); + } +}; +} // end anonymous namespace + + + +char DomPrinter::ID = 0; +INITIALIZE_PASS(DomPrinter, "dot-dom", + "Print dominance tree of function to 'dot' file", + false, false) + +char DomOnlyPrinter::ID = 0; +INITIALIZE_PASS(DomOnlyPrinter, "dot-dom-only", + "Print dominance tree of function to 'dot' file " + "(with no function bodies)", + false, false) + +char PostDomPrinter::ID = 0; +INITIALIZE_PASS(PostDomPrinter, "dot-postdom", + "Print postdominance tree of function to 'dot' file", + false, false) + +char PostDomOnlyPrinter::ID = 0; +INITIALIZE_PASS(PostDomOnlyPrinter, "dot-postdom-only", + "Print postdominance tree of function to 'dot' file " + "(with no function bodies)", + false, false) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +FunctionPass *llvm::createDomPrinterPass() { + return new DomPrinter(); +} + +FunctionPass *llvm::createDomOnlyPrinterPass() { + return new DomOnlyPrinter(); +} + +FunctionPass *llvm::createDomViewerPass() { + return new DomViewer(); +} + +FunctionPass *llvm::createDomOnlyViewerPass() { + return new DomOnlyViewer(); +} + +FunctionPass *llvm::createPostDomPrinterPass() { + return new PostDomPrinter(); +} + +FunctionPass *llvm::createPostDomOnlyPrinterPass() { + return new PostDomOnlyPrinter(); +} + +FunctionPass *llvm::createPostDomViewerPass() { + return new PostDomViewer(); +} + +FunctionPass *llvm::createPostDomOnlyViewerPass() { + return new PostDomOnlyViewer(); +} diff --git a/llvm/lib/Analysis/DomTreeUpdater.cpp b/llvm/lib/Analysis/DomTreeUpdater.cpp new file mode 100644 index 000000000000..49215889cfd6 --- /dev/null +++ b/llvm/lib/Analysis/DomTreeUpdater.cpp @@ -0,0 +1,533 @@ +//===- DomTreeUpdater.cpp - DomTree/Post DomTree Updater --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the DomTreeUpdater class, which provides a uniform way +// to update dominator tree related data structures. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Support/GenericDomTree.h" +#include <algorithm> +#include <functional> +#include <utility> + +namespace llvm { + +bool DomTreeUpdater::isUpdateValid( + const DominatorTree::UpdateType Update) const { + const auto *From = Update.getFrom(); + const auto *To = Update.getTo(); + const auto Kind = Update.getKind(); + + // Discard updates by inspecting the current state of successors of From. + // Since isUpdateValid() must be called *after* the Terminator of From is + // altered we can determine if the update is unnecessary for batch updates + // or invalid for a single update. + const bool HasEdge = llvm::any_of( + successors(From), [To](const BasicBlock *B) { return B == To; }); + + // If the IR does not match the update, + // 1. In batch updates, this update is unnecessary. + // 2. When called by insertEdge*()/deleteEdge*(), this update is invalid. + // Edge does not exist in IR. + if (Kind == DominatorTree::Insert && !HasEdge) + return false; + + // Edge exists in IR. + if (Kind == DominatorTree::Delete && HasEdge) + return false; + + return true; +} + +bool DomTreeUpdater::isSelfDominance( + const DominatorTree::UpdateType Update) const { + // Won't affect DomTree and PostDomTree. + return Update.getFrom() == Update.getTo(); +} + +void DomTreeUpdater::applyDomTreeUpdates() { + // No pending DomTreeUpdates. + if (Strategy != UpdateStrategy::Lazy || !DT) + return; + + // Only apply updates not are applied by DomTree. + if (hasPendingDomTreeUpdates()) { + const auto I = PendUpdates.begin() + PendDTUpdateIndex; + const auto E = PendUpdates.end(); + assert(I < E && "Iterator range invalid; there should be DomTree updates."); + DT->applyUpdates(ArrayRef<DominatorTree::UpdateType>(I, E)); + PendDTUpdateIndex = PendUpdates.size(); + } +} + +void DomTreeUpdater::flush() { + applyDomTreeUpdates(); + applyPostDomTreeUpdates(); + dropOutOfDateUpdates(); +} + +void DomTreeUpdater::applyPostDomTreeUpdates() { + // No pending PostDomTreeUpdates. + if (Strategy != UpdateStrategy::Lazy || !PDT) + return; + + // Only apply updates not are applied by PostDomTree. + if (hasPendingPostDomTreeUpdates()) { + const auto I = PendUpdates.begin() + PendPDTUpdateIndex; + const auto E = PendUpdates.end(); + assert(I < E && + "Iterator range invalid; there should be PostDomTree updates."); + PDT->applyUpdates(ArrayRef<DominatorTree::UpdateType>(I, E)); + PendPDTUpdateIndex = PendUpdates.size(); + } +} + +void DomTreeUpdater::tryFlushDeletedBB() { + if (!hasPendingUpdates()) + forceFlushDeletedBB(); +} + +bool DomTreeUpdater::forceFlushDeletedBB() { + if (DeletedBBs.empty()) + return false; + + for (auto *BB : DeletedBBs) { + // After calling deleteBB or callbackDeleteBB under Lazy UpdateStrategy, + // validateDeleteBB() removes all instructions of DelBB and adds an + // UnreachableInst as its terminator. So we check whether the BasicBlock to + // delete only has an UnreachableInst inside. + assert(BB->getInstList().size() == 1 && + isa<UnreachableInst>(BB->getTerminator()) && + "DelBB has been modified while awaiting deletion."); + BB->removeFromParent(); + eraseDelBBNode(BB); + delete BB; + } + DeletedBBs.clear(); + Callbacks.clear(); + return true; +} + +void DomTreeUpdater::recalculate(Function &F) { + + if (Strategy == UpdateStrategy::Eager) { + if (DT) + DT->recalculate(F); + if (PDT) + PDT->recalculate(F); + return; + } + + // There is little performance gain if we pend the recalculation under + // Lazy UpdateStrategy so we recalculate available trees immediately. + + // Prevent forceFlushDeletedBB() from erasing DomTree or PostDomTree nodes. + IsRecalculatingDomTree = IsRecalculatingPostDomTree = true; + + // Because all trees are going to be up-to-date after recalculation, + // flush awaiting deleted BasicBlocks. + forceFlushDeletedBB(); + if (DT) + DT->recalculate(F); + if (PDT) + PDT->recalculate(F); + + // Resume forceFlushDeletedBB() to erase DomTree or PostDomTree nodes. + IsRecalculatingDomTree = IsRecalculatingPostDomTree = false; + PendDTUpdateIndex = PendPDTUpdateIndex = PendUpdates.size(); + dropOutOfDateUpdates(); +} + +bool DomTreeUpdater::hasPendingUpdates() const { + return hasPendingDomTreeUpdates() || hasPendingPostDomTreeUpdates(); +} + +bool DomTreeUpdater::hasPendingDomTreeUpdates() const { + if (!DT) + return false; + return PendUpdates.size() != PendDTUpdateIndex; +} + +bool DomTreeUpdater::hasPendingPostDomTreeUpdates() const { + if (!PDT) + return false; + return PendUpdates.size() != PendPDTUpdateIndex; +} + +bool DomTreeUpdater::isBBPendingDeletion(llvm::BasicBlock *DelBB) const { + if (Strategy == UpdateStrategy::Eager || DeletedBBs.empty()) + return false; + return DeletedBBs.count(DelBB) != 0; +} + +// The DT and PDT require the nodes related to updates +// are not deleted when update functions are called. +// So BasicBlock deletions must be pended when the +// UpdateStrategy is Lazy. When the UpdateStrategy is +// Eager, the BasicBlock will be deleted immediately. +void DomTreeUpdater::deleteBB(BasicBlock *DelBB) { + validateDeleteBB(DelBB); + if (Strategy == UpdateStrategy::Lazy) { + DeletedBBs.insert(DelBB); + return; + } + + DelBB->removeFromParent(); + eraseDelBBNode(DelBB); + delete DelBB; +} + +void DomTreeUpdater::callbackDeleteBB( + BasicBlock *DelBB, std::function<void(BasicBlock *)> Callback) { + validateDeleteBB(DelBB); + if (Strategy == UpdateStrategy::Lazy) { + Callbacks.push_back(CallBackOnDeletion(DelBB, Callback)); + DeletedBBs.insert(DelBB); + return; + } + + DelBB->removeFromParent(); + eraseDelBBNode(DelBB); + Callback(DelBB); + delete DelBB; +} + +void DomTreeUpdater::eraseDelBBNode(BasicBlock *DelBB) { + if (DT && !IsRecalculatingDomTree) + if (DT->getNode(DelBB)) + DT->eraseNode(DelBB); + + if (PDT && !IsRecalculatingPostDomTree) + if (PDT->getNode(DelBB)) + PDT->eraseNode(DelBB); +} + +void DomTreeUpdater::validateDeleteBB(BasicBlock *DelBB) { + assert(DelBB && "Invalid push_back of nullptr DelBB."); + assert(pred_empty(DelBB) && "DelBB has one or more predecessors."); + // DelBB is unreachable and all its instructions are dead. + while (!DelBB->empty()) { + Instruction &I = DelBB->back(); + // Replace used instructions with an arbitrary value (undef). + if (!I.use_empty()) + I.replaceAllUsesWith(llvm::UndefValue::get(I.getType())); + DelBB->getInstList().pop_back(); + } + // Make sure DelBB has a valid terminator instruction. As long as DelBB is a + // Child of Function F it must contain valid IR. + new UnreachableInst(DelBB->getContext(), DelBB); +} + +void DomTreeUpdater::applyUpdates(ArrayRef<DominatorTree::UpdateType> Updates) { + if (!DT && !PDT) + return; + + if (Strategy == UpdateStrategy::Lazy) { + for (const auto U : Updates) + if (!isSelfDominance(U)) + PendUpdates.push_back(U); + + return; + } + + if (DT) + DT->applyUpdates(Updates); + if (PDT) + PDT->applyUpdates(Updates); +} + +void DomTreeUpdater::applyUpdatesPermissive( + ArrayRef<DominatorTree::UpdateType> Updates) { + if (!DT && !PDT) + return; + + SmallSet<std::pair<BasicBlock *, BasicBlock *>, 8> Seen; + SmallVector<DominatorTree::UpdateType, 8> DeduplicatedUpdates; + for (const auto U : Updates) { + auto Edge = std::make_pair(U.getFrom(), U.getTo()); + // Because it is illegal to submit updates that have already been applied + // and updates to an edge need to be strictly ordered, + // it is safe to infer the existence of an edge from the first update + // to this edge. + // If the first update to an edge is "Delete", it means that the edge + // existed before. If the first update to an edge is "Insert", it means + // that the edge didn't exist before. + // + // For example, if the user submits {{Delete, A, B}, {Insert, A, B}}, + // because + // 1. it is illegal to submit updates that have already been applied, + // i.e., user cannot delete an nonexistent edge, + // 2. updates to an edge need to be strictly ordered, + // So, initially edge A -> B existed. + // We can then safely ignore future updates to this edge and directly + // inspect the current CFG: + // a. If the edge still exists, because the user cannot insert an existent + // edge, so both {Delete, A, B}, {Insert, A, B} actually happened and + // resulted in a no-op. DTU won't submit any update in this case. + // b. If the edge doesn't exist, we can then infer that {Delete, A, B} + // actually happened but {Insert, A, B} was an invalid update which never + // happened. DTU will submit {Delete, A, B} in this case. + if (!isSelfDominance(U) && Seen.count(Edge) == 0) { + Seen.insert(Edge); + // If the update doesn't appear in the CFG, it means that + // either the change isn't made or relevant operations + // result in a no-op. + if (isUpdateValid(U)) { + if (isLazy()) + PendUpdates.push_back(U); + else + DeduplicatedUpdates.push_back(U); + } + } + } + + if (Strategy == UpdateStrategy::Lazy) + return; + + if (DT) + DT->applyUpdates(DeduplicatedUpdates); + if (PDT) + PDT->applyUpdates(DeduplicatedUpdates); +} + +DominatorTree &DomTreeUpdater::getDomTree() { + assert(DT && "Invalid acquisition of a null DomTree"); + applyDomTreeUpdates(); + dropOutOfDateUpdates(); + return *DT; +} + +PostDominatorTree &DomTreeUpdater::getPostDomTree() { + assert(PDT && "Invalid acquisition of a null PostDomTree"); + applyPostDomTreeUpdates(); + dropOutOfDateUpdates(); + return *PDT; +} + +void DomTreeUpdater::insertEdge(BasicBlock *From, BasicBlock *To) { + +#ifndef NDEBUG + assert(isUpdateValid({DominatorTree::Insert, From, To}) && + "Inserted edge does not appear in the CFG"); +#endif + + if (!DT && !PDT) + return; + + // Won't affect DomTree and PostDomTree; discard update. + if (From == To) + return; + + if (Strategy == UpdateStrategy::Eager) { + if (DT) + DT->insertEdge(From, To); + if (PDT) + PDT->insertEdge(From, To); + return; + } + + PendUpdates.push_back({DominatorTree::Insert, From, To}); +} + +void DomTreeUpdater::insertEdgeRelaxed(BasicBlock *From, BasicBlock *To) { + if (From == To) + return; + + if (!DT && !PDT) + return; + + if (!isUpdateValid({DominatorTree::Insert, From, To})) + return; + + if (Strategy == UpdateStrategy::Eager) { + if (DT) + DT->insertEdge(From, To); + if (PDT) + PDT->insertEdge(From, To); + return; + } + + PendUpdates.push_back({DominatorTree::Insert, From, To}); +} + +void DomTreeUpdater::deleteEdge(BasicBlock *From, BasicBlock *To) { + +#ifndef NDEBUG + assert(isUpdateValid({DominatorTree::Delete, From, To}) && + "Deleted edge still exists in the CFG!"); +#endif + + if (!DT && !PDT) + return; + + // Won't affect DomTree and PostDomTree; discard update. + if (From == To) + return; + + if (Strategy == UpdateStrategy::Eager) { + if (DT) + DT->deleteEdge(From, To); + if (PDT) + PDT->deleteEdge(From, To); + return; + } + + PendUpdates.push_back({DominatorTree::Delete, From, To}); +} + +void DomTreeUpdater::deleteEdgeRelaxed(BasicBlock *From, BasicBlock *To) { + if (From == To) + return; + + if (!DT && !PDT) + return; + + if (!isUpdateValid({DominatorTree::Delete, From, To})) + return; + + if (Strategy == UpdateStrategy::Eager) { + if (DT) + DT->deleteEdge(From, To); + if (PDT) + PDT->deleteEdge(From, To); + return; + } + + PendUpdates.push_back({DominatorTree::Delete, From, To}); +} + +void DomTreeUpdater::dropOutOfDateUpdates() { + if (Strategy == DomTreeUpdater::UpdateStrategy::Eager) + return; + + tryFlushDeletedBB(); + + // Drop all updates applied by both trees. + if (!DT) + PendDTUpdateIndex = PendUpdates.size(); + if (!PDT) + PendPDTUpdateIndex = PendUpdates.size(); + + const size_t dropIndex = std::min(PendDTUpdateIndex, PendPDTUpdateIndex); + const auto B = PendUpdates.begin(); + const auto E = PendUpdates.begin() + dropIndex; + assert(B <= E && "Iterator out of range."); + PendUpdates.erase(B, E); + // Calculate current index. + PendDTUpdateIndex -= dropIndex; + PendPDTUpdateIndex -= dropIndex; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void DomTreeUpdater::dump() const { + raw_ostream &OS = llvm::dbgs(); + + OS << "Available Trees: "; + if (DT || PDT) { + if (DT) + OS << "DomTree "; + if (PDT) + OS << "PostDomTree "; + OS << "\n"; + } else + OS << "None\n"; + + OS << "UpdateStrategy: "; + if (Strategy == UpdateStrategy::Eager) { + OS << "Eager\n"; + return; + } else + OS << "Lazy\n"; + int Index = 0; + + auto printUpdates = + [&](ArrayRef<DominatorTree::UpdateType>::const_iterator begin, + ArrayRef<DominatorTree::UpdateType>::const_iterator end) { + if (begin == end) + OS << " None\n"; + Index = 0; + for (auto It = begin, ItEnd = end; It != ItEnd; ++It) { + auto U = *It; + OS << " " << Index << " : "; + ++Index; + if (U.getKind() == DominatorTree::Insert) + OS << "Insert, "; + else + OS << "Delete, "; + BasicBlock *From = U.getFrom(); + if (From) { + auto S = From->getName(); + if (!From->hasName()) + S = "(no name)"; + OS << S << "(" << From << "), "; + } else { + OS << "(badref), "; + } + BasicBlock *To = U.getTo(); + if (To) { + auto S = To->getName(); + if (!To->hasName()) + S = "(no_name)"; + OS << S << "(" << To << ")\n"; + } else { + OS << "(badref)\n"; + } + } + }; + + if (DT) { + const auto I = PendUpdates.begin() + PendDTUpdateIndex; + assert(PendUpdates.begin() <= I && I <= PendUpdates.end() && + "Iterator out of range."); + OS << "Applied but not cleared DomTreeUpdates:\n"; + printUpdates(PendUpdates.begin(), I); + OS << "Pending DomTreeUpdates:\n"; + printUpdates(I, PendUpdates.end()); + } + + if (PDT) { + const auto I = PendUpdates.begin() + PendPDTUpdateIndex; + assert(PendUpdates.begin() <= I && I <= PendUpdates.end() && + "Iterator out of range."); + OS << "Applied but not cleared PostDomTreeUpdates:\n"; + printUpdates(PendUpdates.begin(), I); + OS << "Pending PostDomTreeUpdates:\n"; + printUpdates(I, PendUpdates.end()); + } + + OS << "Pending DeletedBBs:\n"; + Index = 0; + for (auto BB : DeletedBBs) { + OS << " " << Index << " : "; + ++Index; + if (BB->hasName()) + OS << BB->getName() << "("; + else + OS << "(no_name)("; + OS << BB << ")\n"; + } + + OS << "Pending Callbacks:\n"; + Index = 0; + for (auto BB : Callbacks) { + OS << " " << Index << " : "; + ++Index; + if (BB->hasName()) + OS << BB->getName() << "("; + else + OS << "(no_name)("; + OS << BB << ")\n"; + } +} +#endif +} // namespace llvm diff --git a/llvm/lib/Analysis/DominanceFrontier.cpp b/llvm/lib/Analysis/DominanceFrontier.cpp new file mode 100644 index 000000000000..f9a554acb7ea --- /dev/null +++ b/llvm/lib/Analysis/DominanceFrontier.cpp @@ -0,0 +1,96 @@ +//===- DominanceFrontier.cpp - Dominance Frontier Calculation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DominanceFrontier.h" +#include "llvm/Analysis/DominanceFrontierImpl.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +namespace llvm { + +template class DominanceFrontierBase<BasicBlock, false>; +template class DominanceFrontierBase<BasicBlock, true>; +template class ForwardDominanceFrontierBase<BasicBlock>; + +} // end namespace llvm + +char DominanceFrontierWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(DominanceFrontierWrapperPass, "domfrontier", + "Dominance Frontier Construction", true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DominanceFrontierWrapperPass, "domfrontier", + "Dominance Frontier Construction", true, true) + +DominanceFrontierWrapperPass::DominanceFrontierWrapperPass() + : FunctionPass(ID), DF() { + initializeDominanceFrontierWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void DominanceFrontierWrapperPass::releaseMemory() { + DF.releaseMemory(); +} + +bool DominanceFrontierWrapperPass::runOnFunction(Function &) { + releaseMemory(); + DF.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); + return false; +} + +void DominanceFrontierWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<DominatorTreeWrapperPass>(); +} + +void DominanceFrontierWrapperPass::print(raw_ostream &OS, const Module *) const { + DF.print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void DominanceFrontierWrapperPass::dump() const { + print(dbgs()); +} +#endif + +/// Handle invalidation explicitly. +bool DominanceFrontier::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<DominanceFrontierAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + +AnalysisKey DominanceFrontierAnalysis::Key; + +DominanceFrontier DominanceFrontierAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + DominanceFrontier DF; + DF.analyze(AM.getResult<DominatorTreeAnalysis>(F)); + return DF; +} + +DominanceFrontierPrinterPass::DominanceFrontierPrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses +DominanceFrontierPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "DominanceFrontier for function: " << F.getName() << "\n"; + AM.getResult<DominanceFrontierAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/EHPersonalities.cpp b/llvm/lib/Analysis/EHPersonalities.cpp new file mode 100644 index 000000000000..2242541696a4 --- /dev/null +++ b/llvm/lib/Analysis/EHPersonalities.cpp @@ -0,0 +1,135 @@ +//===- EHPersonalities.cpp - Compute EH-related information ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +/// See if the given exception handling personality function is one that we +/// understand. If so, return a description of it; otherwise return Unknown. +EHPersonality llvm::classifyEHPersonality(const Value *Pers) { + const Function *F = + Pers ? dyn_cast<Function>(Pers->stripPointerCasts()) : nullptr; + if (!F) + return EHPersonality::Unknown; + return StringSwitch<EHPersonality>(F->getName()) + .Case("__gnat_eh_personality", EHPersonality::GNU_Ada) + .Case("__gxx_personality_v0", EHPersonality::GNU_CXX) + .Case("__gxx_personality_seh0", EHPersonality::GNU_CXX) + .Case("__gxx_personality_sj0", EHPersonality::GNU_CXX_SjLj) + .Case("__gcc_personality_v0", EHPersonality::GNU_C) + .Case("__gcc_personality_seh0", EHPersonality::GNU_C) + .Case("__gcc_personality_sj0", EHPersonality::GNU_C_SjLj) + .Case("__objc_personality_v0", EHPersonality::GNU_ObjC) + .Case("_except_handler3", EHPersonality::MSVC_X86SEH) + .Case("_except_handler4", EHPersonality::MSVC_X86SEH) + .Case("__C_specific_handler", EHPersonality::MSVC_Win64SEH) + .Case("__CxxFrameHandler3", EHPersonality::MSVC_CXX) + .Case("ProcessCLRException", EHPersonality::CoreCLR) + .Case("rust_eh_personality", EHPersonality::Rust) + .Case("__gxx_wasm_personality_v0", EHPersonality::Wasm_CXX) + .Default(EHPersonality::Unknown); +} + +StringRef llvm::getEHPersonalityName(EHPersonality Pers) { + switch (Pers) { + case EHPersonality::GNU_Ada: return "__gnat_eh_personality"; + case EHPersonality::GNU_CXX: return "__gxx_personality_v0"; + case EHPersonality::GNU_CXX_SjLj: return "__gxx_personality_sj0"; + case EHPersonality::GNU_C: return "__gcc_personality_v0"; + case EHPersonality::GNU_C_SjLj: return "__gcc_personality_sj0"; + case EHPersonality::GNU_ObjC: return "__objc_personality_v0"; + case EHPersonality::MSVC_X86SEH: return "_except_handler3"; + case EHPersonality::MSVC_Win64SEH: return "__C_specific_handler"; + case EHPersonality::MSVC_CXX: return "__CxxFrameHandler3"; + case EHPersonality::CoreCLR: return "ProcessCLRException"; + case EHPersonality::Rust: return "rust_eh_personality"; + case EHPersonality::Wasm_CXX: return "__gxx_wasm_personality_v0"; + case EHPersonality::Unknown: llvm_unreachable("Unknown EHPersonality!"); + } + + llvm_unreachable("Invalid EHPersonality!"); +} + +EHPersonality llvm::getDefaultEHPersonality(const Triple &T) { + return EHPersonality::GNU_C; +} + +bool llvm::canSimplifyInvokeNoUnwind(const Function *F) { + EHPersonality Personality = classifyEHPersonality(F->getPersonalityFn()); + // We can't simplify any invokes to nounwind functions if the personality + // function wants to catch asynch exceptions. The nounwind attribute only + // implies that the function does not throw synchronous exceptions. + return !isAsynchronousEHPersonality(Personality); +} + +DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) { + SmallVector<std::pair<BasicBlock *, BasicBlock *>, 16> Worklist; + BasicBlock *EntryBlock = &F.getEntryBlock(); + DenseMap<BasicBlock *, ColorVector> BlockColors; + + // Build up the color map, which maps each block to its set of 'colors'. + // For any block B the "colors" of B are the set of funclets F (possibly + // including a root "funclet" representing the main function) such that + // F will need to directly contain B or a copy of B (where the term "directly + // contain" is used to distinguish from being "transitively contained" in + // a nested funclet). + // + // Note: Despite not being a funclet in the truest sense, a catchswitch is + // considered to belong to its own funclet for the purposes of coloring. + + DEBUG_WITH_TYPE("winehprepare-coloring", dbgs() << "\nColoring funclets for " + << F.getName() << "\n"); + + Worklist.push_back({EntryBlock, EntryBlock}); + + while (!Worklist.empty()) { + BasicBlock *Visiting; + BasicBlock *Color; + std::tie(Visiting, Color) = Worklist.pop_back_val(); + DEBUG_WITH_TYPE("winehprepare-coloring", + dbgs() << "Visiting " << Visiting->getName() << ", " + << Color->getName() << "\n"); + Instruction *VisitingHead = Visiting->getFirstNonPHI(); + if (VisitingHead->isEHPad()) { + // Mark this funclet head as a member of itself. + Color = Visiting; + } + // Note that this is a member of the given color. + ColorVector &Colors = BlockColors[Visiting]; + if (!is_contained(Colors, Color)) + Colors.push_back(Color); + else + continue; + + DEBUG_WITH_TYPE("winehprepare-coloring", + dbgs() << " Assigned color \'" << Color->getName() + << "\' to block \'" << Visiting->getName() + << "\'.\n"); + + BasicBlock *SuccColor = Color; + Instruction *Terminator = Visiting->getTerminator(); + if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) { + Value *ParentPad = CatchRet->getCatchSwitchParentPad(); + if (isa<ConstantTokenNone>(ParentPad)) + SuccColor = EntryBlock; + else + SuccColor = cast<Instruction>(ParentPad)->getParent(); + } + + for (BasicBlock *Succ : successors(Visiting)) + Worklist.push_back({Succ, SuccColor}); + } + return BlockColors; +} diff --git a/llvm/lib/Analysis/GlobalsModRef.cpp b/llvm/lib/Analysis/GlobalsModRef.cpp new file mode 100644 index 000000000000..efdf9706ba3c --- /dev/null +++ b/llvm/lib/Analysis/GlobalsModRef.cpp @@ -0,0 +1,1026 @@ +//===- GlobalsModRef.cpp - Simple Mod/Ref Analysis for Globals ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This simple pass provides alias and mod/ref information for global values +// that do not have their address taken, and keeps track of whether functions +// read or write memory (are "pure"). For this simple (but very common) case, +// we can provide pretty accurate and useful information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +#define DEBUG_TYPE "globalsmodref-aa" + +STATISTIC(NumNonAddrTakenGlobalVars, + "Number of global vars without address taken"); +STATISTIC(NumNonAddrTakenFunctions,"Number of functions without address taken"); +STATISTIC(NumNoMemFunctions, "Number of functions that do not access memory"); +STATISTIC(NumReadMemFunctions, "Number of functions that only read memory"); +STATISTIC(NumIndirectGlobalVars, "Number of indirect global objects"); + +// An option to enable unsafe alias results from the GlobalsModRef analysis. +// When enabled, GlobalsModRef will provide no-alias results which in extremely +// rare cases may not be conservatively correct. In particular, in the face of +// transforms which cause assymetry between how effective GetUnderlyingObject +// is for two pointers, it may produce incorrect results. +// +// These unsafe results have been returned by GMR for many years without +// causing significant issues in the wild and so we provide a mechanism to +// re-enable them for users of LLVM that have a particular performance +// sensitivity and no known issues. The option also makes it easy to evaluate +// the performance impact of these results. +static cl::opt<bool> EnableUnsafeGlobalsModRefAliasResults( + "enable-unsafe-globalsmodref-alias-results", cl::init(false), cl::Hidden); + +/// The mod/ref information collected for a particular function. +/// +/// We collect information about mod/ref behavior of a function here, both in +/// general and as pertains to specific globals. We only have this detailed +/// information when we know *something* useful about the behavior. If we +/// saturate to fully general mod/ref, we remove the info for the function. +class GlobalsAAResult::FunctionInfo { + typedef SmallDenseMap<const GlobalValue *, ModRefInfo, 16> GlobalInfoMapType; + + /// Build a wrapper struct that has 8-byte alignment. All heap allocations + /// should provide this much alignment at least, but this makes it clear we + /// specifically rely on this amount of alignment. + struct alignas(8) AlignedMap { + AlignedMap() {} + AlignedMap(const AlignedMap &Arg) : Map(Arg.Map) {} + GlobalInfoMapType Map; + }; + + /// Pointer traits for our aligned map. + struct AlignedMapPointerTraits { + static inline void *getAsVoidPointer(AlignedMap *P) { return P; } + static inline AlignedMap *getFromVoidPointer(void *P) { + return (AlignedMap *)P; + } + enum { NumLowBitsAvailable = 3 }; + static_assert(alignof(AlignedMap) >= (1 << NumLowBitsAvailable), + "AlignedMap insufficiently aligned to have enough low bits."); + }; + + /// The bit that flags that this function may read any global. This is + /// chosen to mix together with ModRefInfo bits. + /// FIXME: This assumes ModRefInfo lattice will remain 4 bits! + /// It overlaps with ModRefInfo::Must bit! + /// FunctionInfo.getModRefInfo() masks out everything except ModRef so + /// this remains correct, but the Must info is lost. + enum { MayReadAnyGlobal = 4 }; + + /// Checks to document the invariants of the bit packing here. + static_assert((MayReadAnyGlobal & static_cast<int>(ModRefInfo::MustModRef)) == + 0, + "ModRef and the MayReadAnyGlobal flag bits overlap."); + static_assert(((MayReadAnyGlobal | + static_cast<int>(ModRefInfo::MustModRef)) >> + AlignedMapPointerTraits::NumLowBitsAvailable) == 0, + "Insufficient low bits to store our flag and ModRef info."); + +public: + FunctionInfo() : Info() {} + ~FunctionInfo() { + delete Info.getPointer(); + } + // Spell out the copy ond move constructors and assignment operators to get + // deep copy semantics and correct move semantics in the face of the + // pointer-int pair. + FunctionInfo(const FunctionInfo &Arg) + : Info(nullptr, Arg.Info.getInt()) { + if (const auto *ArgPtr = Arg.Info.getPointer()) + Info.setPointer(new AlignedMap(*ArgPtr)); + } + FunctionInfo(FunctionInfo &&Arg) + : Info(Arg.Info.getPointer(), Arg.Info.getInt()) { + Arg.Info.setPointerAndInt(nullptr, 0); + } + FunctionInfo &operator=(const FunctionInfo &RHS) { + delete Info.getPointer(); + Info.setPointerAndInt(nullptr, RHS.Info.getInt()); + if (const auto *RHSPtr = RHS.Info.getPointer()) + Info.setPointer(new AlignedMap(*RHSPtr)); + return *this; + } + FunctionInfo &operator=(FunctionInfo &&RHS) { + delete Info.getPointer(); + Info.setPointerAndInt(RHS.Info.getPointer(), RHS.Info.getInt()); + RHS.Info.setPointerAndInt(nullptr, 0); + return *this; + } + + /// This method clears MayReadAnyGlobal bit added by GlobalsAAResult to return + /// the corresponding ModRefInfo. It must align in functionality with + /// clearMust(). + ModRefInfo globalClearMayReadAnyGlobal(int I) const { + return ModRefInfo((I & static_cast<int>(ModRefInfo::ModRef)) | + static_cast<int>(ModRefInfo::NoModRef)); + } + + /// Returns the \c ModRefInfo info for this function. + ModRefInfo getModRefInfo() const { + return globalClearMayReadAnyGlobal(Info.getInt()); + } + + /// Adds new \c ModRefInfo for this function to its state. + void addModRefInfo(ModRefInfo NewMRI) { + Info.setInt(Info.getInt() | static_cast<int>(setMust(NewMRI))); + } + + /// Returns whether this function may read any global variable, and we don't + /// know which global. + bool mayReadAnyGlobal() const { return Info.getInt() & MayReadAnyGlobal; } + + /// Sets this function as potentially reading from any global. + void setMayReadAnyGlobal() { Info.setInt(Info.getInt() | MayReadAnyGlobal); } + + /// Returns the \c ModRefInfo info for this function w.r.t. a particular + /// global, which may be more precise than the general information above. + ModRefInfo getModRefInfoForGlobal(const GlobalValue &GV) const { + ModRefInfo GlobalMRI = + mayReadAnyGlobal() ? ModRefInfo::Ref : ModRefInfo::NoModRef; + if (AlignedMap *P = Info.getPointer()) { + auto I = P->Map.find(&GV); + if (I != P->Map.end()) + GlobalMRI = unionModRef(GlobalMRI, I->second); + } + return GlobalMRI; + } + + /// Add mod/ref info from another function into ours, saturating towards + /// ModRef. + void addFunctionInfo(const FunctionInfo &FI) { + addModRefInfo(FI.getModRefInfo()); + + if (FI.mayReadAnyGlobal()) + setMayReadAnyGlobal(); + + if (AlignedMap *P = FI.Info.getPointer()) + for (const auto &G : P->Map) + addModRefInfoForGlobal(*G.first, G.second); + } + + void addModRefInfoForGlobal(const GlobalValue &GV, ModRefInfo NewMRI) { + AlignedMap *P = Info.getPointer(); + if (!P) { + P = new AlignedMap(); + Info.setPointer(P); + } + auto &GlobalMRI = P->Map[&GV]; + GlobalMRI = unionModRef(GlobalMRI, NewMRI); + } + + /// Clear a global's ModRef info. Should be used when a global is being + /// deleted. + void eraseModRefInfoForGlobal(const GlobalValue &GV) { + if (AlignedMap *P = Info.getPointer()) + P->Map.erase(&GV); + } + +private: + /// All of the information is encoded into a single pointer, with a three bit + /// integer in the low three bits. The high bit provides a flag for when this + /// function may read any global. The low two bits are the ModRefInfo. And + /// the pointer, when non-null, points to a map from GlobalValue to + /// ModRefInfo specific to that GlobalValue. + PointerIntPair<AlignedMap *, 3, unsigned, AlignedMapPointerTraits> Info; +}; + +void GlobalsAAResult::DeletionCallbackHandle::deleted() { + Value *V = getValPtr(); + if (auto *F = dyn_cast<Function>(V)) + GAR->FunctionInfos.erase(F); + + if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (GAR->NonAddressTakenGlobals.erase(GV)) { + // This global might be an indirect global. If so, remove it and + // remove any AllocRelatedValues for it. + if (GAR->IndirectGlobals.erase(GV)) { + // Remove any entries in AllocsForIndirectGlobals for this global. + for (auto I = GAR->AllocsForIndirectGlobals.begin(), + E = GAR->AllocsForIndirectGlobals.end(); + I != E; ++I) + if (I->second == GV) + GAR->AllocsForIndirectGlobals.erase(I); + } + + // Scan the function info we have collected and remove this global + // from all of them. + for (auto &FIPair : GAR->FunctionInfos) + FIPair.second.eraseModRefInfoForGlobal(*GV); + } + } + + // If this is an allocation related to an indirect global, remove it. + GAR->AllocsForIndirectGlobals.erase(V); + + // And clear out the handle. + setValPtr(nullptr); + GAR->Handles.erase(I); + // This object is now destroyed! +} + +FunctionModRefBehavior GlobalsAAResult::getModRefBehavior(const Function *F) { + FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + + if (FunctionInfo *FI = getFunctionInfo(F)) { + if (!isModOrRefSet(FI->getModRefInfo())) + Min = FMRB_DoesNotAccessMemory; + else if (!isModSet(FI->getModRefInfo())) + Min = FMRB_OnlyReadsMemory; + } + + return FunctionModRefBehavior(AAResultBase::getModRefBehavior(F) & Min); +} + +FunctionModRefBehavior +GlobalsAAResult::getModRefBehavior(const CallBase *Call) { + FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + + if (!Call->hasOperandBundles()) + if (const Function *F = Call->getCalledFunction()) + if (FunctionInfo *FI = getFunctionInfo(F)) { + if (!isModOrRefSet(FI->getModRefInfo())) + Min = FMRB_DoesNotAccessMemory; + else if (!isModSet(FI->getModRefInfo())) + Min = FMRB_OnlyReadsMemory; + } + + return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min); +} + +/// Returns the function info for the function, or null if we don't have +/// anything useful to say about it. +GlobalsAAResult::FunctionInfo * +GlobalsAAResult::getFunctionInfo(const Function *F) { + auto I = FunctionInfos.find(F); + if (I != FunctionInfos.end()) + return &I->second; + return nullptr; +} + +/// AnalyzeGlobals - Scan through the users of all of the internal +/// GlobalValue's in the program. If none of them have their "address taken" +/// (really, their address passed to something nontrivial), record this fact, +/// and record the functions that they are used directly in. +void GlobalsAAResult::AnalyzeGlobals(Module &M) { + SmallPtrSet<Function *, 32> TrackedFunctions; + for (Function &F : M) + if (F.hasLocalLinkage()) + if (!AnalyzeUsesOfPointer(&F)) { + // Remember that we are tracking this global. + NonAddressTakenGlobals.insert(&F); + TrackedFunctions.insert(&F); + Handles.emplace_front(*this, &F); + Handles.front().I = Handles.begin(); + ++NumNonAddrTakenFunctions; + } + + SmallPtrSet<Function *, 16> Readers, Writers; + for (GlobalVariable &GV : M.globals()) + if (GV.hasLocalLinkage()) { + if (!AnalyzeUsesOfPointer(&GV, &Readers, + GV.isConstant() ? nullptr : &Writers)) { + // Remember that we are tracking this global, and the mod/ref fns + NonAddressTakenGlobals.insert(&GV); + Handles.emplace_front(*this, &GV); + Handles.front().I = Handles.begin(); + + for (Function *Reader : Readers) { + if (TrackedFunctions.insert(Reader).second) { + Handles.emplace_front(*this, Reader); + Handles.front().I = Handles.begin(); + } + FunctionInfos[Reader].addModRefInfoForGlobal(GV, ModRefInfo::Ref); + } + + if (!GV.isConstant()) // No need to keep track of writers to constants + for (Function *Writer : Writers) { + if (TrackedFunctions.insert(Writer).second) { + Handles.emplace_front(*this, Writer); + Handles.front().I = Handles.begin(); + } + FunctionInfos[Writer].addModRefInfoForGlobal(GV, ModRefInfo::Mod); + } + ++NumNonAddrTakenGlobalVars; + + // If this global holds a pointer type, see if it is an indirect global. + if (GV.getValueType()->isPointerTy() && + AnalyzeIndirectGlobalMemory(&GV)) + ++NumIndirectGlobalVars; + } + Readers.clear(); + Writers.clear(); + } +} + +/// AnalyzeUsesOfPointer - Look at all of the users of the specified pointer. +/// If this is used by anything complex (i.e., the address escapes), return +/// true. Also, while we are at it, keep track of those functions that read and +/// write to the value. +/// +/// If OkayStoreDest is non-null, stores into this global are allowed. +bool GlobalsAAResult::AnalyzeUsesOfPointer(Value *V, + SmallPtrSetImpl<Function *> *Readers, + SmallPtrSetImpl<Function *> *Writers, + GlobalValue *OkayStoreDest) { + if (!V->getType()->isPointerTy()) + return true; + + for (Use &U : V->uses()) { + User *I = U.getUser(); + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (Readers) + Readers->insert(LI->getParent()->getParent()); + } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + if (V == SI->getOperand(1)) { + if (Writers) + Writers->insert(SI->getParent()->getParent()); + } else if (SI->getOperand(1) != OkayStoreDest) { + return true; // Storing the pointer + } + } else if (Operator::getOpcode(I) == Instruction::GetElementPtr) { + if (AnalyzeUsesOfPointer(I, Readers, Writers)) + return true; + } else if (Operator::getOpcode(I) == Instruction::BitCast) { + if (AnalyzeUsesOfPointer(I, Readers, Writers, OkayStoreDest)) + return true; + } else if (auto *Call = dyn_cast<CallBase>(I)) { + // Make sure that this is just the function being called, not that it is + // passing into the function. + if (Call->isDataOperand(&U)) { + // Detect calls to free. + if (Call->isArgOperand(&U) && + isFreeCall(I, &GetTLI(*Call->getFunction()))) { + if (Writers) + Writers->insert(Call->getParent()->getParent()); + } else { + return true; // Argument of an unknown call. + } + } + } else if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { + if (!isa<ConstantPointerNull>(ICI->getOperand(1))) + return true; // Allow comparison against null. + } else if (Constant *C = dyn_cast<Constant>(I)) { + // Ignore constants which don't have any live uses. + if (isa<GlobalValue>(C) || C->isConstantUsed()) + return true; + } else { + return true; + } + } + + return false; +} + +/// AnalyzeIndirectGlobalMemory - We found an non-address-taken global variable +/// which holds a pointer type. See if the global always points to non-aliased +/// heap memory: that is, all initializers of the globals are allocations, and +/// those allocations have no use other than initialization of the global. +/// Further, all loads out of GV must directly use the memory, not store the +/// pointer somewhere. If this is true, we consider the memory pointed to by +/// GV to be owned by GV and can disambiguate other pointers from it. +bool GlobalsAAResult::AnalyzeIndirectGlobalMemory(GlobalVariable *GV) { + // Keep track of values related to the allocation of the memory, f.e. the + // value produced by the malloc call and any casts. + std::vector<Value *> AllocRelatedValues; + + // If the initializer is a valid pointer, bail. + if (Constant *C = GV->getInitializer()) + if (!C->isNullValue()) + return false; + + // Walk the user list of the global. If we find anything other than a direct + // load or store, bail out. + for (User *U : GV->users()) { + if (LoadInst *LI = dyn_cast<LoadInst>(U)) { + // The pointer loaded from the global can only be used in simple ways: + // we allow addressing of it and loading storing to it. We do *not* allow + // storing the loaded pointer somewhere else or passing to a function. + if (AnalyzeUsesOfPointer(LI)) + return false; // Loaded pointer escapes. + // TODO: Could try some IP mod/ref of the loaded pointer. + } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + // Storing the global itself. + if (SI->getOperand(0) == GV) + return false; + + // If storing the null pointer, ignore it. + if (isa<ConstantPointerNull>(SI->getOperand(0))) + continue; + + // Check the value being stored. + Value *Ptr = GetUnderlyingObject(SI->getOperand(0), + GV->getParent()->getDataLayout()); + + if (!isAllocLikeFn(Ptr, &GetTLI(*SI->getFunction()))) + return false; // Too hard to analyze. + + // Analyze all uses of the allocation. If any of them are used in a + // non-simple way (e.g. stored to another global) bail out. + if (AnalyzeUsesOfPointer(Ptr, /*Readers*/ nullptr, /*Writers*/ nullptr, + GV)) + return false; // Loaded pointer escapes. + + // Remember that this allocation is related to the indirect global. + AllocRelatedValues.push_back(Ptr); + } else { + // Something complex, bail out. + return false; + } + } + + // Okay, this is an indirect global. Remember all of the allocations for + // this global in AllocsForIndirectGlobals. + while (!AllocRelatedValues.empty()) { + AllocsForIndirectGlobals[AllocRelatedValues.back()] = GV; + Handles.emplace_front(*this, AllocRelatedValues.back()); + Handles.front().I = Handles.begin(); + AllocRelatedValues.pop_back(); + } + IndirectGlobals.insert(GV); + Handles.emplace_front(*this, GV); + Handles.front().I = Handles.begin(); + return true; +} + +void GlobalsAAResult::CollectSCCMembership(CallGraph &CG) { + // We do a bottom-up SCC traversal of the call graph. In other words, we + // visit all callees before callers (leaf-first). + unsigned SCCID = 0; + for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { + const std::vector<CallGraphNode *> &SCC = *I; + assert(!SCC.empty() && "SCC with no functions?"); + + for (auto *CGN : SCC) + if (Function *F = CGN->getFunction()) + FunctionToSCCMap[F] = SCCID; + ++SCCID; + } +} + +/// AnalyzeCallGraph - At this point, we know the functions where globals are +/// immediately stored to and read from. Propagate this information up the call +/// graph to all callers and compute the mod/ref info for all memory for each +/// function. +void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { + // We do a bottom-up SCC traversal of the call graph. In other words, we + // visit all callees before callers (leaf-first). + for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { + const std::vector<CallGraphNode *> &SCC = *I; + assert(!SCC.empty() && "SCC with no functions?"); + + Function *F = SCC[0]->getFunction(); + + if (!F || !F->isDefinitionExact()) { + // Calls externally or not exact - can't say anything useful. Remove any + // existing function records (may have been created when scanning + // globals). + for (auto *Node : SCC) + FunctionInfos.erase(Node->getFunction()); + continue; + } + + FunctionInfo &FI = FunctionInfos[F]; + Handles.emplace_front(*this, F); + Handles.front().I = Handles.begin(); + bool KnowNothing = false; + + // Collect the mod/ref properties due to called functions. We only compute + // one mod-ref set. + for (unsigned i = 0, e = SCC.size(); i != e && !KnowNothing; ++i) { + if (!F) { + KnowNothing = true; + break; + } + + if (F->isDeclaration() || F->hasOptNone()) { + // Try to get mod/ref behaviour from function attributes. + if (F->doesNotAccessMemory()) { + // Can't do better than that! + } else if (F->onlyReadsMemory()) { + FI.addModRefInfo(ModRefInfo::Ref); + if (!F->isIntrinsic() && !F->onlyAccessesArgMemory()) + // This function might call back into the module and read a global - + // consider every global as possibly being read by this function. + FI.setMayReadAnyGlobal(); + } else { + FI.addModRefInfo(ModRefInfo::ModRef); + // Can't say anything useful unless it's an intrinsic - they don't + // read or write global variables of the kind considered here. + KnowNothing = !F->isIntrinsic(); + } + continue; + } + + for (CallGraphNode::iterator CI = SCC[i]->begin(), E = SCC[i]->end(); + CI != E && !KnowNothing; ++CI) + if (Function *Callee = CI->second->getFunction()) { + if (FunctionInfo *CalleeFI = getFunctionInfo(Callee)) { + // Propagate function effect up. + FI.addFunctionInfo(*CalleeFI); + } else { + // Can't say anything about it. However, if it is inside our SCC, + // then nothing needs to be done. + CallGraphNode *CalleeNode = CG[Callee]; + if (!is_contained(SCC, CalleeNode)) + KnowNothing = true; + } + } else { + KnowNothing = true; + } + } + + // If we can't say anything useful about this SCC, remove all SCC functions + // from the FunctionInfos map. + if (KnowNothing) { + for (auto *Node : SCC) + FunctionInfos.erase(Node->getFunction()); + continue; + } + + // Scan the function bodies for explicit loads or stores. + for (auto *Node : SCC) { + if (isModAndRefSet(FI.getModRefInfo())) + break; // The mod/ref lattice saturates here. + + // Don't prove any properties based on the implementation of an optnone + // function. Function attributes were already used as a best approximation + // above. + if (Node->getFunction()->hasOptNone()) + continue; + + for (Instruction &I : instructions(Node->getFunction())) { + if (isModAndRefSet(FI.getModRefInfo())) + break; // The mod/ref lattice saturates here. + + // We handle calls specially because the graph-relevant aspects are + // handled above. + if (auto *Call = dyn_cast<CallBase>(&I)) { + auto &TLI = GetTLI(*Node->getFunction()); + if (isAllocationFn(Call, &TLI) || isFreeCall(Call, &TLI)) { + // FIXME: It is completely unclear why this is necessary and not + // handled by the above graph code. + FI.addModRefInfo(ModRefInfo::ModRef); + } else if (Function *Callee = Call->getCalledFunction()) { + // The callgraph doesn't include intrinsic calls. + if (Callee->isIntrinsic()) { + if (isa<DbgInfoIntrinsic>(Call)) + // Don't let dbg intrinsics affect alias info. + continue; + + FunctionModRefBehavior Behaviour = + AAResultBase::getModRefBehavior(Callee); + FI.addModRefInfo(createModRefInfo(Behaviour)); + } + } + continue; + } + + // All non-call instructions we use the primary predicates for whether + // they read or write memory. + if (I.mayReadFromMemory()) + FI.addModRefInfo(ModRefInfo::Ref); + if (I.mayWriteToMemory()) + FI.addModRefInfo(ModRefInfo::Mod); + } + } + + if (!isModSet(FI.getModRefInfo())) + ++NumReadMemFunctions; + if (!isModOrRefSet(FI.getModRefInfo())) + ++NumNoMemFunctions; + + // Finally, now that we know the full effect on this SCC, clone the + // information to each function in the SCC. + // FI is a reference into FunctionInfos, so copy it now so that it doesn't + // get invalidated if DenseMap decides to re-hash. + FunctionInfo CachedFI = FI; + for (unsigned i = 1, e = SCC.size(); i != e; ++i) + FunctionInfos[SCC[i]->getFunction()] = CachedFI; + } +} + +// GV is a non-escaping global. V is a pointer address that has been loaded from. +// If we can prove that V must escape, we can conclude that a load from V cannot +// alias GV. +static bool isNonEscapingGlobalNoAliasWithLoad(const GlobalValue *GV, + const Value *V, + int &Depth, + const DataLayout &DL) { + SmallPtrSet<const Value *, 8> Visited; + SmallVector<const Value *, 8> Inputs; + Visited.insert(V); + Inputs.push_back(V); + do { + const Value *Input = Inputs.pop_back_val(); + + if (isa<GlobalValue>(Input) || isa<Argument>(Input) || isa<CallInst>(Input) || + isa<InvokeInst>(Input)) + // Arguments to functions or returns from functions are inherently + // escaping, so we can immediately classify those as not aliasing any + // non-addr-taken globals. + // + // (Transitive) loads from a global are also safe - if this aliased + // another global, its address would escape, so no alias. + continue; + + // Recurse through a limited number of selects, loads and PHIs. This is an + // arbitrary depth of 4, lower numbers could be used to fix compile time + // issues if needed, but this is generally expected to be only be important + // for small depths. + if (++Depth > 4) + return false; + + if (auto *LI = dyn_cast<LoadInst>(Input)) { + Inputs.push_back(GetUnderlyingObject(LI->getPointerOperand(), DL)); + continue; + } + if (auto *SI = dyn_cast<SelectInst>(Input)) { + const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL); + const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL); + if (Visited.insert(LHS).second) + Inputs.push_back(LHS); + if (Visited.insert(RHS).second) + Inputs.push_back(RHS); + continue; + } + if (auto *PN = dyn_cast<PHINode>(Input)) { + for (const Value *Op : PN->incoming_values()) { + Op = GetUnderlyingObject(Op, DL); + if (Visited.insert(Op).second) + Inputs.push_back(Op); + } + continue; + } + + return false; + } while (!Inputs.empty()); + + // All inputs were known to be no-alias. + return true; +} + +// There are particular cases where we can conclude no-alias between +// a non-addr-taken global and some other underlying object. Specifically, +// a non-addr-taken global is known to not be escaped from any function. It is +// also incorrect for a transformation to introduce an escape of a global in +// a way that is observable when it was not there previously. One function +// being transformed to introduce an escape which could possibly be observed +// (via loading from a global or the return value for example) within another +// function is never safe. If the observation is made through non-atomic +// operations on different threads, it is a data-race and UB. If the +// observation is well defined, by being observed the transformation would have +// changed program behavior by introducing the observed escape, making it an +// invalid transform. +// +// This property does require that transformations which *temporarily* escape +// a global that was not previously escaped, prior to restoring it, cannot rely +// on the results of GMR::alias. This seems a reasonable restriction, although +// currently there is no way to enforce it. There is also no realistic +// optimization pass that would make this mistake. The closest example is +// a transformation pass which does reg2mem of SSA values but stores them into +// global variables temporarily before restoring the global variable's value. +// This could be useful to expose "benign" races for example. However, it seems +// reasonable to require that a pass which introduces escapes of global +// variables in this way to either not trust AA results while the escape is +// active, or to be forced to operate as a module pass that cannot co-exist +// with an alias analysis such as GMR. +bool GlobalsAAResult::isNonEscapingGlobalNoAlias(const GlobalValue *GV, + const Value *V) { + // In order to know that the underlying object cannot alias the + // non-addr-taken global, we must know that it would have to be an escape. + // Thus if the underlying object is a function argument, a load from + // a global, or the return of a function, it cannot alias. We can also + // recurse through PHI nodes and select nodes provided all of their inputs + // resolve to one of these known-escaping roots. + SmallPtrSet<const Value *, 8> Visited; + SmallVector<const Value *, 8> Inputs; + Visited.insert(V); + Inputs.push_back(V); + int Depth = 0; + do { + const Value *Input = Inputs.pop_back_val(); + + if (auto *InputGV = dyn_cast<GlobalValue>(Input)) { + // If one input is the very global we're querying against, then we can't + // conclude no-alias. + if (InputGV == GV) + return false; + + // Distinct GlobalVariables never alias, unless overriden or zero-sized. + // FIXME: The condition can be refined, but be conservative for now. + auto *GVar = dyn_cast<GlobalVariable>(GV); + auto *InputGVar = dyn_cast<GlobalVariable>(InputGV); + if (GVar && InputGVar && + !GVar->isDeclaration() && !InputGVar->isDeclaration() && + !GVar->isInterposable() && !InputGVar->isInterposable()) { + Type *GVType = GVar->getInitializer()->getType(); + Type *InputGVType = InputGVar->getInitializer()->getType(); + if (GVType->isSized() && InputGVType->isSized() && + (DL.getTypeAllocSize(GVType) > 0) && + (DL.getTypeAllocSize(InputGVType) > 0)) + continue; + } + + // Conservatively return false, even though we could be smarter + // (e.g. look through GlobalAliases). + return false; + } + + if (isa<Argument>(Input) || isa<CallInst>(Input) || + isa<InvokeInst>(Input)) { + // Arguments to functions or returns from functions are inherently + // escaping, so we can immediately classify those as not aliasing any + // non-addr-taken globals. + continue; + } + + // Recurse through a limited number of selects, loads and PHIs. This is an + // arbitrary depth of 4, lower numbers could be used to fix compile time + // issues if needed, but this is generally expected to be only be important + // for small depths. + if (++Depth > 4) + return false; + + if (auto *LI = dyn_cast<LoadInst>(Input)) { + // A pointer loaded from a global would have been captured, and we know + // that the global is non-escaping, so no alias. + const Value *Ptr = GetUnderlyingObject(LI->getPointerOperand(), DL); + if (isNonEscapingGlobalNoAliasWithLoad(GV, Ptr, Depth, DL)) + // The load does not alias with GV. + continue; + // Otherwise, a load could come from anywhere, so bail. + return false; + } + if (auto *SI = dyn_cast<SelectInst>(Input)) { + const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL); + const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL); + if (Visited.insert(LHS).second) + Inputs.push_back(LHS); + if (Visited.insert(RHS).second) + Inputs.push_back(RHS); + continue; + } + if (auto *PN = dyn_cast<PHINode>(Input)) { + for (const Value *Op : PN->incoming_values()) { + Op = GetUnderlyingObject(Op, DL); + if (Visited.insert(Op).second) + Inputs.push_back(Op); + } + continue; + } + + // FIXME: It would be good to handle other obvious no-alias cases here, but + // it isn't clear how to do so reasonably without building a small version + // of BasicAA into this code. We could recurse into AAResultBase::alias + // here but that seems likely to go poorly as we're inside the + // implementation of such a query. Until then, just conservatively return + // false. + return false; + } while (!Inputs.empty()); + + // If all the inputs to V were definitively no-alias, then V is no-alias. + return true; +} + +/// alias - If one of the pointers is to a global that we are tracking, and the +/// other is some random pointer, we know there cannot be an alias, because the +/// address of the global isn't taken. +AliasResult GlobalsAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + // Get the base object these pointers point to. + const Value *UV1 = GetUnderlyingObject(LocA.Ptr, DL); + const Value *UV2 = GetUnderlyingObject(LocB.Ptr, DL); + + // If either of the underlying values is a global, they may be non-addr-taken + // globals, which we can answer queries about. + const GlobalValue *GV1 = dyn_cast<GlobalValue>(UV1); + const GlobalValue *GV2 = dyn_cast<GlobalValue>(UV2); + if (GV1 || GV2) { + // If the global's address is taken, pretend we don't know it's a pointer to + // the global. + if (GV1 && !NonAddressTakenGlobals.count(GV1)) + GV1 = nullptr; + if (GV2 && !NonAddressTakenGlobals.count(GV2)) + GV2 = nullptr; + + // If the two pointers are derived from two different non-addr-taken + // globals we know these can't alias. + if (GV1 && GV2 && GV1 != GV2) + return NoAlias; + + // If one is and the other isn't, it isn't strictly safe but we can fake + // this result if necessary for performance. This does not appear to be + // a common problem in practice. + if (EnableUnsafeGlobalsModRefAliasResults) + if ((GV1 || GV2) && GV1 != GV2) + return NoAlias; + + // Check for a special case where a non-escaping global can be used to + // conclude no-alias. + if ((GV1 || GV2) && GV1 != GV2) { + const GlobalValue *GV = GV1 ? GV1 : GV2; + const Value *UV = GV1 ? UV2 : UV1; + if (isNonEscapingGlobalNoAlias(GV, UV)) + return NoAlias; + } + + // Otherwise if they are both derived from the same addr-taken global, we + // can't know the two accesses don't overlap. + } + + // These pointers may be based on the memory owned by an indirect global. If + // so, we may be able to handle this. First check to see if the base pointer + // is a direct load from an indirect global. + GV1 = GV2 = nullptr; + if (const LoadInst *LI = dyn_cast<LoadInst>(UV1)) + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) + if (IndirectGlobals.count(GV)) + GV1 = GV; + if (const LoadInst *LI = dyn_cast<LoadInst>(UV2)) + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) + if (IndirectGlobals.count(GV)) + GV2 = GV; + + // These pointers may also be from an allocation for the indirect global. If + // so, also handle them. + if (!GV1) + GV1 = AllocsForIndirectGlobals.lookup(UV1); + if (!GV2) + GV2 = AllocsForIndirectGlobals.lookup(UV2); + + // Now that we know whether the two pointers are related to indirect globals, + // use this to disambiguate the pointers. If the pointers are based on + // different indirect globals they cannot alias. + if (GV1 && GV2 && GV1 != GV2) + return NoAlias; + + // If one is based on an indirect global and the other isn't, it isn't + // strictly safe but we can fake this result if necessary for performance. + // This does not appear to be a common problem in practice. + if (EnableUnsafeGlobalsModRefAliasResults) + if ((GV1 || GV2) && GV1 != GV2) + return NoAlias; + + return AAResultBase::alias(LocA, LocB, AAQI); +} + +ModRefInfo GlobalsAAResult::getModRefInfoForArgument(const CallBase *Call, + const GlobalValue *GV, + AAQueryInfo &AAQI) { + if (Call->doesNotAccessMemory()) + return ModRefInfo::NoModRef; + ModRefInfo ConservativeResult = + Call->onlyReadsMemory() ? ModRefInfo::Ref : ModRefInfo::ModRef; + + // Iterate through all the arguments to the called function. If any argument + // is based on GV, return the conservative result. + for (auto &A : Call->args()) { + SmallVector<const Value*, 4> Objects; + GetUnderlyingObjects(A, Objects, DL); + + // All objects must be identified. + if (!all_of(Objects, isIdentifiedObject) && + // Try ::alias to see if all objects are known not to alias GV. + !all_of(Objects, [&](const Value *V) { + return this->alias(MemoryLocation(V), MemoryLocation(GV), AAQI) == + NoAlias; + })) + return ConservativeResult; + + if (is_contained(Objects, GV)) + return ConservativeResult; + } + + // We identified all objects in the argument list, and none of them were GV. + return ModRefInfo::NoModRef; +} + +ModRefInfo GlobalsAAResult::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + ModRefInfo Known = ModRefInfo::ModRef; + + // If we are asking for mod/ref info of a direct call with a pointer to a + // global we are tracking, return information if we have it. + if (const GlobalValue *GV = + dyn_cast<GlobalValue>(GetUnderlyingObject(Loc.Ptr, DL))) + if (GV->hasLocalLinkage()) + if (const Function *F = Call->getCalledFunction()) + if (NonAddressTakenGlobals.count(GV)) + if (const FunctionInfo *FI = getFunctionInfo(F)) + Known = unionModRef(FI->getModRefInfoForGlobal(*GV), + getModRefInfoForArgument(Call, GV, AAQI)); + + if (!isModOrRefSet(Known)) + return ModRefInfo::NoModRef; // No need to query other mod/ref analyses + return intersectModRef(Known, AAResultBase::getModRefInfo(Call, Loc, AAQI)); +} + +GlobalsAAResult::GlobalsAAResult( + const DataLayout &DL, + std::function<const TargetLibraryInfo &(Function &F)> GetTLI) + : AAResultBase(), DL(DL), GetTLI(std::move(GetTLI)) {} + +GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg) + : AAResultBase(std::move(Arg)), DL(Arg.DL), GetTLI(std::move(Arg.GetTLI)), + NonAddressTakenGlobals(std::move(Arg.NonAddressTakenGlobals)), + IndirectGlobals(std::move(Arg.IndirectGlobals)), + AllocsForIndirectGlobals(std::move(Arg.AllocsForIndirectGlobals)), + FunctionInfos(std::move(Arg.FunctionInfos)), + Handles(std::move(Arg.Handles)) { + // Update the parent for each DeletionCallbackHandle. + for (auto &H : Handles) { + assert(H.GAR == &Arg); + H.GAR = this; + } +} + +GlobalsAAResult::~GlobalsAAResult() {} + +/*static*/ GlobalsAAResult GlobalsAAResult::analyzeModule( + Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI, + CallGraph &CG) { + GlobalsAAResult Result(M.getDataLayout(), GetTLI); + + // Discover which functions aren't recursive, to feed into AnalyzeGlobals. + Result.CollectSCCMembership(CG); + + // Find non-addr taken globals. + Result.AnalyzeGlobals(M); + + // Propagate on CG. + Result.AnalyzeCallGraph(CG, M); + + return Result; +} + +AnalysisKey GlobalsAA::Key; + +GlobalsAAResult GlobalsAA::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; + return GlobalsAAResult::analyzeModule(M, GetTLI, + AM.getResult<CallGraphAnalysis>(M)); +} + +char GlobalsAAWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(GlobalsAAWrapperPass, "globals-aa", + "Globals Alias Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(GlobalsAAWrapperPass, "globals-aa", + "Globals Alias Analysis", false, true) + +ModulePass *llvm::createGlobalsAAWrapperPass() { + return new GlobalsAAWrapperPass(); +} + +GlobalsAAWrapperPass::GlobalsAAWrapperPass() : ModulePass(ID) { + initializeGlobalsAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool GlobalsAAWrapperPass::runOnModule(Module &M) { + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; + Result.reset(new GlobalsAAResult(GlobalsAAResult::analyzeModule( + M, GetTLI, getAnalysis<CallGraphWrapperPass>().getCallGraph()))); + return false; +} + +bool GlobalsAAWrapperPass::doFinalization(Module &M) { + Result.reset(); + return false; +} + +void GlobalsAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<CallGraphWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/llvm/lib/Analysis/GuardUtils.cpp b/llvm/lib/Analysis/GuardUtils.cpp new file mode 100644 index 000000000000..cad92f6e56bb --- /dev/null +++ b/llvm/lib/Analysis/GuardUtils.cpp @@ -0,0 +1,49 @@ +//===-- GuardUtils.cpp - Utils for work with guards -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Utils that are used to perform analyzes related to guards and their +// conditions. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; + +bool llvm::isGuard(const User *U) { + using namespace llvm::PatternMatch; + return match(U, m_Intrinsic<Intrinsic::experimental_guard>()); +} + +bool llvm::isGuardAsWidenableBranch(const User *U) { + Value *Condition, *WidenableCondition; + BasicBlock *GuardedBB, *DeoptBB; + if (!parseWidenableBranch(U, Condition, WidenableCondition, GuardedBB, + DeoptBB)) + return false; + using namespace llvm::PatternMatch; + for (auto &Insn : *DeoptBB) { + if (match(&Insn, m_Intrinsic<Intrinsic::experimental_deoptimize>())) + return true; + if (Insn.mayHaveSideEffects()) + return false; + } + return false; +} + +bool llvm::parseWidenableBranch(const User *U, Value *&Condition, + Value *&WidenableCondition, + BasicBlock *&IfTrueBB, BasicBlock *&IfFalseBB) { + using namespace llvm::PatternMatch; + if (!match(U, m_Br(m_And(m_Value(Condition), m_Value(WidenableCondition)), + IfTrueBB, IfFalseBB))) + return false; + // TODO: At the moment, we only recognize the branch if the WC call in this + // specific position. We should generalize! + return match(WidenableCondition, + m_Intrinsic<Intrinsic::experimental_widenable_condition>()); +} diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp new file mode 100644 index 000000000000..6fb600114bc6 --- /dev/null +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -0,0 +1,1101 @@ +//===- llvm/Analysis/IVDescriptors.cpp - IndVar Descriptors -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file "describes" induction and recurrence variables. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IVDescriptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "iv-descriptors" + +bool RecurrenceDescriptor::areAllUsesIn(Instruction *I, + SmallPtrSetImpl<Instruction *> &Set) { + for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) + if (!Set.count(dyn_cast<Instruction>(*Use))) + return false; + return true; +} + +bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_IntegerOr: + case RK_IntegerAnd: + case RK_IntegerXor: + case RK_IntegerMinMax: + return true; + } + return false; +} + +bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) { + return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind); +} + +bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_FloatAdd: + case RK_FloatMult: + return true; + } + return false; +} + +/// Determines if Phi may have been type-promoted. If Phi has a single user +/// that ANDs the Phi with a type mask, return the user. RT is updated to +/// account for the narrower bit width represented by the mask, and the AND +/// instruction is added to CI. +static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl<Instruction *> &Visited, + SmallPtrSetImpl<Instruction *> &CI) { + if (!Phi->hasOneUse()) + return Phi; + + const APInt *M = nullptr; + Instruction *I, *J = cast<Instruction>(Phi->use_begin()->getUser()); + + // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT + // with a new integer type of the corresponding bit width. + if (match(J, m_c_And(m_Instruction(I), m_APInt(M)))) { + int32_t Bits = (*M + 1).exactLogBase2(); + if (Bits > 0) { + RT = IntegerType::get(Phi->getContext(), Bits); + Visited.insert(Phi); + CI.insert(J); + return J; + } + } + return Phi; +} + +/// Compute the minimal bit width needed to represent a reduction whose exit +/// instruction is given by Exit. +static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + bool IsSigned = false; + const DataLayout &DL = Exit->getModule()->getDataLayout(); + uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); + + if (DB) { + // Use the demanded bits analysis to determine the bits that are live out + // of the exit instruction, rounding up to the nearest power of two. If the + // use of demanded bits results in a smaller bit width, we know the value + // must be positive (i.e., IsSigned = false), because if this were not the + // case, the sign bit would have been demanded. + auto Mask = DB->getDemandedBits(Exit); + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + } + + if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { + // If demanded bits wasn't able to limit the bit width, we can try to use + // value tracking instead. This can be the case, for example, if the value + // may be negative. + auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); + MaxBitWidth = NumTypeBits - NumSignBits; + KnownBits Bits = computeKnownBits(Exit, DL); + if (!Bits.isNonNegative()) { + // If the value is not known to be non-negative, we set IsSigned to true, + // meaning that we will use sext instructions instead of zext + // instructions to restore the original type. + IsSigned = true; + if (!Bits.isNegative()) + // If the value is not known to be negative, we don't known what the + // upper bit is, and therefore, we don't know what kind of extend we + // will need. In this case, just increase the bit width by one bit and + // use sext. + ++MaxBitWidth; + } + } + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), + IsSigned); +} + +/// Collect cast instructions that can be ignored in the vectorizer's cost +/// model, given a reduction exit value and the minimal type in which the +/// reduction can be represented. +static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, + Type *RecurrenceType, + SmallPtrSetImpl<Instruction *> &Casts) { + + SmallVector<Instruction *, 8> Worklist; + SmallPtrSet<Instruction *, 8> Visited; + Worklist.push_back(Exit); + + while (!Worklist.empty()) { + Instruction *Val = Worklist.pop_back_val(); + Visited.insert(Val); + if (auto *Cast = dyn_cast<CastInst>(Val)) + if (Cast->getSrcTy() == RecurrenceType) { + // If the source type of a cast instruction is equal to the recurrence + // type, it will be eliminated, and should be ignored in the vectorizer + // cost model. + Casts.insert(Cast); + continue; + } + + // Add all operands to the work list if they are loop-varying values that + // we haven't yet visited. + for (Value *O : cast<User>(Val)->operands()) + if (auto *I = dyn_cast<Instruction>(O)) + if (TheLoop->contains(I) && !Visited.count(I)) + Worklist.push_back(I); + } +} + +bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, + Loop *TheLoop, bool HasFunNoNaNAttr, + RecurrenceDescriptor &RedDes, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + if (Phi->getNumIncomingValues() != 2) + return false; + + // Reduction variables are only found in the loop header block. + if (Phi->getParent() != TheLoop->getHeader()) + return false; + + // Obtain the reduction start value from the value that comes from the loop + // preheader. + Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()); + + // ExitInstruction is the single value which is used outside the loop. + // We only allow for a single reduction value to be used outside the loop. + // This includes users of the reduction, variables (which form a cycle + // which ends in the phi node). + Instruction *ExitInstruction = nullptr; + // Indicates that we found a reduction operation in our scan. + bool FoundReduxOp = false; + + // We start with the PHI node and scan for all of the users of this + // instruction. All users must be instructions that can be used as reduction + // variables (such as ADD). We must have a single out-of-block user. The cycle + // must include the original PHI. + bool FoundStartPHI = false; + + // To recognize min/max patterns formed by a icmp select sequence, we store + // the number of instruction we saw from the recognized min/max pattern, + // to make sure we only see exactly the two instructions. + unsigned NumCmpSelectPatternInst = 0; + InstDesc ReduxDesc(false, nullptr); + + // Data used for determining if the recurrence has been type-promoted. + Type *RecurrenceType = Phi->getType(); + SmallPtrSet<Instruction *, 4> CastInsts; + Instruction *Start = Phi; + bool IsSigned = false; + + SmallPtrSet<Instruction *, 8> VisitedInsts; + SmallVector<Instruction *, 8> Worklist; + + // Return early if the recurrence kind does not match the type of Phi. If the + // recurrence kind is arithmetic, we attempt to look through AND operations + // resulting from the type promotion performed by InstCombine. Vector + // operations are not limited to the legal integer widths, so we may be able + // to evaluate the reduction in the narrower width. + if (RecurrenceType->isFloatingPointTy()) { + if (!isFloatingPointRecurrenceKind(Kind)) + return false; + } else { + if (!isIntegerRecurrenceKind(Kind)) + return false; + if (isArithmeticRecurrenceKind(Kind)) + Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); + } + + Worklist.push_back(Start); + VisitedInsts.insert(Start); + + // Start with all flags set because we will intersect this with the reduction + // flags from all the reduction operations. + FastMathFlags FMF = FastMathFlags::getFast(); + + // A value in the reduction can be used: + // - By the reduction: + // - Reduction operation: + // - One use of reduction value (safe). + // - Multiple use of reduction value (not safe). + // - PHI: + // - All uses of the PHI must be the reduction (safe). + // - Otherwise, not safe. + // - By instructions outside of the loop (safe). + // * One value may have several outside users, but all outside + // uses must be of the same value. + // - By an instruction that is not part of the reduction (not safe). + // This is either: + // * An instruction type other than PHI or the reduction operation. + // * A PHI in the header other than the initial PHI. + while (!Worklist.empty()) { + Instruction *Cur = Worklist.back(); + Worklist.pop_back(); + + // No Users. + // If the instruction has no users then this is a broken chain and can't be + // a reduction variable. + if (Cur->use_empty()) + return false; + + bool IsAPhi = isa<PHINode>(Cur); + + // A header PHI use other than the original PHI. + if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent()) + return false; + + // Reductions of instructions such as Div, and Sub is only possible if the + // LHS is the reduction variable. + if (!Cur->isCommutative() && !IsAPhi && !isa<SelectInst>(Cur) && + !isa<ICmpInst>(Cur) && !isa<FCmpInst>(Cur) && + !VisitedInsts.count(dyn_cast<Instruction>(Cur->getOperand(0)))) + return false; + + // Any reduction instruction must be of one of the allowed kinds. We ignore + // the starting value (the Phi or an AND instruction if the Phi has been + // type-promoted). + if (Cur != Start) { + ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); + if (!ReduxDesc.isRecurrence()) + return false; + // FIXME: FMF is allowed on phi, but propagation is not handled correctly. + if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi) + FMF &= ReduxDesc.getPatternInst()->getFastMathFlags(); + } + + bool IsASelect = isa<SelectInst>(Cur); + + // A conditional reduction operation must only have 2 or less uses in + // VisitedInsts. + if (IsASelect && (Kind == RK_FloatAdd || Kind == RK_FloatMult) && + hasMultipleUsesOf(Cur, VisitedInsts, 2)) + return false; + + // A reduction operation must only have one use of the reduction value. + if (!IsAPhi && !IsASelect && Kind != RK_IntegerMinMax && + Kind != RK_FloatMinMax && hasMultipleUsesOf(Cur, VisitedInsts, 1)) + return false; + + // All inputs to a PHI node must be a reduction value. + if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) + return false; + + if (Kind == RK_IntegerMinMax && + (isa<ICmpInst>(Cur) || isa<SelectInst>(Cur))) + ++NumCmpSelectPatternInst; + if (Kind == RK_FloatMinMax && (isa<FCmpInst>(Cur) || isa<SelectInst>(Cur))) + ++NumCmpSelectPatternInst; + + // Check whether we found a reduction operator. + FoundReduxOp |= !IsAPhi && Cur != Start; + + // Process users of current instruction. Push non-PHI nodes after PHI nodes + // onto the stack. This way we are going to have seen all inputs to PHI + // nodes once we get to them. + SmallVector<Instruction *, 8> NonPHIs; + SmallVector<Instruction *, 8> PHIs; + for (User *U : Cur->users()) { + Instruction *UI = cast<Instruction>(U); + + // Check if we found the exit user. + BasicBlock *Parent = UI->getParent(); + if (!TheLoop->contains(Parent)) { + // If we already know this instruction is used externally, move on to + // the next user. + if (ExitInstruction == Cur) + continue; + + // Exit if you find multiple values used outside or if the header phi + // node is being used. In this case the user uses the value of the + // previous iteration, in which case we would loose "VF-1" iterations of + // the reduction operation if we vectorize. + if (ExitInstruction != nullptr || Cur == Phi) + return false; + + // The instruction used by an outside user must be the last instruction + // before we feed back to the reduction phi. Otherwise, we loose VF-1 + // operations on the value. + if (!is_contained(Phi->operands(), Cur)) + return false; + + ExitInstruction = Cur; + continue; + } + + // Process instructions only once (termination). Each reduction cycle + // value must only be used once, except by phi nodes and min/max + // reductions which are represented as a cmp followed by a select. + InstDesc IgnoredVal(false, nullptr); + if (VisitedInsts.insert(UI).second) { + if (isa<PHINode>(UI)) + PHIs.push_back(UI); + else + NonPHIs.push_back(UI); + } else if (!isa<PHINode>(UI) && + ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) && + !isa<SelectInst>(UI)) || + (!isConditionalRdxPattern(Kind, UI).isRecurrence() && + !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence()))) + return false; + + // Remember that we completed the cycle. + if (UI == Phi) + FoundStartPHI = true; + } + Worklist.append(PHIs.begin(), PHIs.end()); + Worklist.append(NonPHIs.begin(), NonPHIs.end()); + } + + // This means we have seen one but not the other instruction of the + // pattern or more than just a select and cmp. + if ((Kind == RK_IntegerMinMax || Kind == RK_FloatMinMax) && + NumCmpSelectPatternInst != 2) + return false; + + if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) + return false; + + if (Start != Phi) { + // If the starting value is not the same as the phi node, we speculatively + // looked through an 'and' instruction when evaluating a potential + // arithmetic reduction to determine if it may have been type-promoted. + // + // We now compute the minimal bit width that is required to represent the + // reduction. If this is the same width that was indicated by the 'and', we + // can represent the reduction in the smaller type. The 'and' instruction + // will be eliminated since it will essentially be a cast instruction that + // can be ignore in the cost model. If we compute a different type than we + // did when evaluating the 'and', the 'and' will not be eliminated, and we + // will end up with different kinds of operations in the recurrence + // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is + // the case. + // + // The vectorizer relies on InstCombine to perform the actual + // type-shrinking. It does this by inserting instructions to truncate the + // exit value of the reduction to the width indicated by RecurrenceType and + // then extend this value back to the original width. If IsSigned is false, + // a 'zext' instruction will be generated; otherwise, a 'sext' will be + // used. + // + // TODO: We should not rely on InstCombine to rewrite the reduction in the + // smaller type. We should just generate a correctly typed expression + // to begin with. + Type *ComputedType; + std::tie(ComputedType, IsSigned) = + computeRecurrenceType(ExitInstruction, DB, AC, DT); + if (ComputedType != RecurrenceType) + return false; + + // The recurrence expression will be represented in a narrower type. If + // there are any cast instructions that will be unnecessary, collect them + // in CastInsts. Note that the 'and' instruction was already included in + // this list. + // + // TODO: A better way to represent this may be to tag in some way all the + // instructions that are a part of the reduction. The vectorizer cost + // model could then apply the recurrence type to these instructions, + // without needing a white list of instructions to ignore. + collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); + } + + // We found a reduction var if we have reached the original phi node and we + // only have a single instruction with out-of-loop users. + + // The ExitInstruction(Instruction which is allowed to have out-of-loop users) + // is saved as part of the RecurrenceDescriptor. + + // Save the description of this reduction variable. + RecurrenceDescriptor RD( + RdxStart, ExitInstruction, Kind, FMF, ReduxDesc.getMinMaxKind(), + ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); + RedDes = RD; + + return true; +} + +/// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction +/// pattern corresponding to a min(X, Y) or max(X, Y). +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) { + + assert((isa<ICmpInst>(I) || isa<FCmpInst>(I) || isa<SelectInst>(I)) && + "Expect a select instruction"); + Instruction *Cmp = nullptr; + SelectInst *Select = nullptr; + + // We must handle the select(cmp()) as a single instruction. Advance to the + // select. + if ((Cmp = dyn_cast<ICmpInst>(I)) || (Cmp = dyn_cast<FCmpInst>(I))) { + if (!Cmp->hasOneUse() || !(Select = dyn_cast<SelectInst>(*I->user_begin()))) + return InstDesc(false, I); + return InstDesc(Select, Prev.getMinMaxKind()); + } + + // Only handle single use cases for now. + if (!(Select = dyn_cast<SelectInst>(I))) + return InstDesc(false, I); + if (!(Cmp = dyn_cast<ICmpInst>(I->getOperand(0))) && + !(Cmp = dyn_cast<FCmpInst>(I->getOperand(0)))) + return InstDesc(false, I); + if (!Cmp->hasOneUse()) + return InstDesc(false, I); + + Value *CmpLeft; + Value *CmpRight; + + // Look for a min/max pattern. + if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_UIntMin); + else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_UIntMax); + else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_SIntMax); + else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_SIntMin); + else if (m_OrdFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMin); + else if (m_OrdFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMax); + else if (m_UnordFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMin); + else if (m_UnordFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMax); + + return InstDesc(false, I); +} + +/// Returns true if the select instruction has users in the compare-and-add +/// reduction pattern below. The select instruction argument is the last one +/// in the sequence. +/// +/// %sum.1 = phi ... +/// ... +/// %cmp = fcmp pred %0, %CFP +/// %add = fadd %0, %sum.1 +/// %sum.2 = select %cmp, %add, %sum.1 +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isConditionalRdxPattern( + RecurrenceKind Kind, Instruction *I) { + SelectInst *SI = dyn_cast<SelectInst>(I); + if (!SI) + return InstDesc(false, I); + + CmpInst *CI = dyn_cast<CmpInst>(SI->getCondition()); + // Only handle single use cases for now. + if (!CI || !CI->hasOneUse()) + return InstDesc(false, I); + + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + // Handle only when either of operands of select instruction is a PHI + // node for now. + if ((isa<PHINode>(*TrueVal) && isa<PHINode>(*FalseVal)) || + (!isa<PHINode>(*TrueVal) && !isa<PHINode>(*FalseVal))) + return InstDesc(false, I); + + Instruction *I1 = + isa<PHINode>(*TrueVal) ? dyn_cast<Instruction>(FalseVal) + : dyn_cast<Instruction>(TrueVal); + if (!I1 || !I1->isBinaryOp()) + return InstDesc(false, I); + + Value *Op1, *Op2; + if ((m_FAdd(m_Value(Op1), m_Value(Op2)).match(I1) || + m_FSub(m_Value(Op1), m_Value(Op2)).match(I1)) && + I1->isFast()) + return InstDesc(Kind == RK_FloatAdd, SI); + + if (m_FMul(m_Value(Op1), m_Value(Op2)).match(I1) && (I1->isFast())) + return InstDesc(Kind == RK_FloatMult, SI); + + return InstDesc(false, I); +} + +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, + InstDesc &Prev, bool HasFunNoNaNAttr) { + Instruction *UAI = Prev.getUnsafeAlgebraInst(); + if (!UAI && isa<FPMathOperator>(I) && !I->hasAllowReassoc()) + UAI = I; // Found an unsafe (unvectorizable) algebra instruction. + + switch (I->getOpcode()) { + default: + return InstDesc(false, I); + case Instruction::PHI: + return InstDesc(I, Prev.getMinMaxKind(), Prev.getUnsafeAlgebraInst()); + case Instruction::Sub: + case Instruction::Add: + return InstDesc(Kind == RK_IntegerAdd, I); + case Instruction::Mul: + return InstDesc(Kind == RK_IntegerMult, I); + case Instruction::And: + return InstDesc(Kind == RK_IntegerAnd, I); + case Instruction::Or: + return InstDesc(Kind == RK_IntegerOr, I); + case Instruction::Xor: + return InstDesc(Kind == RK_IntegerXor, I); + case Instruction::FMul: + return InstDesc(Kind == RK_FloatMult, I, UAI); + case Instruction::FSub: + case Instruction::FAdd: + return InstDesc(Kind == RK_FloatAdd, I, UAI); + case Instruction::Select: + if (Kind == RK_FloatAdd || Kind == RK_FloatMult) + return isConditionalRdxPattern(Kind, I); + LLVM_FALLTHROUGH; + case Instruction::FCmp: + case Instruction::ICmp: + if (Kind != RK_IntegerMinMax && + (!HasFunNoNaNAttr || Kind != RK_FloatMinMax)) + return InstDesc(false, I); + return isMinMaxSelectCmpPattern(I, Prev); + } +} + +bool RecurrenceDescriptor::hasMultipleUsesOf( + Instruction *I, SmallPtrSetImpl<Instruction *> &Insts, + unsigned MaxNumUses) { + unsigned NumUses = 0; + for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; + ++Use) { + if (Insts.count(dyn_cast<Instruction>(*Use))) + ++NumUses; + if (NumUses > MaxNumUses) + return true; + } + + return false; +} +bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, + RecurrenceDescriptor &RedDes, + DemandedBits *DB, AssumptionCache *AC, + DominatorTree *DT) { + + BasicBlock *Header = TheLoop->getHeader(); + Function &F = *Header->getParent(); + bool HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + + if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, + DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi + << "\n"); + return true; + } + // Not a reduction of known type. + return false; +} + +bool RecurrenceDescriptor::isFirstOrderRecurrence( + PHINode *Phi, Loop *TheLoop, + DenseMap<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) { + + // Ensure the phi node is in the loop header and has two incoming values. + if (Phi->getParent() != TheLoop->getHeader() || + Phi->getNumIncomingValues() != 2) + return false; + + // Ensure the loop has a preheader and a single latch block. The loop + // vectorizer will need the latch to set up the next iteration of the loop. + auto *Preheader = TheLoop->getLoopPreheader(); + auto *Latch = TheLoop->getLoopLatch(); + if (!Preheader || !Latch) + return false; + + // Ensure the phi node's incoming blocks are the loop preheader and latch. + if (Phi->getBasicBlockIndex(Preheader) < 0 || + Phi->getBasicBlockIndex(Latch) < 0) + return false; + + // Get the previous value. The previous value comes from the latch edge while + // the initial value comes form the preheader edge. + auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch)); + if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) || + SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. + return false; + + // Ensure every user of the phi node is dominated by the previous value. + // The dominance requirement ensures the loop vectorizer will not need to + // vectorize the initial value prior to the first iteration of the loop. + // TODO: Consider extending this sinking to handle other kinds of instructions + // and expressions, beyond sinking a single cast past Previous. + if (Phi->hasOneUse()) { + auto *I = Phi->user_back(); + if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && + DT->dominates(Previous, I->user_back())) { + if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking. + SinkAfter[I] = Previous; + return true; + } + } + + for (User *U : Phi->users()) + if (auto *I = dyn_cast<Instruction>(U)) { + if (!DT->dominates(Previous, I)) + return false; + } + + return true; +} + +/// This function returns the identity element (or neutral element) for +/// the operation K. +Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K, + Type *Tp) { + switch (K) { + case RK_IntegerXor: + case RK_IntegerAdd: + case RK_IntegerOr: + // Adding, Xoring, Oring zero to a number does not change it. + return ConstantInt::get(Tp, 0); + case RK_IntegerMult: + // Multiplying a number by 1 does not change it. + return ConstantInt::get(Tp, 1); + case RK_IntegerAnd: + // AND-ing a number with an all-1 value does not change it. + return ConstantInt::get(Tp, -1, true); + case RK_FloatMult: + // Multiplying a number by 1 does not change it. + return ConstantFP::get(Tp, 1.0L); + case RK_FloatAdd: + // Adding zero to a number does not change it. + return ConstantFP::get(Tp, 0.0L); + default: + llvm_unreachable("Unknown recurrence kind"); + } +} + +/// This function translates the recurrence kind to an LLVM binary operator. +unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurrenceKind Kind) { + switch (Kind) { + case RK_IntegerAdd: + return Instruction::Add; + case RK_IntegerMult: + return Instruction::Mul; + case RK_IntegerOr: + return Instruction::Or; + case RK_IntegerAnd: + return Instruction::And; + case RK_IntegerXor: + return Instruction::Xor; + case RK_FloatMult: + return Instruction::FMul; + case RK_FloatAdd: + return Instruction::FAdd; + case RK_IntegerMinMax: + return Instruction::ICmp; + case RK_FloatMinMax: + return Instruction::FCmp; + default: + llvm_unreachable("Unknown recurrence operation"); + } +} + +InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, + const SCEV *Step, BinaryOperator *BOp, + SmallVectorImpl<Instruction *> *Casts) + : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { + assert(IK != IK_NoInduction && "Not an induction"); + + // Start value type should match the induction kind and the value + // itself should not be null. + assert(StartValue && "StartValue is null"); + assert((IK != IK_PtrInduction || StartValue->getType()->isPointerTy()) && + "StartValue is not a pointer for pointer induction"); + assert((IK != IK_IntInduction || StartValue->getType()->isIntegerTy()) && + "StartValue is not an integer for integer induction"); + + // Check the Step Value. It should be non-zero integer value. + assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) && + "Step value is zero"); + + assert((IK != IK_PtrInduction || getConstIntStepValue()) && + "Step value should be constant for pointer induction"); + assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) && + "StepValue is not an integer"); + + assert((IK != IK_FpInduction || Step->getType()->isFloatingPointTy()) && + "StepValue is not FP for FpInduction"); + assert((IK != IK_FpInduction || + (InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub))) && + "Binary opcode should be specified for FP induction"); + + if (Casts) { + for (auto &Inst : *Casts) { + RedundantCasts.push_back(Inst); + } + } +} + +int InductionDescriptor::getConsecutiveDirection() const { + ConstantInt *ConstStep = getConstIntStepValue(); + if (ConstStep && (ConstStep->isOne() || ConstStep->isMinusOne())) + return ConstStep->getSExtValue(); + return 0; +} + +ConstantInt *InductionDescriptor::getConstIntStepValue() const { + if (isa<SCEVConstant>(Step)) + return dyn_cast<ConstantInt>(cast<SCEVConstant>(Step)->getValue()); + return nullptr; +} + +bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, + ScalarEvolution *SE, + InductionDescriptor &D) { + + // Here we only handle FP induction variables. + assert(Phi->getType()->isFloatingPointTy() && "Unexpected Phi type"); + + if (TheLoop->getHeader() != Phi->getParent()) + return false; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi if it has a unique entry value and a unique backedge value. + if (Phi->getNumIncomingValues() != 2) + return false; + Value *BEValue = nullptr, *StartValue = nullptr; + if (TheLoop->contains(Phi->getIncomingBlock(0))) { + BEValue = Phi->getIncomingValue(0); + StartValue = Phi->getIncomingValue(1); + } else { + assert(TheLoop->contains(Phi->getIncomingBlock(1)) && + "Unexpected Phi node in the loop"); + BEValue = Phi->getIncomingValue(1); + StartValue = Phi->getIncomingValue(0); + } + + BinaryOperator *BOp = dyn_cast<BinaryOperator>(BEValue); + if (!BOp) + return false; + + Value *Addend = nullptr; + if (BOp->getOpcode() == Instruction::FAdd) { + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + else if (BOp->getOperand(1) == Phi) + Addend = BOp->getOperand(0); + } else if (BOp->getOpcode() == Instruction::FSub) + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + + if (!Addend) + return false; + + // The addend should be loop invariant + if (auto *I = dyn_cast<Instruction>(Addend)) + if (TheLoop->contains(I)) + return false; + + // FP Step has unknown SCEV + const SCEV *Step = SE->getUnknown(Addend); + D = InductionDescriptor(StartValue, IK_FpInduction, Step, BOp); + return true; +} + +/// This function is called when we suspect that the update-chain of a phi node +/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, +/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime +/// predicate P under which the SCEV expression for the phi can be the +/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the +/// cast instructions that are involved in the update-chain of this induction. +/// A caller that adds the required runtime predicate can be free to drop these +/// cast instructions, and compute the phi using \p AR (instead of some scev +/// expression with casts). +/// +/// For example, without a predicate the scev expression can take the following +/// form: +/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy) +/// +/// It corresponds to the following IR sequence: +/// %for.body: +/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ] +/// %casted_phi = "ExtTrunc i64 %x" +/// %add = add i64 %casted_phi, %step +/// +/// where %x is given in \p PN, +/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate, +/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of +/// several forms, for example, such as: +/// ExtTrunc1: %casted_phi = and %x, 2^n-1 +/// or: +/// ExtTrunc2: %t = shl %x, m +/// %casted_phi = ashr %t, m +/// +/// If we are able to find such sequence, we return the instructions +/// we found, namely %casted_phi and the instructions on its use-def chain up +/// to the phi (not including the phi). +static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE, + const SCEVUnknown *PhiScev, + const SCEVAddRecExpr *AR, + SmallVectorImpl<Instruction *> &CastInsts) { + + assert(CastInsts.empty() && "CastInsts is expected to be empty."); + auto *PN = cast<PHINode>(PhiScev->getValue()); + assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); + const Loop *L = AR->getLoop(); + + // Find any cast instructions that participate in the def-use chain of + // PhiScev in the loop. + // FORNOW/TODO: We currently expect the def-use chain to include only + // two-operand instructions, where one of the operands is an invariant. + // createAddRecFromPHIWithCasts() currently does not support anything more + // involved than that, so we keep the search simple. This can be + // extended/generalized as needed. + + auto getDef = [&](const Value *Val) -> Value * { + const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val); + if (!BinOp) + return nullptr; + Value *Op0 = BinOp->getOperand(0); + Value *Op1 = BinOp->getOperand(1); + Value *Def = nullptr; + if (L->isLoopInvariant(Op0)) + Def = Op1; + else if (L->isLoopInvariant(Op1)) + Def = Op0; + return Def; + }; + + // Look for the instruction that defines the induction via the + // loop backedge. + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + Value *Val = PN->getIncomingValueForBlock(Latch); + if (!Val) + return false; + + // Follow the def-use chain until the induction phi is reached. + // If on the way we encounter a Value that has the same SCEV Expr as the + // phi node, we can consider the instructions we visit from that point + // as part of the cast-sequence that can be ignored. + bool InCastSequence = false; + auto *Inst = dyn_cast<Instruction>(Val); + while (Val != PN) { + // If we encountered a phi node other than PN, or if we left the loop, + // we bail out. + if (!Inst || !L->contains(Inst)) { + return false; + } + auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val)); + if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR)) + InCastSequence = true; + if (InCastSequence) { + // Only the last instruction in the cast sequence is expected to have + // uses outside the induction def-use chain. + if (!CastInsts.empty()) + if (!Inst->hasOneUse()) + return false; + CastInsts.push_back(Inst); + } + Val = getDef(Val); + if (!Val) + return false; + Inst = dyn_cast<Instruction>(Val); + } + + return InCastSequence; +} + +bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, + PredicatedScalarEvolution &PSE, + InductionDescriptor &D, bool Assume) { + Type *PhiTy = Phi->getType(); + + // Handle integer and pointer inductions variables. + // Now we handle also FP induction but not trying to make a + // recurrent expression from the PHI node in-place. + + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy() && !PhiTy->isFloatTy() && + !PhiTy->isDoubleTy() && !PhiTy->isHalfTy()) + return false; + + if (PhiTy->isFloatingPointTy()) + return isFPInductionPHI(Phi, TheLoop, PSE.getSE(), D); + + const SCEV *PhiScev = PSE.getSCEV(Phi); + const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + + // We need this expression to be an AddRecExpr. + if (Assume && !AR) + AR = PSE.getAsAddRec(Phi); + + if (!AR) { + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + return false; + } + + // Record any Cast instructions that participate in the induction update + const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev); + // If we started from an UnknownSCEV, and managed to build an addRecurrence + // only after enabling Assume with PSCEV, this means we may have encountered + // cast instructions that required adding a runtime check in order to + // guarantee the correctness of the AddRecurrence respresentation of the + // induction. + if (PhiScev != AR && SymbolicPhi) { + SmallVector<Instruction *, 2> Casts; + if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts)) + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts); + } + + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); +} + +bool InductionDescriptor::isInductionPHI( + PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, + InductionDescriptor &D, const SCEV *Expr, + SmallVectorImpl<Instruction *> *CastsToIgnore) { + Type *PhiTy = Phi->getType(); + // We only handle integer and pointer inductions variables. + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) + return false; + + // Check that the PHI is consecutive. + const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + + if (!AR) { + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + return false; + } + + if (AR->getLoop() != TheLoop) { + // FIXME: We should treat this as a uniform. Unfortunately, we + // don't currently know how to handled uniform PHIs. + LLVM_DEBUG( + dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); + return false; + } + + Value *StartValue = + Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); + + BasicBlock *Latch = AR->getLoop()->getLoopLatch(); + if (!Latch) + return false; + BinaryOperator *BOp = + dyn_cast<BinaryOperator>(Phi->getIncomingValueForBlock(Latch)); + + const SCEV *Step = AR->getStepRecurrence(*SE); + // Calculate the pointer stride and check if it is consecutive. + // The stride may be a constant or a loop invariant integer value. + const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step); + if (!ConstStep && !SE->isLoopInvariant(Step, TheLoop)) + return false; + + if (PhiTy->isIntegerTy()) { + D = InductionDescriptor(StartValue, IK_IntInduction, Step, BOp, + CastsToIgnore); + return true; + } + + assert(PhiTy->isPointerTy() && "The PHI must be a pointer"); + // Pointer induction should be a constant. + if (!ConstStep) + return false; + + ConstantInt *CV = ConstStep->getValue(); + Type *PointerElementType = PhiTy->getPointerElementType(); + // The pointer stride cannot be determined if the pointer element type is not + // sized. + if (!PointerElementType->isSized()) + return false; + + const DataLayout &DL = Phi->getModule()->getDataLayout(); + int64_t Size = static_cast<int64_t>(DL.getTypeAllocSize(PointerElementType)); + if (!Size) + return false; + + int64_t CVSize = CV->getSExtValue(); + if (CVSize % Size) + return false; + auto *StepValue = + SE->getConstant(CV->getType(), CVSize / Size, true /* signed */); + D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue, BOp); + return true; +} diff --git a/llvm/lib/Analysis/IVUsers.cpp b/llvm/lib/Analysis/IVUsers.cpp new file mode 100644 index 000000000000..681a0cf7e981 --- /dev/null +++ b/llvm/lib/Analysis/IVUsers.cpp @@ -0,0 +1,426 @@ +//===- IVUsers.cpp - Induction Variable Users -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements bookkeeping for "interesting" users of expressions +// computed from induction variables. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IVUsers.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +using namespace llvm; + +#define DEBUG_TYPE "iv-users" + +AnalysisKey IVUsersAnalysis::Key; + +IVUsers IVUsersAnalysis::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR) { + return IVUsers(&L, &AR.AC, &AR.LI, &AR.DT, &AR.SE); +} + +char IVUsersWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(IVUsersWrapperPass, "iv-users", + "Induction Variable Users", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(IVUsersWrapperPass, "iv-users", "Induction Variable Users", + false, true) + +Pass *llvm::createIVUsersPass() { return new IVUsersWrapperPass(); } + +/// isInteresting - Test whether the given expression is "interesting" when +/// used by the given expression, within the context of analyzing the +/// given loop. +static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L, + ScalarEvolution *SE, LoopInfo *LI) { + // An addrec is interesting if it's affine or if it has an interesting start. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + // Keep things simple. Don't touch loop-variant strides unless they're + // only used outside the loop and we can simplify them. + if (AR->getLoop() == L) + return AR->isAffine() || + (!L->contains(I) && + SE->getSCEVAtScope(AR, LI->getLoopFor(I->getParent())) != AR); + // Otherwise recurse to see if the start value is interesting, and that + // the step value is not interesting, since we don't yet know how to + // do effective SCEV expansions for addrecs with interesting steps. + return isInteresting(AR->getStart(), I, L, SE, LI) && + !isInteresting(AR->getStepRecurrence(*SE), I, L, SE, LI); + } + + // An add is interesting if exactly one of its operands is interesting. + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + bool AnyInterestingYet = false; + for (const auto *Op : Add->operands()) + if (isInteresting(Op, I, L, SE, LI)) { + if (AnyInterestingYet) + return false; + AnyInterestingYet = true; + } + return AnyInterestingYet; + } + + // Nothing else is interesting here. + return false; +} + +/// Return true if all loop headers that dominate this block are in simplified +/// form. +static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, + const LoopInfo *LI, + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { + Loop *NearestLoop = nullptr; + for (DomTreeNode *Rung = DT->getNode(BB); + Rung; Rung = Rung->getIDom()) { + BasicBlock *DomBB = Rung->getBlock(); + Loop *DomLoop = LI->getLoopFor(DomBB); + if (DomLoop && DomLoop->getHeader() == DomBB) { + // If the domtree walk reaches a loop with no preheader, return false. + if (!DomLoop->isLoopSimplifyForm()) + return false; + // If we have already checked this loop nest, stop checking. + if (SimpleLoopNests.count(DomLoop)) + break; + // If we have not already checked this loop nest, remember the loop + // header nearest to BB. The nearest loop may not contain BB. + if (!NearestLoop) + NearestLoop = DomLoop; + } + } + if (NearestLoop) + SimpleLoopNests.insert(NearestLoop); + return true; +} + +/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression +/// and now we need to decide whether the user should use the preinc or post-inc +/// value. If this user should use the post-inc version of the IV, return true. +/// +/// Choosing wrong here can break dominance properties (if we choose to use the +/// post-inc value when we cannot) or it can end up adding extra live-ranges to +/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we +/// should use the post-inc value). +static bool IVUseShouldUsePostIncValue(Instruction *User, Value *Operand, + const Loop *L, DominatorTree *DT) { + // If the user is in the loop, use the preinc value. + if (L->contains(User)) + return false; + + BasicBlock *LatchBlock = L->getLoopLatch(); + if (!LatchBlock) + return false; + + // Ok, the user is outside of the loop. If it is dominated by the latch + // block, use the post-inc value. + if (DT->dominates(LatchBlock, User->getParent())) + return true; + + // There is one case we have to be careful of: PHI nodes. These little guys + // can live in blocks that are not dominated by the latch block, but (since + // their uses occur in the predecessor block, not the block the PHI lives in) + // should still use the post-inc value. Check for this case now. + PHINode *PN = dyn_cast<PHINode>(User); + if (!PN || !Operand) + return false; // not a phi, not dominated by latch block. + + // Look at all of the uses of Operand by the PHI node. If any use corresponds + // to a block that is not dominated by the latch block, give up and use the + // preincremented value. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == Operand && + !DT->dominates(LatchBlock, PN->getIncomingBlock(i))) + return false; + + // Okay, all uses of Operand by PN are in predecessor blocks that really are + // dominated by the latch block. Use the post-incremented value. + return true; +} + +/// AddUsersImpl - Inspect the specified instruction. If it is a +/// reducible SCEV, recursively add its users to the IVUsesByStride set and +/// return true. Otherwise, return false. +bool IVUsers::AddUsersImpl(Instruction *I, + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { + const DataLayout &DL = I->getModule()->getDataLayout(); + + // Add this IV user to the Processed set before returning false to ensure that + // all IV users are members of the set. See IVUsers::isIVUserOrOperand. + if (!Processed.insert(I).second) + return true; // Instruction already handled. + + if (!SE->isSCEVable(I->getType())) + return false; // Void and FP expressions cannot be reduced. + + // IVUsers is used by LSR which assumes that all SCEV expressions are safe to + // pass to SCEVExpander. Expressions are not safe to expand if they represent + // operations that are not safe to speculate, namely integer division. + if (!isa<PHINode>(I) && !isSafeToSpeculativelyExecute(I)) + return false; + + // LSR is not APInt clean, do not touch integers bigger than 64-bits. + // Also avoid creating IVs of non-native types. For example, we don't want a + // 64-bit IV in 32-bit code just because the loop has one 64-bit cast. + uint64_t Width = SE->getTypeSizeInBits(I->getType()); + if (Width > 64 || !DL.isLegalInteger(Width)) + return false; + + // Don't attempt to promote ephemeral values to indvars. They will be removed + // later anyway. + if (EphValues.count(I)) + return false; + + // Get the symbolic expression for this instruction. + const SCEV *ISE = SE->getSCEV(I); + + // If we've come to an uninteresting expression, stop the traversal and + // call this a user. + if (!isInteresting(ISE, I, L, SE, LI)) + return false; + + SmallPtrSet<Instruction *, 4> UniqueUsers; + for (Use &U : I->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + if (!UniqueUsers.insert(User).second) + continue; + + // Do not infinitely recurse on PHI nodes. + if (isa<PHINode>(User) && Processed.count(User)) + continue; + + // Only consider IVUsers that are dominated by simplified loop + // headers. Otherwise, SCEVExpander will crash. + BasicBlock *UseBB = User->getParent(); + // A phi's use is live out of its predecessor block. + if (PHINode *PHI = dyn_cast<PHINode>(User)) { + unsigned OperandNo = U.getOperandNo(); + unsigned ValNo = PHINode::getIncomingValueNumForOperand(OperandNo); + UseBB = PHI->getIncomingBlock(ValNo); + } + if (!isSimplifiedLoopNest(UseBB, DT, LI, SimpleLoopNests)) + return false; + + // Descend recursively, but not into PHI nodes outside the current loop. + // It's important to see the entire expression outside the loop to get + // choices that depend on addressing mode use right, although we won't + // consider references outside the loop in all cases. + // If User is already in Processed, we don't want to recurse into it again, + // but do want to record a second reference in the same instruction. + bool AddUserToIVUsers = false; + if (LI->getLoopFor(User->getParent()) != L) { + if (isa<PHINode>(User) || Processed.count(User) || + !AddUsersImpl(User, SimpleLoopNests)) { + LLVM_DEBUG(dbgs() << "FOUND USER in other loop: " << *User << '\n' + << " OF SCEV: " << *ISE << '\n'); + AddUserToIVUsers = true; + } + } else if (Processed.count(User) || !AddUsersImpl(User, SimpleLoopNests)) { + LLVM_DEBUG(dbgs() << "FOUND USER: " << *User << '\n' + << " OF SCEV: " << *ISE << '\n'); + AddUserToIVUsers = true; + } + + if (AddUserToIVUsers) { + // Okay, we found a user that we cannot reduce. + IVStrideUse &NewUse = AddUser(User, I); + // Autodetect the post-inc loop set, populating NewUse.PostIncLoops. + // The regular return value here is discarded; instead of recording + // it, we just recompute it when we need it. + const SCEV *OriginalISE = ISE; + + auto NormalizePred = [&](const SCEVAddRecExpr *AR) { + auto *L = AR->getLoop(); + bool Result = IVUseShouldUsePostIncValue(User, I, L, DT); + if (Result) + NewUse.PostIncLoops.insert(L); + return Result; + }; + + ISE = normalizeForPostIncUseIf(ISE, NormalizePred, *SE); + + // PostIncNormalization effectively simplifies the expression under + // pre-increment assumptions. Those assumptions (no wrapping) might not + // hold for the post-inc value. Catch such cases by making sure the + // transformation is invertible. + if (OriginalISE != ISE) { + const SCEV *DenormalizedISE = + denormalizeForPostIncUse(ISE, NewUse.PostIncLoops, *SE); + + // If we normalized the expression, but denormalization doesn't give the + // original one, discard this user. + if (OriginalISE != DenormalizedISE) { + LLVM_DEBUG(dbgs() + << " DISCARDING (NORMALIZATION ISN'T INVERTIBLE): " + << *ISE << '\n'); + IVUses.pop_back(); + return false; + } + } + LLVM_DEBUG(if (SE->getSCEV(I) != ISE) dbgs() + << " NORMALIZED TO: " << *ISE << '\n'); + } + } + return true; +} + +bool IVUsers::AddUsersIfInteresting(Instruction *I) { + // SCEVExpander can only handle users that are dominated by simplified loop + // entries. Keep track of all loops that are only dominated by other simple + // loops so we don't traverse the domtree for each user. + SmallPtrSet<Loop*,16> SimpleLoopNests; + + return AddUsersImpl(I, SimpleLoopNests); +} + +IVStrideUse &IVUsers::AddUser(Instruction *User, Value *Operand) { + IVUses.push_back(new IVStrideUse(this, User, Operand)); + return IVUses.back(); +} + +IVUsers::IVUsers(Loop *L, AssumptionCache *AC, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) + : L(L), AC(AC), LI(LI), DT(DT), SE(SE), IVUses() { + // Collect ephemeral values so that AddUsersIfInteresting skips them. + EphValues.clear(); + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + // Find all uses of induction variables in this loop, and categorize + // them by stride. Start by finding all of the PHI nodes in the header for + // this loop. If they are induction variables, inspect their uses. + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) + (void)AddUsersIfInteresting(&*I); +} + +void IVUsers::print(raw_ostream &OS, const Module *M) const { + OS << "IV Users for loop "; + L->getHeader()->printAsOperand(OS, false); + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << " with backedge-taken count " << *SE->getBackedgeTakenCount(L); + } + OS << ":\n"; + + for (const IVStrideUse &IVUse : IVUses) { + OS << " "; + IVUse.getOperandValToReplace()->printAsOperand(OS, false); + OS << " = " << *getReplacementExpr(IVUse); + for (auto PostIncLoop : IVUse.PostIncLoops) { + OS << " (post-inc with loop "; + PostIncLoop->getHeader()->printAsOperand(OS, false); + OS << ")"; + } + OS << " in "; + if (IVUse.getUser()) + IVUse.getUser()->print(OS); + else + OS << "Printing <null> User"; + OS << '\n'; + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void IVUsers::dump() const { print(dbgs()); } +#endif + +void IVUsers::releaseMemory() { + Processed.clear(); + IVUses.clear(); +} + +IVUsersWrapperPass::IVUsersWrapperPass() : LoopPass(ID) { + initializeIVUsersWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void IVUsersWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.setPreservesAll(); +} + +bool IVUsersWrapperPass::runOnLoop(Loop *L, LPPassManager &LPM) { + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + + IU.reset(new IVUsers(L, AC, LI, DT, SE)); + return false; +} + +void IVUsersWrapperPass::print(raw_ostream &OS, const Module *M) const { + IU->print(OS, M); +} + +void IVUsersWrapperPass::releaseMemory() { IU->releaseMemory(); } + +/// getReplacementExpr - Return a SCEV expression which computes the +/// value of the OperandValToReplace. +const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const { + return SE->getSCEV(IU.getOperandValToReplace()); +} + +/// getExpr - Return the expression for the use. +const SCEV *IVUsers::getExpr(const IVStrideUse &IU) const { + return normalizeForPostIncUse(getReplacementExpr(IU), IU.getPostIncLoops(), + *SE); +} + +static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + if (AR->getLoop() == L) + return AR; + return findAddRecForLoop(AR->getStart(), L); + } + + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + for (const auto *Op : Add->operands()) + if (const SCEVAddRecExpr *AR = findAddRecForLoop(Op, L)) + return AR; + return nullptr; + } + + return nullptr; +} + +const SCEV *IVUsers::getStride(const IVStrideUse &IU, const Loop *L) const { + if (const SCEVAddRecExpr *AR = findAddRecForLoop(getExpr(IU), L)) + return AR->getStepRecurrence(*SE); + return nullptr; +} + +void IVStrideUse::transformToPostInc(const Loop *L) { + PostIncLoops.insert(L); +} + +void IVStrideUse::deleted() { + // Remove this user from the list. + Parent->Processed.erase(this->getUser()); + Parent->IVUses.erase(this); + // this now dangles! +} diff --git a/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp new file mode 100644 index 000000000000..68153de8219f --- /dev/null +++ b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -0,0 +1,106 @@ +//===-- IndirectCallPromotionAnalysis.cpp - Find promotion candidates ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper methods for identifying profitable indirect call promotion +// candidates for an instruction when the indirect-call value profile metadata +// is available. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/IndirectCallVisitor.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Debug.h" +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "pgo-icall-prom-analysis" + +// The percent threshold for the direct-call target (this call site vs the +// remaining call count) for it to be considered as the promotion target. +static cl::opt<unsigned> ICPRemainingPercentThreshold( + "icp-remaining-percent-threshold", cl::init(30), cl::Hidden, cl::ZeroOrMore, + cl::desc("The percentage threshold against remaining unpromoted indirect " + "call count for the promotion")); + +// The percent threshold for the direct-call target (this call site vs the +// total call count) for it to be considered as the promotion target. +static cl::opt<unsigned> + ICPTotalPercentThreshold("icp-total-percent-threshold", cl::init(5), + cl::Hidden, cl::ZeroOrMore, + cl::desc("The percentage threshold against total " + "count for the promotion")); + +// Set the maximum number of targets to promote for a single indirect-call +// callsite. +static cl::opt<unsigned> + MaxNumPromotions("icp-max-prom", cl::init(3), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of promotions for a single indirect " + "call callsite")); + +ICallPromotionAnalysis::ICallPromotionAnalysis() { + ValueDataArray = std::make_unique<InstrProfValueData[]>(MaxNumPromotions); +} + +bool ICallPromotionAnalysis::isPromotionProfitable(uint64_t Count, + uint64_t TotalCount, + uint64_t RemainingCount) { + return Count * 100 >= ICPRemainingPercentThreshold * RemainingCount && + Count * 100 >= ICPTotalPercentThreshold * TotalCount; +} + +// Indirect-call promotion heuristic. The direct targets are sorted based on +// the count. Stop at the first target that is not promoted. Returns the +// number of candidates deemed profitable. +uint32_t ICallPromotionAnalysis::getProfitablePromotionCandidates( + const Instruction *Inst, uint32_t NumVals, uint64_t TotalCount) { + ArrayRef<InstrProfValueData> ValueDataRef(ValueDataArray.get(), NumVals); + + LLVM_DEBUG(dbgs() << " \nWork on callsite " << *Inst + << " Num_targets: " << NumVals << "\n"); + + uint32_t I = 0; + uint64_t RemainingCount = TotalCount; + for (; I < MaxNumPromotions && I < NumVals; I++) { + uint64_t Count = ValueDataRef[I].Count; + assert(Count <= RemainingCount); + LLVM_DEBUG(dbgs() << " Candidate " << I << " Count=" << Count + << " Target_func: " << ValueDataRef[I].Value << "\n"); + + if (!isPromotionProfitable(Count, TotalCount, RemainingCount)) { + LLVM_DEBUG(dbgs() << " Not promote: Cold target.\n"); + return I; + } + RemainingCount -= Count; + } + return I; +} + +ArrayRef<InstrProfValueData> +ICallPromotionAnalysis::getPromotionCandidatesForInstruction( + const Instruction *I, uint32_t &NumVals, uint64_t &TotalCount, + uint32_t &NumCandidates) { + bool Res = + getValueProfDataFromInst(*I, IPVK_IndirectCallTarget, MaxNumPromotions, + ValueDataArray.get(), NumVals, TotalCount); + if (!Res) { + NumCandidates = 0; + return ArrayRef<InstrProfValueData>(); + } + NumCandidates = getProfitablePromotionCandidates(I, NumVals, TotalCount); + return ArrayRef<InstrProfValueData>(ValueDataArray.get(), NumVals); +} diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp new file mode 100644 index 000000000000..89811ec0e377 --- /dev/null +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -0,0 +1,2223 @@ +//===- InlineCost.cpp - Cost analysis for inliner -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements inline cost analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InlineCost.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "inline-cost" + +STATISTIC(NumCallsAnalyzed, "Number of call sites analyzed"); + +static cl::opt<int> InlineThreshold( + "inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore, + cl::desc("Control the amount of inlining to perform (default = 225)")); + +static cl::opt<int> HintThreshold( + "inlinehint-threshold", cl::Hidden, cl::init(325), cl::ZeroOrMore, + cl::desc("Threshold for inlining functions with inline hint")); + +static cl::opt<int> + ColdCallSiteThreshold("inline-cold-callsite-threshold", cl::Hidden, + cl::init(45), cl::ZeroOrMore, + cl::desc("Threshold for inlining cold callsites")); + +// We introduce this threshold to help performance of instrumentation based +// PGO before we actually hook up inliner with analysis passes such as BPI and +// BFI. +static cl::opt<int> ColdThreshold( + "inlinecold-threshold", cl::Hidden, cl::init(45), cl::ZeroOrMore, + cl::desc("Threshold for inlining functions with cold attribute")); + +static cl::opt<int> + HotCallSiteThreshold("hot-callsite-threshold", cl::Hidden, cl::init(3000), + cl::ZeroOrMore, + cl::desc("Threshold for hot callsites ")); + +static cl::opt<int> LocallyHotCallSiteThreshold( + "locally-hot-callsite-threshold", cl::Hidden, cl::init(525), cl::ZeroOrMore, + cl::desc("Threshold for locally hot callsites ")); + +static cl::opt<int> ColdCallSiteRelFreq( + "cold-callsite-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore, + cl::desc("Maximum block frequency, expressed as a percentage of caller's " + "entry frequency, for a callsite to be cold in the absence of " + "profile information.")); + +static cl::opt<int> HotCallSiteRelFreq( + "hot-callsite-rel-freq", cl::Hidden, cl::init(60), cl::ZeroOrMore, + cl::desc("Minimum block frequency, expressed as a multiple of caller's " + "entry frequency, for a callsite to be hot in the absence of " + "profile information.")); + +static cl::opt<bool> OptComputeFullInlineCost( + "inline-cost-full", cl::Hidden, cl::init(false), cl::ZeroOrMore, + cl::desc("Compute the full inline cost of a call site even when the cost " + "exceeds the threshold.")); + +namespace { + +class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { + typedef InstVisitor<CallAnalyzer, bool> Base; + friend class InstVisitor<CallAnalyzer, bool>; + + /// The TargetTransformInfo available for this compilation. + const TargetTransformInfo &TTI; + + /// Getter for the cache of @llvm.assume intrinsics. + std::function<AssumptionCache &(Function &)> &GetAssumptionCache; + + /// Getter for BlockFrequencyInfo + Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI; + + /// Profile summary information. + ProfileSummaryInfo *PSI; + + /// The called function. + Function &F; + + // Cache the DataLayout since we use it a lot. + const DataLayout &DL; + + /// The OptimizationRemarkEmitter available for this compilation. + OptimizationRemarkEmitter *ORE; + + /// The candidate callsite being analyzed. Please do not use this to do + /// analysis in the caller function; we want the inline cost query to be + /// easily cacheable. Instead, use the cover function paramHasAttr. + CallBase &CandidateCall; + + /// Tunable parameters that control the analysis. + const InlineParams &Params; + + /// Upper bound for the inlining cost. Bonuses are being applied to account + /// for speculative "expected profit" of the inlining decision. + int Threshold; + + /// Inlining cost measured in abstract units, accounts for all the + /// instructions expected to be executed for a given function invocation. + /// Instructions that are statically proven to be dead based on call-site + /// arguments are not counted here. + int Cost = 0; + + bool ComputeFullInlineCost; + + bool IsCallerRecursive = false; + bool IsRecursiveCall = false; + bool ExposesReturnsTwice = false; + bool HasDynamicAlloca = false; + bool ContainsNoDuplicateCall = false; + bool HasReturn = false; + bool HasIndirectBr = false; + bool HasUninlineableIntrinsic = false; + bool InitsVargArgs = false; + + /// Number of bytes allocated statically by the callee. + uint64_t AllocatedSize = 0; + unsigned NumInstructions = 0; + unsigned NumVectorInstructions = 0; + + /// Bonus to be applied when percentage of vector instructions in callee is + /// high (see more details in updateThreshold). + int VectorBonus = 0; + /// Bonus to be applied when the callee has only one reachable basic block. + int SingleBBBonus = 0; + + /// While we walk the potentially-inlined instructions, we build up and + /// maintain a mapping of simplified values specific to this callsite. The + /// idea is to propagate any special information we have about arguments to + /// this call through the inlinable section of the function, and account for + /// likely simplifications post-inlining. The most important aspect we track + /// is CFG altering simplifications -- when we prove a basic block dead, that + /// can cause dramatic shifts in the cost of inlining a function. + DenseMap<Value *, Constant *> SimplifiedValues; + + /// Keep track of the values which map back (through function arguments) to + /// allocas on the caller stack which could be simplified through SROA. + DenseMap<Value *, Value *> SROAArgValues; + + /// The mapping of caller Alloca values to their accumulated cost savings. If + /// we have to disable SROA for one of the allocas, this tells us how much + /// cost must be added. + DenseMap<Value *, int> SROAArgCosts; + + /// Keep track of values which map to a pointer base and constant offset. + DenseMap<Value *, std::pair<Value *, APInt>> ConstantOffsetPtrs; + + /// Keep track of dead blocks due to the constant arguments. + SetVector<BasicBlock *> DeadBlocks; + + /// The mapping of the blocks to their known unique successors due to the + /// constant arguments. + DenseMap<BasicBlock *, BasicBlock *> KnownSuccessors; + + /// Model the elimination of repeated loads that is expected to happen + /// whenever we simplify away the stores that would otherwise cause them to be + /// loads. + bool EnableLoadElimination; + SmallPtrSet<Value *, 16> LoadAddrSet; + int LoadEliminationCost = 0; + + // Custom simplification helper routines. + bool isAllocaDerivedArg(Value *V); + bool lookupSROAArgAndCost(Value *V, Value *&Arg, + DenseMap<Value *, int>::iterator &CostIt); + void disableSROA(DenseMap<Value *, int>::iterator CostIt); + void disableSROA(Value *V); + void findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB); + void accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, + int InstructionCost); + void disableLoadElimination(); + bool isGEPFree(GetElementPtrInst &GEP); + bool canFoldInboundsGEP(GetElementPtrInst &I); + bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset); + bool simplifyCallSite(Function *F, CallBase &Call); + template <typename Callable> + bool simplifyInstruction(Instruction &I, Callable Evaluate); + ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); + + /// Return true if the given argument to the function being considered for + /// inlining has the given attribute set either at the call site or the + /// function declaration. Primarily used to inspect call site specific + /// attributes since these can be more precise than the ones on the callee + /// itself. + bool paramHasAttr(Argument *A, Attribute::AttrKind Attr); + + /// Return true if the given value is known non null within the callee if + /// inlined through this particular callsite. + bool isKnownNonNullInCallee(Value *V); + + /// Update Threshold based on callsite properties such as callee + /// attributes and callee hotness for PGO builds. The Callee is explicitly + /// passed to support analyzing indirect calls whose target is inferred by + /// analysis. + void updateThreshold(CallBase &Call, Function &Callee); + + /// Return true if size growth is allowed when inlining the callee at \p Call. + bool allowSizeGrowth(CallBase &Call); + + /// Return true if \p Call is a cold callsite. + bool isColdCallSite(CallBase &Call, BlockFrequencyInfo *CallerBFI); + + /// Return a higher threshold if \p Call is a hot callsite. + Optional<int> getHotCallSiteThreshold(CallBase &Call, + BlockFrequencyInfo *CallerBFI); + + // Custom analysis routines. + InlineResult analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues); + + /// Handle a capped 'int' increment for Cost. + void addCost(int64_t Inc, int64_t UpperBound = INT_MAX) { + assert(UpperBound > 0 && UpperBound <= INT_MAX && "invalid upper bound"); + Cost = (int)std::min(UpperBound, Cost + Inc); + } + + // Disable several entry points to the visitor so we don't accidentally use + // them by declaring but not defining them here. + void visit(Module *); + void visit(Module &); + void visit(Function *); + void visit(Function &); + void visit(BasicBlock *); + void visit(BasicBlock &); + + // Provide base case for our instruction visit. + bool visitInstruction(Instruction &I); + + // Our visit overrides. + bool visitAlloca(AllocaInst &I); + bool visitPHI(PHINode &I); + bool visitGetElementPtr(GetElementPtrInst &I); + bool visitBitCast(BitCastInst &I); + bool visitPtrToInt(PtrToIntInst &I); + bool visitIntToPtr(IntToPtrInst &I); + bool visitCastInst(CastInst &I); + bool visitUnaryInstruction(UnaryInstruction &I); + bool visitCmpInst(CmpInst &I); + bool visitSub(BinaryOperator &I); + bool visitBinaryOperator(BinaryOperator &I); + bool visitFNeg(UnaryOperator &I); + bool visitLoad(LoadInst &I); + bool visitStore(StoreInst &I); + bool visitExtractValue(ExtractValueInst &I); + bool visitInsertValue(InsertValueInst &I); + bool visitCallBase(CallBase &Call); + bool visitReturnInst(ReturnInst &RI); + bool visitBranchInst(BranchInst &BI); + bool visitSelectInst(SelectInst &SI); + bool visitSwitchInst(SwitchInst &SI); + bool visitIndirectBrInst(IndirectBrInst &IBI); + bool visitResumeInst(ResumeInst &RI); + bool visitCleanupReturnInst(CleanupReturnInst &RI); + bool visitCatchReturnInst(CatchReturnInst &RI); + bool visitUnreachableInst(UnreachableInst &I); + +public: + CallAnalyzer(const TargetTransformInfo &TTI, + std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI, + ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE, + Function &Callee, CallBase &Call, const InlineParams &Params) + : TTI(TTI), GetAssumptionCache(GetAssumptionCache), GetBFI(GetBFI), + PSI(PSI), F(Callee), DL(F.getParent()->getDataLayout()), ORE(ORE), + CandidateCall(Call), Params(Params), Threshold(Params.DefaultThreshold), + ComputeFullInlineCost(OptComputeFullInlineCost || + Params.ComputeFullInlineCost || ORE), + EnableLoadElimination(true) {} + + InlineResult analyzeCall(CallBase &Call); + + int getThreshold() { return Threshold; } + int getCost() { return Cost; } + + // Keep a bunch of stats about the cost savings found so we can print them + // out when debugging. + unsigned NumConstantArgs = 0; + unsigned NumConstantOffsetPtrArgs = 0; + unsigned NumAllocaArgs = 0; + unsigned NumConstantPtrCmps = 0; + unsigned NumConstantPtrDiffs = 0; + unsigned NumInstructionsSimplified = 0; + unsigned SROACostSavings = 0; + unsigned SROACostSavingsLost = 0; + + void dump(); +}; + +} // namespace + +/// Test whether the given value is an Alloca-derived function argument. +bool CallAnalyzer::isAllocaDerivedArg(Value *V) { + return SROAArgValues.count(V); +} + +/// Lookup the SROA-candidate argument and cost iterator which V maps to. +/// Returns false if V does not map to a SROA-candidate. +bool CallAnalyzer::lookupSROAArgAndCost( + Value *V, Value *&Arg, DenseMap<Value *, int>::iterator &CostIt) { + if (SROAArgValues.empty() || SROAArgCosts.empty()) + return false; + + DenseMap<Value *, Value *>::iterator ArgIt = SROAArgValues.find(V); + if (ArgIt == SROAArgValues.end()) + return false; + + Arg = ArgIt->second; + CostIt = SROAArgCosts.find(Arg); + return CostIt != SROAArgCosts.end(); +} + +/// Disable SROA for the candidate marked by this cost iterator. +/// +/// This marks the candidate as no longer viable for SROA, and adds the cost +/// savings associated with it back into the inline cost measurement. +void CallAnalyzer::disableSROA(DenseMap<Value *, int>::iterator CostIt) { + // If we're no longer able to perform SROA we need to undo its cost savings + // and prevent subsequent analysis. + addCost(CostIt->second); + SROACostSavings -= CostIt->second; + SROACostSavingsLost += CostIt->second; + SROAArgCosts.erase(CostIt); + disableLoadElimination(); +} + +/// If 'V' maps to a SROA candidate, disable SROA for it. +void CallAnalyzer::disableSROA(Value *V) { + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(V, SROAArg, CostIt)) + disableSROA(CostIt); +} + +/// Accumulate the given cost for a particular SROA candidate. +void CallAnalyzer::accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, + int InstructionCost) { + CostIt->second += InstructionCost; + SROACostSavings += InstructionCost; +} + +void CallAnalyzer::disableLoadElimination() { + if (EnableLoadElimination) { + addCost(LoadEliminationCost); + LoadEliminationCost = 0; + EnableLoadElimination = false; + } +} + +/// Accumulate a constant GEP offset into an APInt if possible. +/// +/// Returns false if unable to compute the offset for any reason. Respects any +/// simplified values known during the analysis of this callsite. +bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { + unsigned IntPtrWidth = DL.getIndexTypeSizeInBits(GEP.getType()); + assert(IntPtrWidth == Offset.getBitWidth()); + + for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP); + GTI != GTE; ++GTI) { + ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand()); + if (!OpC) + if (Constant *SimpleOp = SimplifiedValues.lookup(GTI.getOperand())) + OpC = dyn_cast<ConstantInt>(SimpleOp); + if (!OpC) + return false; + if (OpC->isZero()) + continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + unsigned ElementIdx = OpC->getZExtValue(); + const StructLayout *SL = DL.getStructLayout(STy); + Offset += APInt(IntPtrWidth, SL->getElementOffset(ElementIdx)); + continue; + } + + APInt TypeSize(IntPtrWidth, DL.getTypeAllocSize(GTI.getIndexedType())); + Offset += OpC->getValue().sextOrTrunc(IntPtrWidth) * TypeSize; + } + return true; +} + +/// Use TTI to check whether a GEP is free. +/// +/// Respects any simplified values known during the analysis of this callsite. +bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) { + SmallVector<Value *, 4> Operands; + Operands.push_back(GEP.getOperand(0)); + for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) + if (Constant *SimpleOp = SimplifiedValues.lookup(*I)) + Operands.push_back(SimpleOp); + else + Operands.push_back(*I); + return TargetTransformInfo::TCC_Free == TTI.getUserCost(&GEP, Operands); +} + +bool CallAnalyzer::visitAlloca(AllocaInst &I) { + // Check whether inlining will turn a dynamic alloca into a static + // alloca and handle that case. + if (I.isArrayAllocation()) { + Constant *Size = SimplifiedValues.lookup(I.getArraySize()); + if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) { + Type *Ty = I.getAllocatedType(); + AllocatedSize = SaturatingMultiplyAdd( + AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty).getFixedSize(), + AllocatedSize); + return Base::visitAlloca(I); + } + } + + // Accumulate the allocated size. + if (I.isStaticAlloca()) { + Type *Ty = I.getAllocatedType(); + AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty).getFixedSize(), + AllocatedSize); + } + + // We will happily inline static alloca instructions. + if (I.isStaticAlloca()) + return Base::visitAlloca(I); + + // FIXME: This is overly conservative. Dynamic allocas are inefficient for + // a variety of reasons, and so we would like to not inline them into + // functions which don't currently have a dynamic alloca. This simply + // disables inlining altogether in the presence of a dynamic alloca. + HasDynamicAlloca = true; + return false; +} + +bool CallAnalyzer::visitPHI(PHINode &I) { + // FIXME: We need to propagate SROA *disabling* through phi nodes, even + // though we don't want to propagate it's bonuses. The idea is to disable + // SROA if it *might* be used in an inappropriate manner. + + // Phi nodes are always zero-cost. + // FIXME: Pointer sizes may differ between different address spaces, so do we + // need to use correct address space in the call to getPointerSizeInBits here? + // Or could we skip the getPointerSizeInBits call completely? As far as I can + // see the ZeroOffset is used as a dummy value, so we can probably use any + // bit width for the ZeroOffset? + APInt ZeroOffset = APInt::getNullValue(DL.getPointerSizeInBits(0)); + bool CheckSROA = I.getType()->isPointerTy(); + + // Track the constant or pointer with constant offset we've seen so far. + Constant *FirstC = nullptr; + std::pair<Value *, APInt> FirstBaseAndOffset = {nullptr, ZeroOffset}; + Value *FirstV = nullptr; + + for (unsigned i = 0, e = I.getNumIncomingValues(); i != e; ++i) { + BasicBlock *Pred = I.getIncomingBlock(i); + // If the incoming block is dead, skip the incoming block. + if (DeadBlocks.count(Pred)) + continue; + // If the parent block of phi is not the known successor of the incoming + // block, skip the incoming block. + BasicBlock *KnownSuccessor = KnownSuccessors[Pred]; + if (KnownSuccessor && KnownSuccessor != I.getParent()) + continue; + + Value *V = I.getIncomingValue(i); + // If the incoming value is this phi itself, skip the incoming value. + if (&I == V) + continue; + + Constant *C = dyn_cast<Constant>(V); + if (!C) + C = SimplifiedValues.lookup(V); + + std::pair<Value *, APInt> BaseAndOffset = {nullptr, ZeroOffset}; + if (!C && CheckSROA) + BaseAndOffset = ConstantOffsetPtrs.lookup(V); + + if (!C && !BaseAndOffset.first) + // The incoming value is neither a constant nor a pointer with constant + // offset, exit early. + return true; + + if (FirstC) { + if (FirstC == C) + // If we've seen a constant incoming value before and it is the same + // constant we see this time, continue checking the next incoming value. + continue; + // Otherwise early exit because we either see a different constant or saw + // a constant before but we have a pointer with constant offset this time. + return true; + } + + if (FirstV) { + // The same logic as above, but check pointer with constant offset here. + if (FirstBaseAndOffset == BaseAndOffset) + continue; + return true; + } + + if (C) { + // This is the 1st time we've seen a constant, record it. + FirstC = C; + continue; + } + + // The remaining case is that this is the 1st time we've seen a pointer with + // constant offset, record it. + FirstV = V; + FirstBaseAndOffset = BaseAndOffset; + } + + // Check if we can map phi to a constant. + if (FirstC) { + SimplifiedValues[&I] = FirstC; + return true; + } + + // Check if we can map phi to a pointer with constant offset. + if (FirstBaseAndOffset.first) { + ConstantOffsetPtrs[&I] = FirstBaseAndOffset; + + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(FirstV, SROAArg, CostIt)) + SROAArgValues[&I] = SROAArg; + } + + return true; +} + +/// Check we can fold GEPs of constant-offset call site argument pointers. +/// This requires target data and inbounds GEPs. +/// +/// \return true if the specified GEP can be folded. +bool CallAnalyzer::canFoldInboundsGEP(GetElementPtrInst &I) { + // Check if we have a base + offset for the pointer. + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(I.getPointerOperand()); + if (!BaseAndOffset.first) + return false; + + // Check if the offset of this GEP is constant, and if so accumulate it + // into Offset. + if (!accumulateGEPOffset(cast<GEPOperator>(I), BaseAndOffset.second)) + return false; + + // Add the result as a new mapping to Base + Offset. + ConstantOffsetPtrs[&I] = BaseAndOffset; + + return true; +} + +bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + bool SROACandidate = + lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt); + + // Lambda to check whether a GEP's indices are all constant. + auto IsGEPOffsetConstant = [&](GetElementPtrInst &GEP) { + for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) + if (!isa<Constant>(*I) && !SimplifiedValues.lookup(*I)) + return false; + return true; + }; + + if ((I.isInBounds() && canFoldInboundsGEP(I)) || IsGEPOffsetConstant(I)) { + if (SROACandidate) + SROAArgValues[&I] = SROAArg; + + // Constant GEPs are modeled as free. + return true; + } + + // Variable GEPs will require math and will disable SROA. + if (SROACandidate) + disableSROA(CostIt); + return isGEPFree(I); +} + +/// Simplify \p I if its operands are constants and update SimplifiedValues. +/// \p Evaluate is a callable specific to instruction type that evaluates the +/// instruction when all the operands are constants. +template <typename Callable> +bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) { + SmallVector<Constant *, 2> COps; + for (Value *Op : I.operands()) { + Constant *COp = dyn_cast<Constant>(Op); + if (!COp) + COp = SimplifiedValues.lookup(Op); + if (!COp) + return false; + COps.push_back(COp); + } + auto *C = Evaluate(COps); + if (!C) + return false; + SimplifiedValues[&I] = C; + return true; +} + +bool CallAnalyzer::visitBitCast(BitCastInst &I) { + // Propagate constants through bitcasts. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getBitCast(COps[0], I.getType()); + })) + return true; + + // Track base/offsets through casts + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(I.getOperand(0)); + // Casts don't change the offset, just wrap it up. + if (BaseAndOffset.first) + ConstantOffsetPtrs[&I] = BaseAndOffset; + + // Also look for SROA candidates here. + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) + SROAArgValues[&I] = SROAArg; + + // Bitcasts are always zero cost. + return true; +} + +bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { + // Propagate constants through ptrtoint. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getPtrToInt(COps[0], I.getType()); + })) + return true; + + // Track base/offset pairs when converted to a plain integer provided the + // integer is large enough to represent the pointer. + unsigned IntegerSize = I.getType()->getScalarSizeInBits(); + unsigned AS = I.getOperand(0)->getType()->getPointerAddressSpace(); + if (IntegerSize >= DL.getPointerSizeInBits(AS)) { + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(I.getOperand(0)); + if (BaseAndOffset.first) + ConstantOffsetPtrs[&I] = BaseAndOffset; + } + + // This is really weird. Technically, ptrtoint will disable SROA. However, + // unless that ptrtoint is *used* somewhere in the live basic blocks after + // inlining, it will be nuked, and SROA should proceed. All of the uses which + // would block SROA would also block SROA if applied directly to a pointer, + // and so we can just add the integer in here. The only places where SROA is + // preserved either cannot fire on an integer, or won't in-and-of themselves + // disable SROA (ext) w/o some later use that we would see and disable. + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) + SROAArgValues[&I] = SROAArg; + + return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { + // Propagate constants through ptrtoint. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getIntToPtr(COps[0], I.getType()); + })) + return true; + + // Track base/offset pairs when round-tripped through a pointer without + // modifications provided the integer is not too large. + Value *Op = I.getOperand(0); + unsigned IntegerSize = Op->getType()->getScalarSizeInBits(); + if (IntegerSize <= DL.getPointerTypeSizeInBits(I.getType())) { + std::pair<Value *, APInt> BaseAndOffset = ConstantOffsetPtrs.lookup(Op); + if (BaseAndOffset.first) + ConstantOffsetPtrs[&I] = BaseAndOffset; + } + + // "Propagate" SROA here in the same manner as we do for ptrtoint above. + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(Op, SROAArg, CostIt)) + SROAArgValues[&I] = SROAArg; + + return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitCastInst(CastInst &I) { + // Propagate constants through casts. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getCast(I.getOpcode(), COps[0], I.getType()); + })) + return true; + + // Disable SROA in the face of arbitrary casts we don't whitelist elsewhere. + disableSROA(I.getOperand(0)); + + // If this is a floating-point cast, and the target says this operation + // is expensive, this may eventually become a library call. Treat the cost + // as such. + switch (I.getOpcode()) { + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: + if (TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive) + addCost(InlineConstants::CallPenalty); + break; + default: + break; + } + + return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { + Value *Operand = I.getOperand(0); + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantFoldInstOperands(&I, COps[0], DL); + })) + return true; + + // Disable any SROA on the argument to arbitrary unary instructions. + disableSROA(Operand); + + return false; +} + +bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) { + return CandidateCall.paramHasAttr(A->getArgNo(), Attr); +} + +bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { + // Does the *call site* have the NonNull attribute set on an argument? We + // use the attribute on the call site to memoize any analysis done in the + // caller. This will also trip if the callee function has a non-null + // parameter attribute, but that's a less interesting case because hopefully + // the callee would already have been simplified based on that. + if (Argument *A = dyn_cast<Argument>(V)) + if (paramHasAttr(A, Attribute::NonNull)) + return true; + + // Is this an alloca in the caller? This is distinct from the attribute case + // above because attributes aren't updated within the inliner itself and we + // always want to catch the alloca derived case. + if (isAllocaDerivedArg(V)) + // We can actually predict the result of comparisons between an + // alloca-derived value and null. Note that this fires regardless of + // SROA firing. + return true; + + return false; +} + +bool CallAnalyzer::allowSizeGrowth(CallBase &Call) { + // If the normal destination of the invoke or the parent block of the call + // site is unreachable-terminated, there is little point in inlining this + // unless there is literally zero cost. + // FIXME: Note that it is possible that an unreachable-terminated block has a + // hot entry. For example, in below scenario inlining hot_call_X() may be + // beneficial : + // main() { + // hot_call_1(); + // ... + // hot_call_N() + // exit(0); + // } + // For now, we are not handling this corner case here as it is rare in real + // code. In future, we should elaborate this based on BPI and BFI in more + // general threshold adjusting heuristics in updateThreshold(). + if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) { + if (isa<UnreachableInst>(II->getNormalDest()->getTerminator())) + return false; + } else if (isa<UnreachableInst>(Call.getParent()->getTerminator())) + return false; + + return true; +} + +bool CallAnalyzer::isColdCallSite(CallBase &Call, + BlockFrequencyInfo *CallerBFI) { + // If global profile summary is available, then callsite's coldness is + // determined based on that. + if (PSI && PSI->hasProfileSummary()) + return PSI->isColdCallSite(CallSite(&Call), CallerBFI); + + // Otherwise we need BFI to be available. + if (!CallerBFI) + return false; + + // Determine if the callsite is cold relative to caller's entry. We could + // potentially cache the computation of scaled entry frequency, but the added + // complexity is not worth it unless this scaling shows up high in the + // profiles. + const BranchProbability ColdProb(ColdCallSiteRelFreq, 100); + auto CallSiteBB = Call.getParent(); + auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB); + auto CallerEntryFreq = + CallerBFI->getBlockFreq(&(Call.getCaller()->getEntryBlock())); + return CallSiteFreq < CallerEntryFreq * ColdProb; +} + +Optional<int> +CallAnalyzer::getHotCallSiteThreshold(CallBase &Call, + BlockFrequencyInfo *CallerBFI) { + + // If global profile summary is available, then callsite's hotness is + // determined based on that. + if (PSI && PSI->hasProfileSummary() && + PSI->isHotCallSite(CallSite(&Call), CallerBFI)) + return Params.HotCallSiteThreshold; + + // Otherwise we need BFI to be available and to have a locally hot callsite + // threshold. + if (!CallerBFI || !Params.LocallyHotCallSiteThreshold) + return None; + + // Determine if the callsite is hot relative to caller's entry. We could + // potentially cache the computation of scaled entry frequency, but the added + // complexity is not worth it unless this scaling shows up high in the + // profiles. + auto CallSiteBB = Call.getParent(); + auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency(); + auto CallerEntryFreq = CallerBFI->getEntryFreq(); + if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq) + return Params.LocallyHotCallSiteThreshold; + + // Otherwise treat it normally. + return None; +} + +void CallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) { + // If no size growth is allowed for this inlining, set Threshold to 0. + if (!allowSizeGrowth(Call)) { + Threshold = 0; + return; + } + + Function *Caller = Call.getCaller(); + + // return min(A, B) if B is valid. + auto MinIfValid = [](int A, Optional<int> B) { + return B ? std::min(A, B.getValue()) : A; + }; + + // return max(A, B) if B is valid. + auto MaxIfValid = [](int A, Optional<int> B) { + return B ? std::max(A, B.getValue()) : A; + }; + + // Various bonus percentages. These are multiplied by Threshold to get the + // bonus values. + // SingleBBBonus: This bonus is applied if the callee has a single reachable + // basic block at the given callsite context. This is speculatively applied + // and withdrawn if more than one basic block is seen. + // + // LstCallToStaticBonus: This large bonus is applied to ensure the inlining + // of the last call to a static function as inlining such functions is + // guaranteed to reduce code size. + // + // These bonus percentages may be set to 0 based on properties of the caller + // and the callsite. + int SingleBBBonusPercent = 50; + int VectorBonusPercent = TTI.getInlinerVectorBonusPercent(); + int LastCallToStaticBonus = InlineConstants::LastCallToStaticBonus; + + // Lambda to set all the above bonus and bonus percentages to 0. + auto DisallowAllBonuses = [&]() { + SingleBBBonusPercent = 0; + VectorBonusPercent = 0; + LastCallToStaticBonus = 0; + }; + + // Use the OptMinSizeThreshold or OptSizeThreshold knob if they are available + // and reduce the threshold if the caller has the necessary attribute. + if (Caller->hasMinSize()) { + Threshold = MinIfValid(Threshold, Params.OptMinSizeThreshold); + // For minsize, we want to disable the single BB bonus and the vector + // bonuses, but not the last-call-to-static bonus. Inlining the last call to + // a static function will, at the minimum, eliminate the parameter setup and + // call/return instructions. + SingleBBBonusPercent = 0; + VectorBonusPercent = 0; + } else if (Caller->hasOptSize()) + Threshold = MinIfValid(Threshold, Params.OptSizeThreshold); + + // Adjust the threshold based on inlinehint attribute and profile based + // hotness information if the caller does not have MinSize attribute. + if (!Caller->hasMinSize()) { + if (Callee.hasFnAttribute(Attribute::InlineHint)) + Threshold = MaxIfValid(Threshold, Params.HintThreshold); + + // FIXME: After switching to the new passmanager, simplify the logic below + // by checking only the callsite hotness/coldness as we will reliably + // have local profile information. + // + // Callsite hotness and coldness can be determined if sample profile is + // used (which adds hotness metadata to calls) or if caller's + // BlockFrequencyInfo is available. + BlockFrequencyInfo *CallerBFI = GetBFI ? &((*GetBFI)(*Caller)) : nullptr; + auto HotCallSiteThreshold = getHotCallSiteThreshold(Call, CallerBFI); + if (!Caller->hasOptSize() && HotCallSiteThreshold) { + LLVM_DEBUG(dbgs() << "Hot callsite.\n"); + // FIXME: This should update the threshold only if it exceeds the + // current threshold, but AutoFDO + ThinLTO currently relies on this + // behavior to prevent inlining of hot callsites during ThinLTO + // compile phase. + Threshold = HotCallSiteThreshold.getValue(); + } else if (isColdCallSite(Call, CallerBFI)) { + LLVM_DEBUG(dbgs() << "Cold callsite.\n"); + // Do not apply bonuses for a cold callsite including the + // LastCallToStatic bonus. While this bonus might result in code size + // reduction, it can cause the size of a non-cold caller to increase + // preventing it from being inlined. + DisallowAllBonuses(); + Threshold = MinIfValid(Threshold, Params.ColdCallSiteThreshold); + } else if (PSI) { + // Use callee's global profile information only if we have no way of + // determining this via callsite information. + if (PSI->isFunctionEntryHot(&Callee)) { + LLVM_DEBUG(dbgs() << "Hot callee.\n"); + // If callsite hotness can not be determined, we may still know + // that the callee is hot and treat it as a weaker hint for threshold + // increase. + Threshold = MaxIfValid(Threshold, Params.HintThreshold); + } else if (PSI->isFunctionEntryCold(&Callee)) { + LLVM_DEBUG(dbgs() << "Cold callee.\n"); + // Do not apply bonuses for a cold callee including the + // LastCallToStatic bonus. While this bonus might result in code size + // reduction, it can cause the size of a non-cold caller to increase + // preventing it from being inlined. + DisallowAllBonuses(); + Threshold = MinIfValid(Threshold, Params.ColdThreshold); + } + } + } + + // Finally, take the target-specific inlining threshold multiplier into + // account. + Threshold *= TTI.getInliningThresholdMultiplier(); + + SingleBBBonus = Threshold * SingleBBBonusPercent / 100; + VectorBonus = Threshold * VectorBonusPercent / 100; + + bool OnlyOneCallAndLocalLinkage = + F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction(); + // If there is only one call of the function, and it has internal linkage, + // the cost of inlining it drops dramatically. It may seem odd to update + // Cost in updateThreshold, but the bonus depends on the logic in this method. + if (OnlyOneCallAndLocalLinkage) + Cost -= LastCallToStaticBonus; +} + +bool CallAnalyzer::visitCmpInst(CmpInst &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + // First try to handle simplified comparisons. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getCompare(I.getPredicate(), COps[0], COps[1]); + })) + return true; + + if (I.getOpcode() == Instruction::FCmp) + return false; + + // Otherwise look for a comparison between constant offset pointers with + // a common base. + Value *LHSBase, *RHSBase; + APInt LHSOffset, RHSOffset; + std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS); + if (LHSBase) { + std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS); + if (RHSBase && LHSBase == RHSBase) { + // We have common bases, fold the icmp to a constant based on the + // offsets. + Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset); + Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset); + if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) { + SimplifiedValues[&I] = C; + ++NumConstantPtrCmps; + return true; + } + } + } + + // If the comparison is an equality comparison with null, we can simplify it + // if we know the value (argument) can't be null + if (I.isEquality() && isa<ConstantPointerNull>(I.getOperand(1)) && + isKnownNonNullInCallee(I.getOperand(0))) { + bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE; + SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType()) + : ConstantInt::getFalse(I.getType()); + return true; + } + // Finally check for SROA candidates in comparisons. + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) { + if (isa<ConstantPointerNull>(I.getOperand(1))) { + accumulateSROACost(CostIt, InlineConstants::InstrCost); + return true; + } + + disableSROA(CostIt); + } + + return false; +} + +bool CallAnalyzer::visitSub(BinaryOperator &I) { + // Try to handle a special case: we can fold computing the difference of two + // constant-related pointers. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *LHSBase, *RHSBase; + APInt LHSOffset, RHSOffset; + std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS); + if (LHSBase) { + std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS); + if (RHSBase && LHSBase == RHSBase) { + // We have common bases, fold the subtract to a constant based on the + // offsets. + Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset); + Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset); + if (Constant *C = ConstantExpr::getSub(CLHS, CRHS)) { + SimplifiedValues[&I] = C; + ++NumConstantPtrDiffs; + return true; + } + } + } + + // Otherwise, fall back to the generic logic for simplifying and handling + // instructions. + return Base::visitSub(I); +} + +bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Constant *CLHS = dyn_cast<Constant>(LHS); + if (!CLHS) + CLHS = SimplifiedValues.lookup(LHS); + Constant *CRHS = dyn_cast<Constant>(RHS); + if (!CRHS) + CRHS = SimplifiedValues.lookup(RHS); + + Value *SimpleV = nullptr; + if (auto FI = dyn_cast<FPMathOperator>(&I)) + SimpleV = SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, + CRHS ? CRHS : RHS, FI->getFastMathFlags(), DL); + else + SimpleV = + SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, DL); + + if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) + SimplifiedValues[&I] = C; + + if (SimpleV) + return true; + + // Disable any SROA on arguments to arbitrary, unsimplified binary operators. + disableSROA(LHS); + disableSROA(RHS); + + // If the instruction is floating point, and the target says this operation + // is expensive, this may eventually become a library call. Treat the cost + // as such. Unless it's fneg which can be implemented with an xor. + using namespace llvm::PatternMatch; + if (I.getType()->isFloatingPointTy() && + TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive && + !match(&I, m_FNeg(m_Value()))) + addCost(InlineConstants::CallPenalty); + + return false; +} + +bool CallAnalyzer::visitFNeg(UnaryOperator &I) { + Value *Op = I.getOperand(0); + Constant *COp = dyn_cast<Constant>(Op); + if (!COp) + COp = SimplifiedValues.lookup(Op); + + Value *SimpleV = SimplifyFNegInst(COp ? COp : Op, + cast<FPMathOperator>(I).getFastMathFlags(), + DL); + + if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) + SimplifiedValues[&I] = C; + + if (SimpleV) + return true; + + // Disable any SROA on arguments to arbitrary, unsimplified fneg. + disableSROA(Op); + + return false; +} + +bool CallAnalyzer::visitLoad(LoadInst &I) { + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) { + if (I.isSimple()) { + accumulateSROACost(CostIt, InlineConstants::InstrCost); + return true; + } + + disableSROA(CostIt); + } + + // If the data is already loaded from this address and hasn't been clobbered + // by any stores or calls, this load is likely to be redundant and can be + // eliminated. + if (EnableLoadElimination && + !LoadAddrSet.insert(I.getPointerOperand()).second && I.isUnordered()) { + LoadEliminationCost += InlineConstants::InstrCost; + return true; + } + + return false; +} + +bool CallAnalyzer::visitStore(StoreInst &I) { + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) { + if (I.isSimple()) { + accumulateSROACost(CostIt, InlineConstants::InstrCost); + return true; + } + + disableSROA(CostIt); + } + + // The store can potentially clobber loads and prevent repeated loads from + // being eliminated. + // FIXME: + // 1. We can probably keep an initial set of eliminatable loads substracted + // from the cost even when we finally see a store. We just need to disable + // *further* accumulation of elimination savings. + // 2. We should probably at some point thread MemorySSA for the callee into + // this and then use that to actually compute *really* precise savings. + disableLoadElimination(); + return false; +} + +bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { + // Constant folding for extract value is trivial. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getExtractValue(COps[0], I.getIndices()); + })) + return true; + + // SROA can look through these but give them a cost. + return false; +} + +bool CallAnalyzer::visitInsertValue(InsertValueInst &I) { + // Constant folding for insert value is trivial. + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getInsertValue(/*AggregateOperand*/ COps[0], + /*InsertedValueOperand*/ COps[1], + I.getIndices()); + })) + return true; + + // SROA can look through these but give them a cost. + return false; +} + +/// Try to simplify a call site. +/// +/// Takes a concrete function and callsite and tries to actually simplify it by +/// analyzing the arguments and call itself with instsimplify. Returns true if +/// it has simplified the callsite to some other entity (a constant), making it +/// free. +bool CallAnalyzer::simplifyCallSite(Function *F, CallBase &Call) { + // FIXME: Using the instsimplify logic directly for this is inefficient + // because we have to continually rebuild the argument list even when no + // simplifications can be performed. Until that is fixed with remapping + // inside of instsimplify, directly constant fold calls here. + if (!canConstantFoldCallTo(&Call, F)) + return false; + + // Try to re-map the arguments to constants. + SmallVector<Constant *, 4> ConstantArgs; + ConstantArgs.reserve(Call.arg_size()); + for (Value *I : Call.args()) { + Constant *C = dyn_cast<Constant>(I); + if (!C) + C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(I)); + if (!C) + return false; // This argument doesn't map to a constant. + + ConstantArgs.push_back(C); + } + if (Constant *C = ConstantFoldCall(&Call, F, ConstantArgs)) { + SimplifiedValues[&Call] = C; + return true; + } + + return false; +} + +bool CallAnalyzer::visitCallBase(CallBase &Call) { + if (Call.hasFnAttr(Attribute::ReturnsTwice) && + !F.hasFnAttribute(Attribute::ReturnsTwice)) { + // This aborts the entire analysis. + ExposesReturnsTwice = true; + return false; + } + if (isa<CallInst>(Call) && cast<CallInst>(Call).cannotDuplicate()) + ContainsNoDuplicateCall = true; + + if (Function *F = Call.getCalledFunction()) { + // When we have a concrete function, first try to simplify it directly. + if (simplifyCallSite(F, Call)) + return true; + + // Next check if it is an intrinsic we know about. + // FIXME: Lift this into part of the InstVisitor. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Call)) { + switch (II->getIntrinsicID()) { + default: + if (!Call.onlyReadsMemory() && !isAssumeLikeIntrinsic(II)) + disableLoadElimination(); + return Base::visitCallBase(Call); + + case Intrinsic::load_relative: + // This is normally lowered to 4 LLVM instructions. + addCost(3 * InlineConstants::InstrCost); + return false; + + case Intrinsic::memset: + case Intrinsic::memcpy: + case Intrinsic::memmove: + disableLoadElimination(); + // SROA can usually chew through these intrinsics, but they aren't free. + return false; + case Intrinsic::icall_branch_funnel: + case Intrinsic::localescape: + HasUninlineableIntrinsic = true; + return false; + case Intrinsic::vastart: + InitsVargArgs = true; + return false; + } + } + + if (F == Call.getFunction()) { + // This flag will fully abort the analysis, so don't bother with anything + // else. + IsRecursiveCall = true; + return false; + } + + if (TTI.isLoweredToCall(F)) { + // We account for the average 1 instruction per call argument setup + // here. + addCost(Call.arg_size() * InlineConstants::InstrCost); + + // Everything other than inline ASM will also have a significant cost + // merely from making the call. + if (!isa<InlineAsm>(Call.getCalledValue())) + addCost(InlineConstants::CallPenalty); + } + + if (!Call.onlyReadsMemory()) + disableLoadElimination(); + return Base::visitCallBase(Call); + } + + // Otherwise we're in a very special case -- an indirect function call. See + // if we can be particularly clever about this. + Value *Callee = Call.getCalledValue(); + + // First, pay the price of the argument setup. We account for the average + // 1 instruction per call argument setup here. + addCost(Call.arg_size() * InlineConstants::InstrCost); + + // Next, check if this happens to be an indirect function call to a known + // function in this inline context. If not, we've done all we can. + Function *F = dyn_cast_or_null<Function>(SimplifiedValues.lookup(Callee)); + if (!F) { + if (!Call.onlyReadsMemory()) + disableLoadElimination(); + return Base::visitCallBase(Call); + } + + // If we have a constant that we are calling as a function, we can peer + // through it and see the function target. This happens not infrequently + // during devirtualization and so we want to give it a hefty bonus for + // inlining, but cap that bonus in the event that inlining wouldn't pan + // out. Pretend to inline the function, with a custom threshold. + auto IndirectCallParams = Params; + IndirectCallParams.DefaultThreshold = InlineConstants::IndirectCallThreshold; + CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, ORE, *F, Call, + IndirectCallParams); + if (CA.analyzeCall(Call)) { + // We were able to inline the indirect call! Subtract the cost from the + // threshold to get the bonus we want to apply, but don't go below zero. + Cost -= std::max(0, CA.getThreshold() - CA.getCost()); + } + + if (!F->onlyReadsMemory()) + disableLoadElimination(); + return Base::visitCallBase(Call); +} + +bool CallAnalyzer::visitReturnInst(ReturnInst &RI) { + // At least one return instruction will be free after inlining. + bool Free = !HasReturn; + HasReturn = true; + return Free; +} + +bool CallAnalyzer::visitBranchInst(BranchInst &BI) { + // We model unconditional branches as essentially free -- they really + // shouldn't exist at all, but handling them makes the behavior of the + // inliner more regular and predictable. Interestingly, conditional branches + // which will fold away are also free. + return BI.isUnconditional() || isa<ConstantInt>(BI.getCondition()) || + dyn_cast_or_null<ConstantInt>( + SimplifiedValues.lookup(BI.getCondition())); +} + +bool CallAnalyzer::visitSelectInst(SelectInst &SI) { + bool CheckSROA = SI.getType()->isPointerTy(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + Constant *TrueC = dyn_cast<Constant>(TrueVal); + if (!TrueC) + TrueC = SimplifiedValues.lookup(TrueVal); + Constant *FalseC = dyn_cast<Constant>(FalseVal); + if (!FalseC) + FalseC = SimplifiedValues.lookup(FalseVal); + Constant *CondC = + dyn_cast_or_null<Constant>(SimplifiedValues.lookup(SI.getCondition())); + + if (!CondC) { + // Select C, X, X => X + if (TrueC == FalseC && TrueC) { + SimplifiedValues[&SI] = TrueC; + return true; + } + + if (!CheckSROA) + return Base::visitSelectInst(SI); + + std::pair<Value *, APInt> TrueBaseAndOffset = + ConstantOffsetPtrs.lookup(TrueVal); + std::pair<Value *, APInt> FalseBaseAndOffset = + ConstantOffsetPtrs.lookup(FalseVal); + if (TrueBaseAndOffset == FalseBaseAndOffset && TrueBaseAndOffset.first) { + ConstantOffsetPtrs[&SI] = TrueBaseAndOffset; + + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(TrueVal, SROAArg, CostIt)) + SROAArgValues[&SI] = SROAArg; + return true; + } + + return Base::visitSelectInst(SI); + } + + // Select condition is a constant. + Value *SelectedV = CondC->isAllOnesValue() + ? TrueVal + : (CondC->isNullValue()) ? FalseVal : nullptr; + if (!SelectedV) { + // Condition is a vector constant that is not all 1s or all 0s. If all + // operands are constants, ConstantExpr::getSelect() can handle the cases + // such as select vectors. + if (TrueC && FalseC) { + if (auto *C = ConstantExpr::getSelect(CondC, TrueC, FalseC)) { + SimplifiedValues[&SI] = C; + return true; + } + } + return Base::visitSelectInst(SI); + } + + // Condition is either all 1s or all 0s. SI can be simplified. + if (Constant *SelectedC = dyn_cast<Constant>(SelectedV)) { + SimplifiedValues[&SI] = SelectedC; + return true; + } + + if (!CheckSROA) + return true; + + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(SelectedV); + if (BaseAndOffset.first) { + ConstantOffsetPtrs[&SI] = BaseAndOffset; + + Value *SROAArg; + DenseMap<Value *, int>::iterator CostIt; + if (lookupSROAArgAndCost(SelectedV, SROAArg, CostIt)) + SROAArgValues[&SI] = SROAArg; + } + + return true; +} + +bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) { + // We model unconditional switches as free, see the comments on handling + // branches. + if (isa<ConstantInt>(SI.getCondition())) + return true; + if (Value *V = SimplifiedValues.lookup(SI.getCondition())) + if (isa<ConstantInt>(V)) + return true; + + // Assume the most general case where the switch is lowered into + // either a jump table, bit test, or a balanced binary tree consisting of + // case clusters without merging adjacent clusters with the same + // destination. We do not consider the switches that are lowered with a mix + // of jump table/bit test/binary search tree. The cost of the switch is + // proportional to the size of the tree or the size of jump table range. + // + // NB: We convert large switches which are just used to initialize large phi + // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent + // inlining those. It will prevent inlining in cases where the optimization + // does not (yet) fire. + + // Maximum valid cost increased in this function. + int CostUpperBound = INT_MAX - InlineConstants::InstrCost - 1; + + unsigned JumpTableSize = 0; + unsigned NumCaseCluster = + TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize); + + // If suitable for a jump table, consider the cost for the table size and + // branch to destination. + if (JumpTableSize) { + int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost + + 4 * InlineConstants::InstrCost; + + addCost(JTCost, (int64_t)CostUpperBound); + return false; + } + + // Considering forming a binary search, we should find the number of nodes + // which is same as the number of comparisons when lowered. For a given + // number of clusters, n, we can define a recursive function, f(n), to find + // the number of nodes in the tree. The recursion is : + // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3, + // and f(n) = n, when n <= 3. + // This will lead a binary tree where the leaf should be either f(2) or f(3) + // when n > 3. So, the number of comparisons from leaves should be n, while + // the number of non-leaf should be : + // 2^(log2(n) - 1) - 1 + // = 2^log2(n) * 2^-1 - 1 + // = n / 2 - 1. + // Considering comparisons from leaf and non-leaf nodes, we can estimate the + // number of comparisons in a simple closed form : + // n + n / 2 - 1 = n * 3 / 2 - 1 + if (NumCaseCluster <= 3) { + // Suppose a comparison includes one compare and one conditional branch. + addCost(NumCaseCluster * 2 * InlineConstants::InstrCost); + return false; + } + + int64_t ExpectedNumberOfCompare = 3 * (int64_t)NumCaseCluster / 2 - 1; + int64_t SwitchCost = + ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost; + + addCost(SwitchCost, (int64_t)CostUpperBound); + return false; +} + +bool CallAnalyzer::visitIndirectBrInst(IndirectBrInst &IBI) { + // We never want to inline functions that contain an indirectbr. This is + // incorrect because all the blockaddress's (in static global initializers + // for example) would be referring to the original function, and this + // indirect jump would jump from the inlined copy of the function into the + // original function which is extremely undefined behavior. + // FIXME: This logic isn't really right; we can safely inline functions with + // indirectbr's as long as no other function or global references the + // blockaddress of a block within the current function. + HasIndirectBr = true; + return false; +} + +bool CallAnalyzer::visitResumeInst(ResumeInst &RI) { + // FIXME: It's not clear that a single instruction is an accurate model for + // the inline cost of a resume instruction. + return false; +} + +bool CallAnalyzer::visitCleanupReturnInst(CleanupReturnInst &CRI) { + // FIXME: It's not clear that a single instruction is an accurate model for + // the inline cost of a cleanupret instruction. + return false; +} + +bool CallAnalyzer::visitCatchReturnInst(CatchReturnInst &CRI) { + // FIXME: It's not clear that a single instruction is an accurate model for + // the inline cost of a catchret instruction. + return false; +} + +bool CallAnalyzer::visitUnreachableInst(UnreachableInst &I) { + // FIXME: It might be reasonably to discount the cost of instructions leading + // to unreachable as they have the lowest possible impact on both runtime and + // code size. + return true; // No actual code is needed for unreachable. +} + +bool CallAnalyzer::visitInstruction(Instruction &I) { + // Some instructions are free. All of the free intrinsics can also be + // handled by SROA, etc. + if (TargetTransformInfo::TCC_Free == TTI.getUserCost(&I)) + return true; + + // We found something we don't understand or can't handle. Mark any SROA-able + // values in the operand list as no longer viable. + for (User::op_iterator OI = I.op_begin(), OE = I.op_end(); OI != OE; ++OI) + disableSROA(*OI); + + return false; +} + +/// Analyze a basic block for its contribution to the inline cost. +/// +/// This method walks the analyzer over every instruction in the given basic +/// block and accounts for their cost during inlining at this callsite. It +/// aborts early if the threshold has been exceeded or an impossible to inline +/// construct has been detected. It returns false if inlining is no longer +/// viable, and true if inlining remains viable. +InlineResult +CallAnalyzer::analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // FIXME: Currently, the number of instructions in a function regardless of + // our ability to simplify them during inline to constants or dead code, + // are actually used by the vector bonus heuristic. As long as that's true, + // we have to special case debug intrinsics here to prevent differences in + // inlining due to debug symbols. Eventually, the number of unsimplified + // instructions shouldn't factor into the cost computation, but until then, + // hack around it here. + if (isa<DbgInfoIntrinsic>(I)) + continue; + + // Skip ephemeral values. + if (EphValues.count(&*I)) + continue; + + ++NumInstructions; + if (isa<ExtractElementInst>(I) || I->getType()->isVectorTy()) + ++NumVectorInstructions; + + // If the instruction simplified to a constant, there is no cost to this + // instruction. Visit the instructions using our InstVisitor to account for + // all of the per-instruction logic. The visit tree returns true if we + // consumed the instruction in any way, and false if the instruction's base + // cost should count against inlining. + if (Base::visit(&*I)) + ++NumInstructionsSimplified; + else + addCost(InlineConstants::InstrCost); + + using namespace ore; + // If the visit this instruction detected an uninlinable pattern, abort. + InlineResult IR; + if (IsRecursiveCall) + IR = "recursive"; + else if (ExposesReturnsTwice) + IR = "exposes returns twice"; + else if (HasDynamicAlloca) + IR = "dynamic alloca"; + else if (HasIndirectBr) + IR = "indirect branch"; + else if (HasUninlineableIntrinsic) + IR = "uninlinable intrinsic"; + else if (InitsVargArgs) + IR = "varargs"; + if (!IR) { + if (ORE) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", + &CandidateCall) + << NV("Callee", &F) << " has uninlinable pattern (" + << NV("InlineResult", IR.message) + << ") and cost is not fully computed"; + }); + return IR; + } + + // If the caller is a recursive function then we don't want to inline + // functions which allocate a lot of stack space because it would increase + // the caller stack usage dramatically. + if (IsCallerRecursive && + AllocatedSize > InlineConstants::TotalAllocaSizeRecursiveCaller) { + InlineResult IR = "recursive and allocates too much stack space"; + if (ORE) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", + &CandidateCall) + << NV("Callee", &F) << " is " << NV("InlineResult", IR.message) + << ". Cost is not fully computed"; + }); + return IR; + } + + // Check if we've passed the maximum possible threshold so we don't spin in + // huge basic blocks that will never inline. + if (Cost >= Threshold && !ComputeFullInlineCost) + return false; + } + + return true; +} + +/// Compute the base pointer and cumulative constant offsets for V. +/// +/// This strips all constant offsets off of V, leaving it the base pointer, and +/// accumulates the total constant offset applied in the returned constant. It +/// returns 0 if V is not a pointer, and returns the constant '0' if there are +/// no constant offsets applied. +ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { + if (!V->getType()->isPointerTy()) + return nullptr; + + unsigned AS = V->getType()->getPointerAddressSpace(); + unsigned IntPtrWidth = DL.getIndexSizeInBits(AS); + APInt Offset = APInt::getNullValue(IntPtrWidth); + + // Even though we don't look through PHI nodes, we could be called on an + // instruction in an unreachable block, which may be on a cycle. + SmallPtrSet<Value *, 4> Visited; + Visited.insert(V); + do { + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + if (!GEP->isInBounds() || !accumulateGEPOffset(*GEP, Offset)) + return nullptr; + V = GEP->getPointerOperand(); + } else if (Operator::getOpcode(V) == Instruction::BitCast) { + V = cast<Operator>(V)->getOperand(0); + } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (GA->isInterposable()) + break; + V = GA->getAliasee(); + } else { + break; + } + assert(V->getType()->isPointerTy() && "Unexpected operand type!"); + } while (Visited.insert(V).second); + + Type *IntPtrTy = DL.getIntPtrType(V->getContext(), AS); + return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset)); +} + +/// Find dead blocks due to deleted CFG edges during inlining. +/// +/// If we know the successor of the current block, \p CurrBB, has to be \p +/// NextBB, the other successors of \p CurrBB are dead if these successors have +/// no live incoming CFG edges. If one block is found to be dead, we can +/// continue growing the dead block list by checking the successors of the dead +/// blocks to see if all their incoming edges are dead or not. +void CallAnalyzer::findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB) { + auto IsEdgeDead = [&](BasicBlock *Pred, BasicBlock *Succ) { + // A CFG edge is dead if the predecessor is dead or the predecessor has a + // known successor which is not the one under exam. + return (DeadBlocks.count(Pred) || + (KnownSuccessors[Pred] && KnownSuccessors[Pred] != Succ)); + }; + + auto IsNewlyDead = [&](BasicBlock *BB) { + // If all the edges to a block are dead, the block is also dead. + return (!DeadBlocks.count(BB) && + llvm::all_of(predecessors(BB), + [&](BasicBlock *P) { return IsEdgeDead(P, BB); })); + }; + + for (BasicBlock *Succ : successors(CurrBB)) { + if (Succ == NextBB || !IsNewlyDead(Succ)) + continue; + SmallVector<BasicBlock *, 4> NewDead; + NewDead.push_back(Succ); + while (!NewDead.empty()) { + BasicBlock *Dead = NewDead.pop_back_val(); + if (DeadBlocks.insert(Dead)) + // Continue growing the dead block lists. + for (BasicBlock *S : successors(Dead)) + if (IsNewlyDead(S)) + NewDead.push_back(S); + } + } +} + +/// Analyze a call site for potential inlining. +/// +/// Returns true if inlining this call is viable, and false if it is not +/// viable. It computes the cost and adjusts the threshold based on numerous +/// factors and heuristics. If this method returns false but the computed cost +/// is below the computed threshold, then inlining was forcibly disabled by +/// some artifact of the routine. +InlineResult CallAnalyzer::analyzeCall(CallBase &Call) { + ++NumCallsAnalyzed; + + // Perform some tweaks to the cost and threshold based on the direct + // callsite information. + + // We want to more aggressively inline vector-dense kernels, so up the + // threshold, and we'll lower it if the % of vector instructions gets too + // low. Note that these bonuses are some what arbitrary and evolved over time + // by accident as much as because they are principled bonuses. + // + // FIXME: It would be nice to remove all such bonuses. At least it would be + // nice to base the bonus values on something more scientific. + assert(NumInstructions == 0); + assert(NumVectorInstructions == 0); + + // Update the threshold based on callsite properties + updateThreshold(Call, F); + + // While Threshold depends on commandline options that can take negative + // values, we want to enforce the invariant that the computed threshold and + // bonuses are non-negative. + assert(Threshold >= 0); + assert(SingleBBBonus >= 0); + assert(VectorBonus >= 0); + + // Speculatively apply all possible bonuses to Threshold. If cost exceeds + // this Threshold any time, and cost cannot decrease, we can stop processing + // the rest of the function body. + Threshold += (SingleBBBonus + VectorBonus); + + // Give out bonuses for the callsite, as the instructions setting them up + // will be gone after inlining. + addCost(-getCallsiteCost(Call, DL)); + + // If this function uses the coldcc calling convention, prefer not to inline + // it. + if (F.getCallingConv() == CallingConv::Cold) + Cost += InlineConstants::ColdccPenalty; + + // Check if we're done. This can happen due to bonuses and penalties. + if (Cost >= Threshold && !ComputeFullInlineCost) + return "high cost"; + + if (F.empty()) + return true; + + Function *Caller = Call.getFunction(); + // Check if the caller function is recursive itself. + for (User *U : Caller->users()) { + CallBase *Call = dyn_cast<CallBase>(U); + if (Call && Call->getFunction() == Caller) { + IsCallerRecursive = true; + break; + } + } + + // Populate our simplified values by mapping from function arguments to call + // arguments with known important simplifications. + auto CAI = Call.arg_begin(); + for (Function::arg_iterator FAI = F.arg_begin(), FAE = F.arg_end(); + FAI != FAE; ++FAI, ++CAI) { + assert(CAI != Call.arg_end()); + if (Constant *C = dyn_cast<Constant>(CAI)) + SimplifiedValues[&*FAI] = C; + + Value *PtrArg = *CAI; + if (ConstantInt *C = stripAndComputeInBoundsConstantOffsets(PtrArg)) { + ConstantOffsetPtrs[&*FAI] = std::make_pair(PtrArg, C->getValue()); + + // We can SROA any pointer arguments derived from alloca instructions. + if (isa<AllocaInst>(PtrArg)) { + SROAArgValues[&*FAI] = PtrArg; + SROAArgCosts[PtrArg] = 0; + } + } + } + NumConstantArgs = SimplifiedValues.size(); + NumConstantOffsetPtrArgs = ConstantOffsetPtrs.size(); + NumAllocaArgs = SROAArgValues.size(); + + // FIXME: If a caller has multiple calls to a callee, we end up recomputing + // the ephemeral values multiple times (and they're completely determined by + // the callee, so this is purely duplicate work). + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(&F, &GetAssumptionCache(F), EphValues); + + // The worklist of live basic blocks in the callee *after* inlining. We avoid + // adding basic blocks of the callee which can be proven to be dead for this + // particular call site in order to get more accurate cost estimates. This + // requires a somewhat heavyweight iteration pattern: we need to walk the + // basic blocks in a breadth-first order as we insert live successors. To + // accomplish this, prioritizing for small iterations because we exit after + // crossing our threshold, we use a small-size optimized SetVector. + typedef SetVector<BasicBlock *, SmallVector<BasicBlock *, 16>, + SmallPtrSet<BasicBlock *, 16>> + BBSetVector; + BBSetVector BBWorklist; + BBWorklist.insert(&F.getEntryBlock()); + bool SingleBB = true; + // Note that we *must not* cache the size, this loop grows the worklist. + for (unsigned Idx = 0; Idx != BBWorklist.size(); ++Idx) { + // Bail out the moment we cross the threshold. This means we'll under-count + // the cost, but only when undercounting doesn't matter. + if (Cost >= Threshold && !ComputeFullInlineCost) + break; + + BasicBlock *BB = BBWorklist[Idx]; + if (BB->empty()) + continue; + + // Disallow inlining a blockaddress with uses other than strictly callbr. + // A blockaddress only has defined behavior for an indirect branch in the + // same function, and we do not currently support inlining indirect + // branches. But, the inliner may not see an indirect branch that ends up + // being dead code at a particular call site. If the blockaddress escapes + // the function, e.g., via a global variable, inlining may lead to an + // invalid cross-function reference. + // FIXME: pr/39560: continue relaxing this overt restriction. + if (BB->hasAddressTaken()) + for (User *U : BlockAddress::get(&*BB)->users()) + if (!isa<CallBrInst>(*U)) + return "blockaddress used outside of callbr"; + + // Analyze the cost of this block. If we blow through the threshold, this + // returns false, and we can bail on out. + InlineResult IR = analyzeBlock(BB, EphValues); + if (!IR) + return IR; + + Instruction *TI = BB->getTerminator(); + + // Add in the live successors by first checking whether we have terminator + // that may be simplified based on the values simplified by this call. + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isConditional()) { + Value *Cond = BI->getCondition(); + if (ConstantInt *SimpleCond = + dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { + BasicBlock *NextBB = BI->getSuccessor(SimpleCond->isZero() ? 1 : 0); + BBWorklist.insert(NextBB); + KnownSuccessors[BB] = NextBB; + findDeadBlocks(BB, NextBB); + continue; + } + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + Value *Cond = SI->getCondition(); + if (ConstantInt *SimpleCond = + dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { + BasicBlock *NextBB = SI->findCaseValue(SimpleCond)->getCaseSuccessor(); + BBWorklist.insert(NextBB); + KnownSuccessors[BB] = NextBB; + findDeadBlocks(BB, NextBB); + continue; + } + } + + // If we're unable to select a particular successor, just count all of + // them. + for (unsigned TIdx = 0, TSize = TI->getNumSuccessors(); TIdx != TSize; + ++TIdx) + BBWorklist.insert(TI->getSuccessor(TIdx)); + + // If we had any successors at this point, than post-inlining is likely to + // have them as well. Note that we assume any basic blocks which existed + // due to branches or switches which folded above will also fold after + // inlining. + if (SingleBB && TI->getNumSuccessors() > 1) { + // Take off the bonus we applied to the threshold. + Threshold -= SingleBBBonus; + SingleBB = false; + } + } + + bool OnlyOneCallAndLocalLinkage = + F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction(); + // If this is a noduplicate call, we can still inline as long as + // inlining this would cause the removal of the caller (so the instruction + // is not actually duplicated, just moved). + if (!OnlyOneCallAndLocalLinkage && ContainsNoDuplicateCall) + return "noduplicate"; + + // Loops generally act a lot like calls in that they act like barriers to + // movement, require a certain amount of setup, etc. So when optimising for + // size, we penalise any call sites that perform loops. We do this after all + // other costs here, so will likely only be dealing with relatively small + // functions (and hence DT and LI will hopefully be cheap). + if (Caller->hasMinSize()) { + DominatorTree DT(F); + LoopInfo LI(DT); + int NumLoops = 0; + for (Loop *L : LI) { + // Ignore loops that will not be executed + if (DeadBlocks.count(L->getHeader())) + continue; + NumLoops++; + } + addCost(NumLoops * InlineConstants::CallPenalty); + } + + // We applied the maximum possible vector bonus at the beginning. Now, + // subtract the excess bonus, if any, from the Threshold before + // comparing against Cost. + if (NumVectorInstructions <= NumInstructions / 10) + Threshold -= VectorBonus; + else if (NumVectorInstructions <= NumInstructions / 2) + Threshold -= VectorBonus/2; + + return Cost < std::max(1, Threshold); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +/// Dump stats about this call's analysis. +LLVM_DUMP_METHOD void CallAnalyzer::dump() { +#define DEBUG_PRINT_STAT(x) dbgs() << " " #x ": " << x << "\n" + DEBUG_PRINT_STAT(NumConstantArgs); + DEBUG_PRINT_STAT(NumConstantOffsetPtrArgs); + DEBUG_PRINT_STAT(NumAllocaArgs); + DEBUG_PRINT_STAT(NumConstantPtrCmps); + DEBUG_PRINT_STAT(NumConstantPtrDiffs); + DEBUG_PRINT_STAT(NumInstructionsSimplified); + DEBUG_PRINT_STAT(NumInstructions); + DEBUG_PRINT_STAT(SROACostSavings); + DEBUG_PRINT_STAT(SROACostSavingsLost); + DEBUG_PRINT_STAT(LoadEliminationCost); + DEBUG_PRINT_STAT(ContainsNoDuplicateCall); + DEBUG_PRINT_STAT(Cost); + DEBUG_PRINT_STAT(Threshold); +#undef DEBUG_PRINT_STAT +} +#endif + +/// Test that there are no attribute conflicts between Caller and Callee +/// that prevent inlining. +static bool functionsHaveCompatibleAttributes(Function *Caller, + Function *Callee, + TargetTransformInfo &TTI) { + return TTI.areInlineCompatible(Caller, Callee) && + AttributeFuncs::areInlineCompatible(*Caller, *Callee); +} + +int llvm::getCallsiteCost(CallBase &Call, const DataLayout &DL) { + int Cost = 0; + for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) { + if (Call.isByValArgument(I)) { + // We approximate the number of loads and stores needed by dividing the + // size of the byval type by the target's pointer size. + PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType()); + unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType()); + unsigned AS = PTy->getAddressSpace(); + unsigned PointerSize = DL.getPointerSizeInBits(AS); + // Ceiling division. + unsigned NumStores = (TypeSize + PointerSize - 1) / PointerSize; + + // If it generates more than 8 stores it is likely to be expanded as an + // inline memcpy so we take that as an upper bound. Otherwise we assume + // one load and one store per word copied. + // FIXME: The maxStoresPerMemcpy setting from the target should be used + // here instead of a magic number of 8, but it's not available via + // DataLayout. + NumStores = std::min(NumStores, 8U); + + Cost += 2 * NumStores * InlineConstants::InstrCost; + } else { + // For non-byval arguments subtract off one instruction per call + // argument. + Cost += InlineConstants::InstrCost; + } + } + // The call instruction also disappears after inlining. + Cost += InlineConstants::InstrCost + InlineConstants::CallPenalty; + return Cost; +} + +InlineCost llvm::getInlineCost( + CallBase &Call, const InlineParams &Params, TargetTransformInfo &CalleeTTI, + std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, + ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) { + return getInlineCost(Call, Call.getCalledFunction(), Params, CalleeTTI, + GetAssumptionCache, GetBFI, PSI, ORE); +} + +InlineCost llvm::getInlineCost( + CallBase &Call, Function *Callee, const InlineParams &Params, + TargetTransformInfo &CalleeTTI, + std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, + ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) { + + // Cannot inline indirect calls. + if (!Callee) + return llvm::InlineCost::getNever("indirect call"); + + // Never inline calls with byval arguments that does not have the alloca + // address space. Since byval arguments can be replaced with a copy to an + // alloca, the inlined code would need to be adjusted to handle that the + // argument is in the alloca address space (so it is a little bit complicated + // to solve). + unsigned AllocaAS = Callee->getParent()->getDataLayout().getAllocaAddrSpace(); + for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) + if (Call.isByValArgument(I)) { + PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType()); + if (PTy->getAddressSpace() != AllocaAS) + return llvm::InlineCost::getNever("byval arguments without alloca" + " address space"); + } + + // Calls to functions with always-inline attributes should be inlined + // whenever possible. + if (Call.hasFnAttr(Attribute::AlwaysInline)) { + auto IsViable = isInlineViable(*Callee); + if (IsViable) + return llvm::InlineCost::getAlways("always inline attribute"); + return llvm::InlineCost::getNever(IsViable.message); + } + + // Never inline functions with conflicting attributes (unless callee has + // always-inline attribute). + Function *Caller = Call.getCaller(); + if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI)) + return llvm::InlineCost::getNever("conflicting attributes"); + + // Don't inline this call if the caller has the optnone attribute. + if (Caller->hasOptNone()) + return llvm::InlineCost::getNever("optnone attribute"); + + // Don't inline a function that treats null pointer as valid into a caller + // that does not have this attribute. + if (!Caller->nullPointerIsDefined() && Callee->nullPointerIsDefined()) + return llvm::InlineCost::getNever("nullptr definitions incompatible"); + + // Don't inline functions which can be interposed at link-time. + if (Callee->isInterposable()) + return llvm::InlineCost::getNever("interposable"); + + // Don't inline functions marked noinline. + if (Callee->hasFnAttribute(Attribute::NoInline)) + return llvm::InlineCost::getNever("noinline function attribute"); + + // Don't inline call sites marked noinline. + if (Call.isNoInline()) + return llvm::InlineCost::getNever("noinline call site attribute"); + + LLVM_DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() + << "... (caller:" << Caller->getName() << ")\n"); + + CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee, + Call, Params); + InlineResult ShouldInline = CA.analyzeCall(Call); + + LLVM_DEBUG(CA.dump()); + + // Check if there was a reason to force inlining or no inlining. + if (!ShouldInline && CA.getCost() < CA.getThreshold()) + return InlineCost::getNever(ShouldInline.message); + if (ShouldInline && CA.getCost() >= CA.getThreshold()) + return InlineCost::getAlways("empty function"); + + return llvm::InlineCost::get(CA.getCost(), CA.getThreshold()); +} + +InlineResult llvm::isInlineViable(Function &F) { + bool ReturnsTwice = F.hasFnAttribute(Attribute::ReturnsTwice); + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + // Disallow inlining of functions which contain indirect branches. + if (isa<IndirectBrInst>(BI->getTerminator())) + return "contains indirect branches"; + + // Disallow inlining of blockaddresses which are used by non-callbr + // instructions. + if (BI->hasAddressTaken()) + for (User *U : BlockAddress::get(&*BI)->users()) + if (!isa<CallBrInst>(*U)) + return "blockaddress used outside of callbr"; + + for (auto &II : *BI) { + CallBase *Call = dyn_cast<CallBase>(&II); + if (!Call) + continue; + + // Disallow recursive calls. + if (&F == Call->getCalledFunction()) + return "recursive call"; + + // Disallow calls which expose returns-twice to a function not previously + // attributed as such. + if (!ReturnsTwice && isa<CallInst>(Call) && + cast<CallInst>(Call)->canReturnTwice()) + return "exposes returns-twice attribute"; + + if (Call->getCalledFunction()) + switch (Call->getCalledFunction()->getIntrinsicID()) { + default: + break; + // Disallow inlining of @llvm.icall.branch.funnel because current + // backend can't separate call targets from call arguments. + case llvm::Intrinsic::icall_branch_funnel: + return "disallowed inlining of @llvm.icall.branch.funnel"; + // Disallow inlining functions that call @llvm.localescape. Doing this + // correctly would require major changes to the inliner. + case llvm::Intrinsic::localescape: + return "disallowed inlining of @llvm.localescape"; + // Disallow inlining of functions that initialize VarArgs with va_start. + case llvm::Intrinsic::vastart: + return "contains VarArgs initialized with va_start"; + } + } + } + + return true; +} + +// APIs to create InlineParams based on command line flags and/or other +// parameters. + +InlineParams llvm::getInlineParams(int Threshold) { + InlineParams Params; + + // This field is the threshold to use for a callee by default. This is + // derived from one or more of: + // * optimization or size-optimization levels, + // * a value passed to createFunctionInliningPass function, or + // * the -inline-threshold flag. + // If the -inline-threshold flag is explicitly specified, that is used + // irrespective of anything else. + if (InlineThreshold.getNumOccurrences() > 0) + Params.DefaultThreshold = InlineThreshold; + else + Params.DefaultThreshold = Threshold; + + // Set the HintThreshold knob from the -inlinehint-threshold. + Params.HintThreshold = HintThreshold; + + // Set the HotCallSiteThreshold knob from the -hot-callsite-threshold. + Params.HotCallSiteThreshold = HotCallSiteThreshold; + + // If the -locally-hot-callsite-threshold is explicitly specified, use it to + // populate LocallyHotCallSiteThreshold. Later, we populate + // Params.LocallyHotCallSiteThreshold from -locally-hot-callsite-threshold if + // we know that optimization level is O3 (in the getInlineParams variant that + // takes the opt and size levels). + // FIXME: Remove this check (and make the assignment unconditional) after + // addressing size regression issues at O2. + if (LocallyHotCallSiteThreshold.getNumOccurrences() > 0) + Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold; + + // Set the ColdCallSiteThreshold knob from the -inline-cold-callsite-threshold. + Params.ColdCallSiteThreshold = ColdCallSiteThreshold; + + // Set the OptMinSizeThreshold and OptSizeThreshold params only if the + // -inlinehint-threshold commandline option is not explicitly given. If that + // option is present, then its value applies even for callees with size and + // minsize attributes. + // If the -inline-threshold is not specified, set the ColdThreshold from the + // -inlinecold-threshold even if it is not explicitly passed. If + // -inline-threshold is specified, then -inlinecold-threshold needs to be + // explicitly specified to set the ColdThreshold knob + if (InlineThreshold.getNumOccurrences() == 0) { + Params.OptMinSizeThreshold = InlineConstants::OptMinSizeThreshold; + Params.OptSizeThreshold = InlineConstants::OptSizeThreshold; + Params.ColdThreshold = ColdThreshold; + } else if (ColdThreshold.getNumOccurrences() > 0) { + Params.ColdThreshold = ColdThreshold; + } + return Params; +} + +InlineParams llvm::getInlineParams() { + return getInlineParams(InlineThreshold); +} + +// Compute the default threshold for inlining based on the opt level and the +// size opt level. +static int computeThresholdFromOptLevels(unsigned OptLevel, + unsigned SizeOptLevel) { + if (OptLevel > 2) + return InlineConstants::OptAggressiveThreshold; + if (SizeOptLevel == 1) // -Os + return InlineConstants::OptSizeThreshold; + if (SizeOptLevel == 2) // -Oz + return InlineConstants::OptMinSizeThreshold; + return InlineThreshold; +} + +InlineParams llvm::getInlineParams(unsigned OptLevel, unsigned SizeOptLevel) { + auto Params = + getInlineParams(computeThresholdFromOptLevels(OptLevel, SizeOptLevel)); + // At O3, use the value of -locally-hot-callsite-threshold option to populate + // Params.LocallyHotCallSiteThreshold. Below O3, this flag has effect only + // when it is specified explicitly. + if (OptLevel > 2) + Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold; + return Params; +} diff --git a/llvm/lib/Analysis/InstCount.cpp b/llvm/lib/Analysis/InstCount.cpp new file mode 100644 index 000000000000..943a99a5f46d --- /dev/null +++ b/llvm/lib/Analysis/InstCount.cpp @@ -0,0 +1,78 @@ +//===-- InstCount.cpp - Collects the count of all instructions ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass collects the count of all instructions and reports them +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "instcount" + +STATISTIC(TotalInsts , "Number of instructions (of all types)"); +STATISTIC(TotalBlocks, "Number of basic blocks"); +STATISTIC(TotalFuncs , "Number of non-external functions"); + +#define HANDLE_INST(N, OPCODE, CLASS) \ + STATISTIC(Num ## OPCODE ## Inst, "Number of " #OPCODE " insts"); + +#include "llvm/IR/Instruction.def" + +namespace { + class InstCount : public FunctionPass, public InstVisitor<InstCount> { + friend class InstVisitor<InstCount>; + + void visitFunction (Function &F) { ++TotalFuncs; } + void visitBasicBlock(BasicBlock &BB) { ++TotalBlocks; } + +#define HANDLE_INST(N, OPCODE, CLASS) \ + void visit##OPCODE(CLASS &) { ++Num##OPCODE##Inst; ++TotalInsts; } + +#include "llvm/IR/Instruction.def" + + void visitInstruction(Instruction &I) { + errs() << "Instruction Count does not know about " << I; + llvm_unreachable(nullptr); + } + public: + static char ID; // Pass identification, replacement for typeid + InstCount() : FunctionPass(ID) { + initializeInstCountPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + void print(raw_ostream &O, const Module *M) const override {} + + }; +} + +char InstCount::ID = 0; +INITIALIZE_PASS(InstCount, "instcount", + "Counts the various types of Instructions", false, true) + +FunctionPass *llvm::createInstCountPass() { return new InstCount(); } + +// InstCount::run - This is the main Analysis entry point for a +// function. +// +bool InstCount::runOnFunction(Function &F) { + visit(F); + return false; +} diff --git a/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp b/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp new file mode 100644 index 000000000000..35190ce3e11a --- /dev/null +++ b/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp @@ -0,0 +1,160 @@ +//===-- InstructionPrecedenceTracking.cpp -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Implements a class that is able to define some instructions as "special" +// (e.g. as having implicit control flow, or writing memory, or having another +// interesting property) and then efficiently answers queries of the types: +// 1. Are there any special instructions in the block of interest? +// 2. Return first of the special instructions in the given block; +// 3. Check if the given instruction is preceeded by the first special +// instruction in the same block. +// The class provides caching that allows to answer these queries quickly. The +// user must make sure that the cached data is invalidated properly whenever +// a content of some tracked block is changed. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InstructionPrecedenceTracking.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; + +#ifndef NDEBUG +static cl::opt<bool> ExpensiveAsserts( + "ipt-expensive-asserts", + cl::desc("Perform expensive assert validation on every query to Instruction" + " Precedence Tracking"), + cl::init(false), cl::Hidden); +#endif + +const Instruction *InstructionPrecedenceTracking::getFirstSpecialInstruction( + const BasicBlock *BB) { +#ifndef NDEBUG + // If there is a bug connected to invalid cache, turn on ExpensiveAsserts to + // catch this situation as early as possible. + if (ExpensiveAsserts) + validateAll(); + else + validate(BB); +#endif + + if (FirstSpecialInsts.find(BB) == FirstSpecialInsts.end()) { + fill(BB); + assert(FirstSpecialInsts.find(BB) != FirstSpecialInsts.end() && "Must be!"); + } + return FirstSpecialInsts[BB]; +} + +bool InstructionPrecedenceTracking::hasSpecialInstructions( + const BasicBlock *BB) { + return getFirstSpecialInstruction(BB) != nullptr; +} + +bool InstructionPrecedenceTracking::isPreceededBySpecialInstruction( + const Instruction *Insn) { + const Instruction *MaybeFirstSpecial = + getFirstSpecialInstruction(Insn->getParent()); + return MaybeFirstSpecial && OI.dominates(MaybeFirstSpecial, Insn); +} + +void InstructionPrecedenceTracking::fill(const BasicBlock *BB) { + FirstSpecialInsts.erase(BB); + for (auto &I : *BB) + if (isSpecialInstruction(&I)) { + FirstSpecialInsts[BB] = &I; + return; + } + + // Mark this block as having no special instructions. + FirstSpecialInsts[BB] = nullptr; +} + +#ifndef NDEBUG +void InstructionPrecedenceTracking::validate(const BasicBlock *BB) const { + auto It = FirstSpecialInsts.find(BB); + // Bail if we don't have anything cached for this block. + if (It == FirstSpecialInsts.end()) + return; + + for (const Instruction &Insn : *BB) + if (isSpecialInstruction(&Insn)) { + assert(It->second == &Insn && + "Cached first special instruction is wrong!"); + return; + } + + assert(It->second == nullptr && + "Block is marked as having special instructions but in fact it has " + "none!"); +} + +void InstructionPrecedenceTracking::validateAll() const { + // Check that for every known block the cached value is correct. + for (auto &It : FirstSpecialInsts) + validate(It.first); +} +#endif + +void InstructionPrecedenceTracking::insertInstructionTo(const Instruction *Inst, + const BasicBlock *BB) { + if (isSpecialInstruction(Inst)) + FirstSpecialInsts.erase(BB); + OI.invalidateBlock(BB); +} + +void InstructionPrecedenceTracking::removeInstruction(const Instruction *Inst) { + if (isSpecialInstruction(Inst)) + FirstSpecialInsts.erase(Inst->getParent()); + OI.invalidateBlock(Inst->getParent()); +} + +void InstructionPrecedenceTracking::clear() { + for (auto It : FirstSpecialInsts) + OI.invalidateBlock(It.first); + FirstSpecialInsts.clear(); +#ifndef NDEBUG + // The map should be valid after clearing (at least empty). + validateAll(); +#endif +} + +bool ImplicitControlFlowTracking::isSpecialInstruction( + const Instruction *Insn) const { + // If a block's instruction doesn't always pass the control to its successor + // instruction, mark the block as having implicit control flow. We use them + // to avoid wrong assumptions of sort "if A is executed and B post-dominates + // A, then B is also executed". This is not true is there is an implicit + // control flow instruction (e.g. a guard) between them. + // + // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false + // for volatile stores and loads because they can trap. The discussion on + // whether or not it is correct is still ongoing. We might want to get rid + // of this logic in the future. Anyways, trapping instructions shouldn't + // introduce implicit control flow, so we explicitly allow them here. This + // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed. + if (isGuaranteedToTransferExecutionToSuccessor(Insn)) + return false; + if (isa<LoadInst>(Insn)) { + assert(cast<LoadInst>(Insn)->isVolatile() && + "Non-volatile load should transfer execution to successor!"); + return false; + } + if (isa<StoreInst>(Insn)) { + assert(cast<StoreInst>(Insn)->isVolatile() && + "Non-volatile store should transfer execution to successor!"); + return false; + } + return true; +} + +bool MemoryWriteTracking::isSpecialInstruction( + const Instruction *Insn) const { + using namespace PatternMatch; + if (match(Insn, m_Intrinsic<Intrinsic::experimental_widenable_condition>())) + return false; + return Insn->mayWriteToMemory(); +} diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp new file mode 100644 index 000000000000..cb8987721700 --- /dev/null +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -0,0 +1,5518 @@ +//===- InstructionSimplify.cpp - Fold instruction operands ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements routines for folding instructions into simpler forms +// that do not require creating new instructions. This does constant folding +// ("add i32 1, 1" -> "2") but can also handle non-constant operands, either +// returning a constant ("and i32 %x, 0" -> "0") or an already existing value +// ("and i32 %x, %x" -> "%x"). All operands are assumed to have already been +// simplified: This is usually true and assuming it simplifies the logic (if +// they have not been simplified then results are correct but maybe suboptimal). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/KnownBits.h" +#include <algorithm> +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "instsimplify" + +enum { RecursionLimit = 3 }; + +STATISTIC(NumExpand, "Number of expansions"); +STATISTIC(NumReassoc, "Number of reassociations"); + +static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyFPUnOp(unsigned, Value *, const FastMathFlags &, + const SimplifyQuery &, unsigned); +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, + unsigned); +static Value *SimplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, + const SimplifyQuery &, unsigned); +static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, + unsigned); +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse); +static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyCastInst(unsigned, Value *, Type *, + const SimplifyQuery &, unsigned); +static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &, + unsigned); + +static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, + Value *FalseVal) { + BinaryOperator::BinaryOps BinOpCode; + if (auto *BO = dyn_cast<BinaryOperator>(Cond)) + BinOpCode = BO->getOpcode(); + else + return nullptr; + + CmpInst::Predicate ExpectedPred, Pred1, Pred2; + if (BinOpCode == BinaryOperator::Or) { + ExpectedPred = ICmpInst::ICMP_NE; + } else if (BinOpCode == BinaryOperator::And) { + ExpectedPred = ICmpInst::ICMP_EQ; + } else + return nullptr; + + // %A = icmp eq %TV, %FV + // %B = icmp eq %X, %Y (and one of these is a select operand) + // %C = and %A, %B + // %D = select %C, %TV, %FV + // --> + // %FV + + // %A = icmp ne %TV, %FV + // %B = icmp ne %X, %Y (and one of these is a select operand) + // %C = or %A, %B + // %D = select %C, %TV, %FV + // --> + // %TV + Value *X, *Y; + if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal), + m_Specific(FalseVal)), + m_ICmp(Pred2, m_Value(X), m_Value(Y)))) || + Pred1 != Pred2 || Pred1 != ExpectedPred) + return nullptr; + + if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal) + return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal; + + return nullptr; +} + +/// For a boolean type or a vector of boolean type, return false or a vector +/// with every element false. +static Constant *getFalse(Type *Ty) { + return ConstantInt::getFalse(Ty); +} + +/// For a boolean type or a vector of boolean type, return true or a vector +/// with every element true. +static Constant *getTrue(Type *Ty) { + return ConstantInt::getTrue(Ty); +} + +/// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? +static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + CmpInst *Cmp = dyn_cast<CmpInst>(V); + if (!Cmp) + return false; + CmpInst::Predicate CPred = Cmp->getPredicate(); + Value *CLHS = Cmp->getOperand(0), *CRHS = Cmp->getOperand(1); + if (CPred == Pred && CLHS == LHS && CRHS == RHS) + return true; + return CPred == CmpInst::getSwappedPredicate(Pred) && CLHS == RHS && + CRHS == LHS; +} + +/// Does the given value dominate the specified phi node? +static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + // Arguments and constants dominate all instructions. + return true; + + // If we are processing instructions (and/or basic blocks) that have not been + // fully added to a function, the parent nodes may still be null. Simply + // return the conservative answer in these cases. + if (!I->getParent() || !P->getParent() || !I->getFunction()) + return false; + + // If we have a DominatorTree then do a precise test. + if (DT) + return DT->dominates(I, P); + + // Otherwise, if the instruction is in the entry block and is not an invoke, + // then it obviously dominates all phi nodes. + if (I->getParent() == &I->getFunction()->getEntryBlock() && + !isa<InvokeInst>(I)) + return true; + + return false; +} + +/// Simplify "A op (B op' C)" by distributing op over op', turning it into +/// "(A op B) op' (A op C)". Here "op" is given by Opcode and "op'" is +/// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. +/// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". +/// Returns the simplified value, or null if no simplification was performed. +static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + Instruction::BinaryOps OpcodeToExpand, + const SimplifyQuery &Q, unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + // Check whether the expression has the form "(A op' B) op C". + if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) + if (Op0->getOpcode() == OpcodeToExpand) { + // It does! Try turning it into "(A op C) op' (B op C)". + Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; + // Do "A op C" and "B op C" both simplify? + if (Value *L = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) + if (Value *R = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + // They do! Return "L op' R" if it simplifies or is already available. + // If "L op' R" equals "A op' B" then "L op' R" is just the LHS. + if ((L == A && R == B) || (Instruction::isCommutative(OpcodeToExpand) + && L == B && R == A)) { + ++NumExpand; + return LHS; + } + // Otherwise return "L op' R" if it simplifies. + if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { + ++NumExpand; + return V; + } + } + } + + // Check whether the expression has the form "A op (B op' C)". + if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) + if (Op1->getOpcode() == OpcodeToExpand) { + // It does! Try turning it into "(A op B) op' (A op C)". + Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); + // Do "A op B" and "A op C" both simplify? + if (Value *L = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) + if (Value *R = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) { + // They do! Return "L op' R" if it simplifies or is already available. + // If "L op' R" equals "B op' C" then "L op' R" is just the RHS. + if ((L == B && R == C) || (Instruction::isCommutative(OpcodeToExpand) + && L == C && R == B)) { + ++NumExpand; + return RHS; + } + // Otherwise return "L op' R" if it simplifies. + if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { + ++NumExpand; + return V; + } + } + } + + return nullptr; +} + +/// Generic simplifications for associative binary operations. +/// Returns the simpler value, or null if none was found. +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); + + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); + BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + + // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely. + if (Op0 && Op0->getOpcode() == Opcode) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = RHS; + + // Does "B op C" simplify? + if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + // It does! Return "A op V" if it simplifies or is already available. + // If V equals B then "A op V" is just the LHS. + if (V == B) return LHS; + // Otherwise return "A op V" if it simplifies. + if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { + ++NumReassoc; + return W; + } + } + } + + // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely. + if (Op1 && Op1->getOpcode() == Opcode) { + Value *A = LHS; + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + + // Does "A op B" simplify? + if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { + // It does! Return "V op C" if it simplifies or is already available. + // If V equals B then "V op C" is just the RHS. + if (V == B) return RHS; + // Otherwise return "V op C" if it simplifies. + if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { + ++NumReassoc; + return W; + } + } + } + + // The remaining transforms require commutativity as well as associativity. + if (!Instruction::isCommutative(Opcode)) + return nullptr; + + // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely. + if (Op0 && Op0->getOpcode() == Opcode) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = RHS; + + // Does "C op A" simplify? + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + // It does! Return "V op B" if it simplifies or is already available. + // If V equals A then "V op B" is just the LHS. + if (V == A) return LHS; + // Otherwise return "V op B" if it simplifies. + if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { + ++NumReassoc; + return W; + } + } + } + + // Transform: "A op (B op C)" ==> "B op (C op A)" if it simplifies completely. + if (Op1 && Op1->getOpcode() == Opcode) { + Value *A = LHS; + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + + // Does "C op A" simplify? + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + // It does! Return "B op V" if it simplifies or is already available. + // If V equals C then "B op V" is just the RHS. + if (V == C) return RHS; + // Otherwise return "B op V" if it simplifies. + if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { + ++NumReassoc; + return W; + } + } + } + + return nullptr; +} + +/// In the case of a binary operation with a select instruction as an operand, +/// try to simplify the binop by seeing whether evaluating it on both branches +/// of the select results in the same value. Returns the common value if so, +/// otherwise returns null. +static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + SelectInst *SI; + if (isa<SelectInst>(LHS)) { + SI = cast<SelectInst>(LHS); + } else { + assert(isa<SelectInst>(RHS) && "No select instruction operand!"); + SI = cast<SelectInst>(RHS); + } + + // Evaluate the BinOp on the true and false branches of the select. + Value *TV; + Value *FV; + if (SI == LHS) { + TV = SimplifyBinOp(Opcode, SI->getTrueValue(), RHS, Q, MaxRecurse); + FV = SimplifyBinOp(Opcode, SI->getFalseValue(), RHS, Q, MaxRecurse); + } else { + TV = SimplifyBinOp(Opcode, LHS, SI->getTrueValue(), Q, MaxRecurse); + FV = SimplifyBinOp(Opcode, LHS, SI->getFalseValue(), Q, MaxRecurse); + } + + // If they simplified to the same value, then return the common value. + // If they both failed to simplify then return null. + if (TV == FV) + return TV; + + // If one branch simplified to undef, return the other one. + if (TV && isa<UndefValue>(TV)) + return FV; + if (FV && isa<UndefValue>(FV)) + return TV; + + // If applying the operation did not change the true and false select values, + // then the result of the binop is the select itself. + if (TV == SI->getTrueValue() && FV == SI->getFalseValue()) + return SI; + + // If one branch simplified and the other did not, and the simplified + // value is equal to the unsimplified one, return the simplified value. + // For example, select (cond, X, X & Z) & Z -> X & Z. + if ((FV && !TV) || (TV && !FV)) { + // Check that the simplified value has the form "X op Y" where "op" is the + // same as the original operation. + Instruction *Simplified = dyn_cast<Instruction>(FV ? FV : TV); + if (Simplified && Simplified->getOpcode() == unsigned(Opcode)) { + // The value that didn't simplify is "UnsimplifiedLHS op UnsimplifiedRHS". + // We already know that "op" is the same as for the simplified value. See + // if the operands match too. If so, return the simplified value. + Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue(); + Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS; + Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch; + if (Simplified->getOperand(0) == UnsimplifiedLHS && + Simplified->getOperand(1) == UnsimplifiedRHS) + return Simplified; + if (Simplified->isCommutative() && + Simplified->getOperand(1) == UnsimplifiedLHS && + Simplified->getOperand(0) == UnsimplifiedRHS) + return Simplified; + } + } + + return nullptr; +} + +/// In the case of a comparison with a select instruction, try to simplify the +/// comparison by seeing whether both branches of the select result in the same +/// value. Returns the common value if so, otherwise returns null. +static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + // Make sure the select is on the LHS. + if (!isa<SelectInst>(LHS)) { + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + assert(isa<SelectInst>(LHS) && "Not comparing with a select instruction!"); + SelectInst *SI = cast<SelectInst>(LHS); + Value *Cond = SI->getCondition(); + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + + // Now that we have "cmp select(Cond, TV, FV), RHS", analyse it. + // Does "cmp TV, RHS" simplify? + Value *TCmp = SimplifyCmpInst(Pred, TV, RHS, Q, MaxRecurse); + if (TCmp == Cond) { + // It not only simplified, it simplified to the select condition. Replace + // it with 'true'. + TCmp = getTrue(Cond->getType()); + } else if (!TCmp) { + // It didn't simplify. However if "cmp TV, RHS" is equal to the select + // condition then we can replace it with 'true'. Otherwise give up. + if (!isSameCompare(Cond, Pred, TV, RHS)) + return nullptr; + TCmp = getTrue(Cond->getType()); + } + + // Does "cmp FV, RHS" simplify? + Value *FCmp = SimplifyCmpInst(Pred, FV, RHS, Q, MaxRecurse); + if (FCmp == Cond) { + // It not only simplified, it simplified to the select condition. Replace + // it with 'false'. + FCmp = getFalse(Cond->getType()); + } else if (!FCmp) { + // It didn't simplify. However if "cmp FV, RHS" is equal to the select + // condition then we can replace it with 'false'. Otherwise give up. + if (!isSameCompare(Cond, Pred, FV, RHS)) + return nullptr; + FCmp = getFalse(Cond->getType()); + } + + // If both sides simplified to the same value, then use it as the result of + // the original comparison. + if (TCmp == FCmp) + return TCmp; + + // The remaining cases only make sense if the select condition has the same + // type as the result of the comparison, so bail out if this is not so. + if (Cond->getType()->isVectorTy() != RHS->getType()->isVectorTy()) + return nullptr; + // If the false value simplified to false, then the result of the compare + // is equal to "Cond && TCmp". This also catches the case when the false + // value simplified to false and the true value to true, returning "Cond". + if (match(FCmp, m_Zero())) + if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse)) + return V; + // If the true value simplified to true, then the result of the compare + // is equal to "Cond || FCmp". + if (match(TCmp, m_One())) + if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse)) + return V; + // Finally, if the false value simplified to true and the true value to + // false, then the result of the compare is equal to "!Cond". + if (match(FCmp, m_One()) && match(TCmp, m_Zero())) + if (Value *V = + SimplifyXorInst(Cond, Constant::getAllOnesValue(Cond->getType()), + Q, MaxRecurse)) + return V; + + return nullptr; +} + +/// In the case of a binary operation with an operand that is a PHI instruction, +/// try to simplify the binop by seeing whether evaluating it on the incoming +/// phi values yields the same result for every value. If so returns the common +/// value, otherwise returns null. +static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + PHINode *PI; + if (isa<PHINode>(LHS)) { + PI = cast<PHINode>(LHS); + // Bail out if RHS and the phi may be mutually interdependent due to a loop. + if (!valueDominatesPHI(RHS, PI, Q.DT)) + return nullptr; + } else { + assert(isa<PHINode>(RHS) && "No PHI instruction operand!"); + PI = cast<PHINode>(RHS); + // Bail out if LHS and the phi may be mutually interdependent due to a loop. + if (!valueDominatesPHI(LHS, PI, Q.DT)) + return nullptr; + } + + // Evaluate the BinOp on the incoming phi values. + Value *CommonValue = nullptr; + for (Value *Incoming : PI->incoming_values()) { + // If the incoming value is the phi node itself, it can safely be skipped. + if (Incoming == PI) continue; + Value *V = PI == LHS ? + SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) : + SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); + // If the operation failed to simplify, or simplified to a different value + // to previously, then give up. + if (!V || (CommonValue && V != CommonValue)) + return nullptr; + CommonValue = V; + } + + return CommonValue; +} + +/// In the case of a comparison with a PHI instruction, try to simplify the +/// comparison by seeing whether comparing with all of the incoming phi values +/// yields the same result every time. If so returns the common result, +/// otherwise returns null. +static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + // Make sure the phi is on the LHS. + if (!isa<PHINode>(LHS)) { + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + assert(isa<PHINode>(LHS) && "Not comparing with a phi instruction!"); + PHINode *PI = cast<PHINode>(LHS); + + // Bail out if RHS and the phi may be mutually interdependent due to a loop. + if (!valueDominatesPHI(RHS, PI, Q.DT)) + return nullptr; + + // Evaluate the BinOp on the incoming phi values. + Value *CommonValue = nullptr; + for (Value *Incoming : PI->incoming_values()) { + // If the incoming value is the phi node itself, it can safely be skipped. + if (Incoming == PI) continue; + Value *V = SimplifyCmpInst(Pred, Incoming, RHS, Q, MaxRecurse); + // If the operation failed to simplify, or simplified to a different value + // to previously, then give up. + if (!V || (CommonValue && V != CommonValue)) + return nullptr; + CommonValue = V; + } + + return CommonValue; +} + +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, + Value *&Op0, Value *&Op1, + const SimplifyQuery &Q) { + if (auto *CLHS = dyn_cast<Constant>(Op0)) { + if (auto *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + + // Canonicalize the constant to the RHS if this is a commutative operation. + if (Instruction::isCommutative(Opcode)) + std::swap(Op0, Op1); + } + return nullptr; +} + +/// Given operands for an Add, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) + return C; + + // X + undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + // X + 0 -> X + if (match(Op1, m_Zero())) + return Op0; + + // If two operands are negative, return 0. + if (isKnownNegation(Op0, Op1)) + return Constant::getNullValue(Op0->getType()); + + // X + (Y - X) -> Y + // (Y - X) + X -> Y + // Eg: X + -X -> 0 + Value *Y = nullptr; + if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || + match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) + return Y; + + // X + ~X -> -1 since ~X = -X-1 + Type *Ty = Op0->getType(); + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) + return Constant::getAllOnesValue(Ty); + + // add nsw/nuw (xor Y, signmask), signmask --> Y + // The no-wrapping add guarantees that the top bit will be set by the add. + // Therefore, the xor must be clearing the already set sign bit of Y. + if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + return Y; + + // add nuw %x, -1 -> -1, because %x can only be 0. + if (IsNUW && match(Op1, m_AllOnes())) + return Op1; // Which is -1. + + /// i1 add -> xor. + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) + if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + return V; + + // Try some generic simplifications for associative operations. + if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, + MaxRecurse)) + return V; + + // Threading Add over selects and phi nodes is pointless, so don't bother. + // Threading over the select in "A + select(cond, B, C)" means evaluating + // "A+B" and "A+C" and seeing if they are equal; but they are equal if and + // only if B and C are equal. If B and C are equal then (since we assume + // that operands have already been simplified) "select(cond, B, C)" should + // have been simplified to the common value of B and C already. Analysing + // "A+B" and "A+C" thus gains nothing, but costs compile time. Similarly + // for threading over phi nodes. + + return nullptr; +} + +Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, + const SimplifyQuery &Query) { + return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); +} + +/// Compute the base pointer and cumulative constant offsets for V. +/// +/// This strips all constant offsets off of V, leaving it the base pointer, and +/// accumulates the total constant offset applied in the returned constant. It +/// returns 0 if V is not a pointer, and returns the constant '0' if there are +/// no constant offsets applied. +/// +/// This is very similar to GetPointerBaseWithConstantOffset except it doesn't +/// follow non-inbounds geps. This allows it to remain usable for icmp ult/etc. +/// folding. +static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, + bool AllowNonInbounds = false) { + assert(V->getType()->isPtrOrPtrVectorTy()); + + Type *IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); + APInt Offset = APInt::getNullValue(IntPtrTy->getIntegerBitWidth()); + + V = V->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds); + // As that strip may trace through `addrspacecast`, need to sext or trunc + // the offset calculated. + IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); + Offset = Offset.sextOrTrunc(IntPtrTy->getIntegerBitWidth()); + + Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); + if (V->getType()->isVectorTy()) + return ConstantVector::getSplat(V->getType()->getVectorNumElements(), + OffsetIntPtr); + return OffsetIntPtr; +} + +/// Compute the constant difference between two pointer values. +/// If the difference is not a constant, returns zero. +static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, + Value *RHS) { + Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); + Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + + // If LHS and RHS are not related via constant offsets to the same base + // value, there is nothing we can do here. + if (LHS != RHS) + return nullptr; + + // Otherwise, the difference of LHS - RHS can be computed as: + // LHS - RHS + // = (LHSOffset + Base) - (RHSOffset + Base) + // = LHSOffset - RHSOffset + return ConstantExpr::getSub(LHSOffset, RHSOffset); +} + +/// Given operands for a Sub, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) + return C; + + // X - undef -> undef + // undef - X -> undef + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + return UndefValue::get(Op0->getType()); + + // X - 0 -> X + if (match(Op1, m_Zero())) + return Op0; + + // X - X -> 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // Is this a negation? + if (match(Op0, m_Zero())) { + // 0 - X -> 0 if the sub is NUW. + if (isNUW) + return Constant::getNullValue(Op0->getType()); + + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.Zero.isMaxSignedValue()) { + // Op1 is either 0 or the minimum signed value. If the sub is NSW, then + // Op1 must be 0 because negating the minimum signed value is undefined. + if (isNSW) + return Constant::getNullValue(Op0->getType()); + + // 0 - X -> X if X is 0 or the minimum signed value. + return Op1; + } + } + + // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. + // For example, (X + Y) - Y -> X; (Y + X) - Y -> X + Value *X = nullptr, *Y = nullptr, *Z = Op1; + if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z + // See if "V === Y - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse-1)) + // It does! Now see if "X + V" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + // See if "V === X - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + // It does! Now see if "Y + V" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + } + + // X - (Y + Z) -> (X - Y) - Z or (X - Z) - Y if everything simplifies. + // For example, X - (X + 1) -> -1 + X = Op0; + if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z) + // See if "V === X - Y" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + // It does! Now see if "V - Z" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + // See if "V === X - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + // It does! Now see if "V - Y" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + } + + // Z - (X - Y) -> (Z - X) + Y if everything simplifies. + // For example, X - (X - Y) -> Y. + Z = Op0; + if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y) + // See if "V === Z - X" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse-1)) + // It does! Now see if "V + Y" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + + // trunc(X) - trunc(Y) -> trunc(X - Y) if everything simplifies. + if (MaxRecurse && match(Op0, m_Trunc(m_Value(X))) && + match(Op1, m_Trunc(m_Value(Y)))) + if (X->getType() == Y->getType()) + // See if "V === X - Y" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + // It does! Now see if "trunc V" simplifies. + if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), + Q, MaxRecurse - 1)) + // It does, return the simplified "trunc V". + return W; + + // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...). + if (match(Op0, m_PtrToInt(m_Value(X))) && + match(Op1, m_PtrToInt(m_Value(Y)))) + if (Constant *Result = computePointerDifference(Q.DL, X, Y)) + return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); + + // i1 sub -> xor. + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) + if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + return V; + + // Threading Sub over selects and phi nodes is pointless, so don't bother. + // Threading over the select in "A - select(cond, B, C)" means evaluating + // "A-B" and "A-C" and seeing if they are equal; but they are equal if and + // only if B and C are equal. If B and C are equal then (since we assume + // that operands have already been simplified) "select(cond, B, C)" should + // have been simplified to the common value of B and C already. Analysing + // "A-B" and "A-C" thus gains nothing, but costs compile time. Similarly + // for threading over phi nodes. + + return nullptr; +} + +Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const SimplifyQuery &Q) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); +} + +/// Given operands for a Mul, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) + return C; + + // X * undef -> 0 + // X * 0 -> 0 + if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) + return Constant::getNullValue(Op0->getType()); + + // X * 1 -> X + if (match(Op1, m_One())) + return Op0; + + // (X / Y) * Y -> X if the division is exact. + Value *X = nullptr; + if (Q.IIQ.UseInstrInfo && + (match(Op0, + m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y + match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))) // Y * (X / Y) + return X; + + // i1 mul -> and. + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) + if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) + return V; + + // Try some generic simplifications for associative operations. + if (Value *V = SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, + MaxRecurse)) + return V; + + // Mul distributes over Add. Try some generic simplifications based on this. + if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, + Q, MaxRecurse)) + return V; + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, + MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, + MaxRecurse)) + return V; + + return nullptr; +} + +Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit); +} + +/// Check for common or similar folds of integer division or integer remainder. +/// This applies to all 4 opcodes (sdiv/udiv/srem/urem). +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { + Type *Ty = Op0->getType(); + + // X / undef -> undef + // X % undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + // X / 0 -> undef + // X % 0 -> undef + // We don't need to preserve faults! + if (match(Op1, m_Zero())) + return UndefValue::get(Ty); + + // If any element of a constant divisor vector is zero or undef, the whole op + // is undef. + auto *Op1C = dyn_cast<Constant>(Op1); + if (Op1C && Ty->isVectorTy()) { + unsigned NumElts = Ty->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt))) + return UndefValue::get(Ty); + } + } + + // undef / X -> 0 + // undef % X -> 0 + if (match(Op0, m_Undef())) + return Constant::getNullValue(Ty); + + // 0 / X -> 0 + // 0 % X -> 0 + if (match(Op0, m_Zero())) + return Constant::getNullValue(Op0->getType()); + + // X / X -> 1 + // X % X -> 0 + if (Op0 == Op1) + return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + + // X / 1 -> X + // X % 1 -> 0 + // If this is a boolean op (single-bit element type), we can't have + // division-by-zero or remainder-by-zero, so assume the divisor is 1. + // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. + Value *X; + if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || + (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + return IsDiv ? Op0 : Constant::getNullValue(Ty); + + return nullptr; +} + +/// Given a predicate and two operands, return true if the comparison is true. +/// This is a helper for div/rem simplification where we return some other value +/// when we can prove a relationship between the operands. +static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *V = SimplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); + Constant *C = dyn_cast_or_null<Constant>(V); + return (C && C->isAllOnesValue()); +} + +/// Return true if we can simplify X / Y to 0. Remainder can adapt that answer +/// to simplify X % Y to X. +static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, + unsigned MaxRecurse, bool IsSigned) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return false; + + if (IsSigned) { + // |X| / |Y| --> 0 + // + // We require that 1 operand is a simple constant. That could be extended to + // 2 variables if we computed the sign bit for each. + // + // Make sure that a constant is not the minimum signed value because taking + // the abs() of that is undefined. + Type *Ty = X->getType(); + const APInt *C; + if (match(X, m_APInt(C)) && !C->isMinSignedValue()) { + // Is the variable divisor magnitude always greater than the constant + // dividend magnitude? + // |Y| > |C| --> Y < -abs(C) or Y > abs(C) + Constant *PosDividendC = ConstantInt::get(Ty, C->abs()); + Constant *NegDividendC = ConstantInt::get(Ty, -C->abs()); + if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || + isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) + return true; + } + if (match(Y, m_APInt(C))) { + // Special-case: we can't take the abs() of a minimum signed value. If + // that's the divisor, then all we have to do is prove that the dividend + // is also not the minimum signed value. + if (C->isMinSignedValue()) + return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse); + + // Is the variable dividend magnitude always less than the constant + // divisor magnitude? + // |X| < |C| --> X > -abs(C) and X < abs(C) + Constant *PosDivisorC = ConstantInt::get(Ty, C->abs()); + Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs()); + if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) && + isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse)) + return true; + } + return false; + } + + // IsSigned == false. + // Is the dividend unsigned less than the divisor? + return isICmpTrue(ICmpInst::ICMP_ULT, X, Y, Q, MaxRecurse); +} + +/// These are simplifications common to SDiv and UDiv. +static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + if (Value *V = simplifyDivRem(Op0, Op1, true)) + return V; + + bool IsSigned = Opcode == Instruction::SDiv; + + // (X * Y) / Y -> X if the multiplication does not overflow. + Value *X; + if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + auto *Mul = cast<OverflowingBinaryOperator>(Op0); + // If the Mul does not overflow, then we are good to go. + if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || + (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul))) + return X; + // If X has the form X = A / Y, then X * Y cannot overflow. + if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) + return X; + } + + // (X rem Y) / Y -> 0 + if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + return Constant::getNullValue(Op0->getType()); + + // (X /u C1) /u C2 -> 0 if C1 * C2 overflow + ConstantInt *C1, *C2; + if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && + match(Op1, m_ConstantInt(C2))) { + bool Overflow; + (void)C1->getValue().umul_ov(C2->getValue(), Overflow); + if (Overflow) + return Constant::getNullValue(Op0->getType()); + } + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned)) + return Constant::getNullValue(Op0->getType()); + + return nullptr; +} + +/// These are simplifications common to SRem and URem. +static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + if (Value *V = simplifyDivRem(Op0, Op1, false)) + return V; + + // (X % Y) % Y -> X % Y + if ((Opcode == Instruction::SRem && + match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || + (Opcode == Instruction::URem && + match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + return Op0; + + // (X << Y) % X -> 0 + if (Q.IIQ.UseInstrInfo && + ((Opcode == Instruction::SRem && + match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || + (Opcode == Instruction::URem && + match(Op0, m_NUWShl(m_Specific(Op1), m_Value()))))) + return Constant::getNullValue(Op0->getType()); + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // If X / Y == 0, then X % Y == X. + if (isDivZero(Op0, Op1, Q, MaxRecurse, Opcode == Instruction::SRem)) + return Op0; + + return nullptr; +} + +/// Given operands for an SDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // If two operands are negated and no signed overflow, return -1. + if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) + return Constant::getAllOnesValue(Op0->getType()); + + return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a UDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for an SRem, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // If the divisor is 0, the result is undefined, so assume the divisor is -1. + // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return ConstantInt::getNullValue(Op0->getType()); + + // If the two operands are negated, return 0. + if (isKnownNegation(Op0, Op1)) + return ConstantInt::getNullValue(Op0->getType()); + + return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a URem, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + return simplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); +} + +/// Returns true if a shift by \c Amount always yields undef. +static bool isUndefShift(Value *Amount) { + Constant *C = dyn_cast<Constant>(Amount); + if (!C) + return false; + + // X shift by undef -> undef because it may shift by the bitwidth. + if (isa<UndefValue>(C)) + return true; + + // Shifting by the bitwidth or more is undefined. + if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) + if (CI->getValue().getLimitedValue() >= + CI->getType()->getScalarSizeInBits()) + return true; + + // If all lanes of a vector shift are undefined the whole shift is. + if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) { + for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; ++I) + if (!isUndefShift(C->getAggregateElement(I))) + return false; + return true; + } + + return false; +} + +/// Given operands for an Shl, LShr or AShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + // 0 shift by X -> 0 + if (match(Op0, m_Zero())) + return Constant::getNullValue(Op0->getType()); + + // X shift by 0 -> X + // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones + // would be poison. + Value *X; + if (match(Op1, m_Zero()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + return Op0; + + // Fold undefined shifts. + if (isUndefShift(Op1)) + return UndefValue::get(Op0->getType()); + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // If any bits in the shift amount make that value greater than or equal to + // the number of bits in the type, the shift is undefined. + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.One.getLimitedValue() >= Known.getBitWidth()) + return UndefValue::get(Op0->getType()); + + // If all valid bits in the shift amount are known zero, the first operand is + // unchanged. + unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); + if (Known.countMinTrailingZeros() >= NumValidShiftBits) + return Op0; + + return nullptr; +} + +/// Given operands for an Shl, LShr or AShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // X >> X -> 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // undef >> X -> 0 + // undef >> X -> undef (if it's exact) + if (match(Op0, m_Undef())) + return isExact ? Op0 : Constant::getNullValue(Op0->getType()); + + // The low bit cannot be shifted out of an exact shift if it is set. + if (isExact) { + KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + if (Op0Known.One[0]) + return Op0; + } + + return nullptr; +} + +/// Given operands for an Shl, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) + return V; + + // undef << X -> 0 + // undef << X -> undef if (if it's NSW/NUW) + if (match(Op0, m_Undef())) + return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType()); + + // (X >> A) << A -> X + Value *X; + if (Q.IIQ.UseInstrInfo && + match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) + return X; + + // shl nuw i8 C, %x -> C iff C has sign bit set. + if (isNUW && match(Op0, m_Negative())) + return Op0; + // NOTE: could use computeKnownBits() / LazyValueInfo, + // but the cost-benefit analysis suggests it isn't worth it. + + return nullptr; +} + +Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const SimplifyQuery &Q) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); +} + +/// Given operands for an LShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, + MaxRecurse)) + return V; + + // (X << A) >> A -> X + Value *X; + if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) + return X; + + // ((X << A) | Y) >> A -> X if effective width of Y is not larger than A. + // We can return X as we do in the above case since OR alters no bits in X. + // SimplifyDemandedBits in InstCombine can do more general optimization for + // bit manipulation. This pattern aims to provide opportunities for other + // optimizers by supporting a simple but common case in InstSimplify. + Value *Y; + const APInt *ShRAmt, *ShLAmt; + if (match(Op1, m_APInt(ShRAmt)) && + match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) && + *ShRAmt == *ShLAmt) { + const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const unsigned Width = Op0->getType()->getScalarSizeInBits(); + const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); + if (ShRAmt->uge(EffWidthY)) + return X; + } + + return nullptr; +} + +Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, + const SimplifyQuery &Q) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); +} + +/// Given operands for an AShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, + MaxRecurse)) + return V; + + // all ones >>a X -> -1 + // Do not return Op0 because it may contain undef elements if it's a vector. + if (match(Op0, m_AllOnes())) + return Constant::getAllOnesValue(Op0->getType()); + + // (X << A) >> A -> X + Value *X; + if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) + return X; + + // Arithmetic shifting an all-sign-bit value is a no-op. + unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (NumSignBits == Op0->getType()->getScalarSizeInBits()) + return Op0; + + return nullptr; +} + +Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, + const SimplifyQuery &Q) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, + ICmpInst *UnsignedICmp, bool IsAnd, + const SimplifyQuery &Q) { + Value *X, *Y; + + ICmpInst::Predicate EqPred; + if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) || + !ICmpInst::isEquality(EqPred)) + return nullptr; + + ICmpInst::Predicate UnsignedPred; + + Value *A, *B; + // Y = (A - B); + if (match(Y, m_Sub(m_Value(A), m_Value(B)))) { + if (match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(A), m_Specific(B))) && + ICmpInst::isUnsigned(UnsignedPred)) { + if (UnsignedICmp->getOperand(0) != A) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + // A >=/<= B || (A - B) != 0 <--> true + if ((UnsignedPred == ICmpInst::ICMP_UGE || + UnsignedPred == ICmpInst::ICMP_ULE) && + EqPred == ICmpInst::ICMP_NE && !IsAnd) + return ConstantInt::getTrue(UnsignedICmp->getType()); + // A </> B && (A - B) == 0 <--> false + if ((UnsignedPred == ICmpInst::ICMP_ULT || + UnsignedPred == ICmpInst::ICMP_UGT) && + EqPred == ICmpInst::ICMP_EQ && IsAnd) + return ConstantInt::getFalse(UnsignedICmp->getType()); + + // A </> B && (A - B) != 0 <--> A </> B + // A </> B || (A - B) != 0 <--> (A - B) != 0 + if (EqPred == ICmpInst::ICMP_NE && (UnsignedPred == ICmpInst::ICMP_ULT || + UnsignedPred == ICmpInst::ICMP_UGT)) + return IsAnd ? UnsignedICmp : ZeroICmp; + + // A <=/>= B && (A - B) == 0 <--> (A - B) == 0 + // A <=/>= B || (A - B) == 0 <--> A <=/>= B + if (EqPred == ICmpInst::ICMP_EQ && (UnsignedPred == ICmpInst::ICMP_ULE || + UnsignedPred == ICmpInst::ICMP_UGE)) + return IsAnd ? ZeroICmp : UnsignedICmp; + } + + // Given Y = (A - B) + // Y >= A && Y != 0 --> Y >= A iff B != 0 + // Y < A || Y == 0 --> Y < A iff B != 0 + if (match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(Y), m_Specific(A)))) { + if (UnsignedICmp->getOperand(0) != Y) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + if (UnsignedPred == ICmpInst::ICMP_UGE && IsAnd && + EqPred == ICmpInst::ICMP_NE && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return UnsignedICmp; + if (UnsignedPred == ICmpInst::ICMP_ULT && !IsAnd && + EqPred == ICmpInst::ICMP_EQ && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return UnsignedICmp; + } + } + + if (match(UnsignedICmp, m_ICmp(UnsignedPred, m_Value(X), m_Specific(Y))) && + ICmpInst::isUnsigned(UnsignedPred)) + ; + else if (match(UnsignedICmp, + m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) && + ICmpInst::isUnsigned(UnsignedPred)) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + else + return nullptr; + + // X < Y && Y != 0 --> X < Y + // X < Y || Y != 0 --> Y != 0 + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE) + return IsAnd ? UnsignedICmp : ZeroICmp; + + // X <= Y && Y != 0 --> X <= Y iff X != 0 + // X <= Y || Y != 0 --> Y != 0 iff X != 0 + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return IsAnd ? UnsignedICmp : ZeroICmp; + + // X >= Y && Y == 0 --> Y == 0 + // X >= Y || Y == 0 --> X >= Y + if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ) + return IsAnd ? ZeroICmp : UnsignedICmp; + + // X > Y && Y == 0 --> Y == 0 iff X != 0 + // X > Y || Y == 0 --> X > Y iff X != 0 + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return IsAnd ? ZeroICmp : UnsignedICmp; + + // X < Y && Y == 0 --> false + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_EQ && + IsAnd) + return getFalse(UnsignedICmp->getType()); + + // X >= Y || Y != 0 --> true + if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_NE && + !IsAnd) + return getTrue(UnsignedICmp->getType()); + + return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + + // We have (icmp Pred0, A, B) & (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op1 from this 'and'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op0; + + // Check for any combination of predicates that are guaranteed to be disjoint. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) || + (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)) + return getFalse(Op0->getType()); + + return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + + // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op0 from this 'or'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op1; + + // Check for any combination of predicates that cover the entire range of + // possibilities. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || + (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) + return getTrue(Op0->getType()); + + return nullptr; +} + +/// Test if a pair of compares with a shared operand and 2 constants has an +/// empty set intersection, full set union, or if one compare is a superset of +/// the other. +static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + // Look for this pattern: {and/or} (icmp X, C0), (icmp X, C1)). + if (Cmp0->getOperand(0) != Cmp1->getOperand(0)) + return nullptr; + + const APInt *C0, *C1; + if (!match(Cmp0->getOperand(1), m_APInt(C0)) || + !match(Cmp1->getOperand(1), m_APInt(C1))) + return nullptr; + + auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0); + auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1); + + // For and-of-compares, check if the intersection is empty: + // (icmp X, C0) && (icmp X, C1) --> empty set --> false + if (IsAnd && Range0.intersectWith(Range1).isEmptySet()) + return getFalse(Cmp0->getType()); + + // For or-of-compares, check if the union is full: + // (icmp X, C0) || (icmp X, C1) --> full set --> true + if (!IsAnd && Range0.unionWith(Range1).isFullSet()) + return getTrue(Cmp0->getType()); + + // Is one range a superset of the other? + // If this is and-of-compares, take the smaller set: + // (icmp sgt X, 4) && (icmp sgt X, 42) --> icmp sgt X, 42 + // If this is or-of-compares, take the larger set: + // (icmp sgt X, 4) || (icmp sgt X, 42) --> icmp sgt X, 4 + if (Range0.contains(Range1)) + return IsAnd ? Cmp1 : Cmp0; + if (Range1.contains(Range0)) + return IsAnd ? Cmp0 : Cmp1; + + return nullptr; +} + +static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); + if (!match(Cmp0->getOperand(1), m_Zero()) || + !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) + return nullptr; + + if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) + return nullptr; + + // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". + Value *X = Cmp0->getOperand(0); + Value *Y = Cmp1->getOperand(0); + + // If one of the compares is a masked version of a (not) null check, then + // that compare implies the other, so we eliminate the other. Optionally, look + // through a pointer-to-int cast to match a null check of a pointer type. + + // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 + // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 + // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 + // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 + if (match(Y, m_c_And(m_Specific(X), m_Value())) || + match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) + return Cmp1; + + // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 + // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 + // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 + // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 + if (match(X, m_c_And(m_Specific(Y), m_Value())) || + match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) + return Cmp0; + + return nullptr; +} + +static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { + // (icmp (add V, C0), C1) & (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; + + auto *AddInst = cast<OverflowingBinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); + bool isNSW = IIQ.hasNoSignedWrap(AddInst); + bool isNUW = IIQ.hasNoUnsignedWrap(AddInst); + + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + } + if (C0->getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + } + + return nullptr; +} + +static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1, + const SimplifyQuery &Q) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true, Q)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true, Q)) + return X; + + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) + return X; + + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1, Q.IIQ)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0, Q.IIQ)) + return X; + + return nullptr; +} + +static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { + // (icmp (add V, C0), C1) | (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); + bool isNSW = IIQ.hasNoSignedWrap(AddInst); + bool isNUW = IIQ.hasNoUnsignedWrap(AddInst); + + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + } + if (C0->getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; +} + +static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, + const SimplifyQuery &Q) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false, Q)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false, Q)) + return X; + + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) + return X; + + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1, Q.IIQ)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0, Q.IIQ)) + return X; + + return nullptr; +} + +static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, + FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + if (LHS0->getType() != RHS0->getType()) + return nullptr; + + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { + // (fcmp ord NNAN, X) & (fcmp ord X, Y) --> fcmp ord X, Y + // (fcmp ord NNAN, X) & (fcmp ord Y, X) --> fcmp ord Y, X + // (fcmp ord X, NNAN) & (fcmp ord X, Y) --> fcmp ord X, Y + // (fcmp ord X, NNAN) & (fcmp ord Y, X) --> fcmp ord Y, X + // (fcmp uno NNAN, X) | (fcmp uno X, Y) --> fcmp uno X, Y + // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X + // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y + // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X + if ((isKnownNeverNaN(LHS0, TLI) && (LHS1 == RHS0 || LHS1 == RHS1)) || + (isKnownNeverNaN(LHS1, TLI) && (LHS0 == RHS0 || LHS0 == RHS1))) + return RHS; + + // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y + // (fcmp ord Y, X) & (fcmp ord NNAN, X) --> fcmp ord Y, X + // (fcmp ord X, Y) & (fcmp ord X, NNAN) --> fcmp ord X, Y + // (fcmp ord Y, X) & (fcmp ord X, NNAN) --> fcmp ord Y, X + // (fcmp uno X, Y) | (fcmp uno NNAN, X) --> fcmp uno X, Y + // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X + // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y + // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X + if ((isKnownNeverNaN(RHS0, TLI) && (RHS1 == LHS0 || RHS1 == LHS1)) || + (isKnownNeverNaN(RHS1, TLI) && (RHS0 == LHS0 || RHS0 == LHS1))) + return LHS; + } + + return nullptr; +} + +static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, + Value *Op0, Value *Op1, bool IsAnd) { + // Look through casts of the 'and' operands to find compares. + auto *Cast0 = dyn_cast<CastInst>(Op0); + auto *Cast1 = dyn_cast<CastInst>(Op1); + if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && + Cast0->getSrcTy() == Cast1->getSrcTy()) { + Op0 = Cast0->getOperand(0); + Op1 = Cast1->getOperand(0); + } + + Value *V = nullptr; + auto *ICmp0 = dyn_cast<ICmpInst>(Op0); + auto *ICmp1 = dyn_cast<ICmpInst>(Op1); + if (ICmp0 && ICmp1) + V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1, Q) + : simplifyOrOfICmps(ICmp0, ICmp1, Q); + + auto *FCmp0 = dyn_cast<FCmpInst>(Op0); + auto *FCmp1 = dyn_cast<FCmpInst>(Op1); + if (FCmp0 && FCmp1) + V = simplifyAndOrOfFCmps(Q.TLI, FCmp0, FCmp1, IsAnd); + + if (!V) + return nullptr; + if (!Cast0) + return V; + + // If we looked through casts, we can only handle a constant simplification + // because we are not allowed to create a cast instruction here. + if (auto *C = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType()); + + return nullptr; +} + +/// Check that the Op1 is in expected form, i.e.: +/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1, + Value *X) { + auto *Extract = dyn_cast<ExtractValueInst>(Op1); + // We should only be extracting the overflow bit. + if (!Extract || !Extract->getIndices().equals(1)) + return false; + Value *Agg = Extract->getAggregateOperand(); + // This should be a multiplication-with-overflow intrinsic. + if (!match(Agg, m_CombineOr(m_Intrinsic<Intrinsic::umul_with_overflow>(), + m_Intrinsic<Intrinsic::smul_with_overflow>()))) + return false; + // One of its multipliers should be the value we checked for zero before. + if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), + m_Argument<1>(m_Specific(X))))) + return false; + return true; +} + +/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some +/// other form of check, e.g. one that was using division; it may have been +/// guarded against division-by-zero. We can drop that check now. +/// Look for: +/// %Op0 = icmp ne i4 %X, 0 +/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %??? = and i1 %Op0, %Op1 +/// We can just return %Op1 +static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) { + ICmpInst::Predicate Pred; + Value *X; + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || + Pred != ICmpInst::Predicate::ICMP_NE) + return nullptr; + // Is Op1 in expected form? + if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) + return nullptr; + // Can omit 'and', and just return the overflow bit. + return Op1; +} + +/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some +/// other form of check, e.g. one that was using division; it may have been +/// guarded against division-by-zero. We can drop that check now. +/// Look for: +/// %Op0 = icmp eq i4 %X, 0 +/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %NotOp1 = xor i1 %Op1, true +/// %or = or i1 %Op0, %NotOp1 +/// We can just return %NotOp1 +static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, + Value *NotOp1) { + ICmpInst::Predicate Pred; + Value *X; + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || + Pred != ICmpInst::Predicate::ICMP_EQ) + return nullptr; + // We expect the other hand of an 'or' to be a 'not'. + Value *Op1; + if (!match(NotOp1, m_Not(m_Value(Op1)))) + return nullptr; + // Is Op1 in expected form? + if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) + return nullptr; + // Can omit 'and', and just return the inverted overflow bit. + return NotOp1; +} + +/// Given operands for an And, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) + return C; + + // X & undef -> 0 + if (match(Op1, m_Undef())) + return Constant::getNullValue(Op0->getType()); + + // X & X = X + if (Op0 == Op1) + return Op0; + + // X & 0 = 0 + if (match(Op1, m_Zero())) + return Constant::getNullValue(Op0->getType()); + + // X & -1 = X + if (match(Op1, m_AllOnes())) + return Op0; + + // A & ~A = ~A & A = 0 + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) + return Constant::getNullValue(Op0->getType()); + + // (A | ?) & A = A + if (match(Op0, m_c_Or(m_Specific(Op1), m_Value()))) + return Op1; + + // A & (A | ?) = A + if (match(Op1, m_c_Or(m_Specific(Op0), m_Value()))) + return Op0; + + // A mask that only clears known zeros of a shifted value is a no-op. + Value *X; + const APInt *Mask; + const APInt *ShAmt; + if (match(Op1, m_APInt(Mask))) { + // If all bits in the inverted and shifted mask are clear: + // and (shl X, ShAmt), Mask --> shl X, ShAmt + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).lshr(*ShAmt).isNullValue()) + return Op0; + + // If all bits in the inverted and shifted mask are clear: + // and (lshr X, ShAmt), Mask --> lshr X, ShAmt + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).shl(*ShAmt).isNullValue()) + return Op0; + } + + // If we have a multiplication overflow check that is being 'and'ed with a + // check that one of the multipliers is not zero, we can omit the 'and', and + // only keep the overflow check. + if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1)) + return V; + if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0)) + return V; + + // A & (-A) = A if A is a power of two or zero. + if (match(Op0, m_Neg(m_Specific(Op1))) || + match(Op1, m_Neg(m_Specific(Op0)))) { + if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, + Q.DT)) + return Op0; + if (isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, + Q.DT)) + return Op1; + } + + // This is a similar pattern used for checking if a value is a power-of-2: + // (A - 1) & A --> 0 (if A is a power-of-2 or 0) + // A & (A - 1) --> 0 (if A is a power-of-2 or 0) + if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Constant::getNullValue(Op1->getType()); + if (match(Op1, m_Add(m_Specific(Op0), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Constant::getNullValue(Op0->getType()); + + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true)) + return V; + + // Try some generic simplifications for associative operations. + if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, + MaxRecurse)) + return V; + + // And distributes over Or. Try some generic simplifications based on this. + if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Or, + Q, MaxRecurse)) + return V; + + // And distributes over Xor. Try some generic simplifications based on this. + if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Xor, + Q, MaxRecurse)) + return V; + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, + MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q, + MaxRecurse)) + return V; + + // Assuming the effective width of Y is not larger than A, i.e. all bits + // from X and Y are disjoint in (X << A) | Y, + // if the mask of this AND op covers all bits of X or Y, while it covers + // no bits from the other, we can bypass this AND op. E.g., + // ((X << A) | Y) & Mask -> Y, + // if Mask = ((1 << effective_width_of(Y)) - 1) + // ((X << A) | Y) & Mask -> X << A, + // if Mask = ((1 << effective_width_of(X)) - 1) << A + // SimplifyDemandedBits in InstCombine can optimize the general case. + // This pattern aims to help other passes for a common case. + Value *Y, *XShifted; + if (match(Op1, m_APInt(Mask)) && + match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)), + m_Value(XShifted)), + m_Value(Y)))) { + const unsigned Width = Op0->getType()->getScalarSizeInBits(); + const unsigned ShftCnt = ShAmt->getLimitedValue(Width); + const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); + if (EffWidthY <= ShftCnt) { + const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, + Q.DT); + const unsigned EffWidthX = Width - XKnown.countMinLeadingZeros(); + const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); + const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; + // If the mask is extracting all bits from X or Y as is, we can skip + // this AND op. + if (EffBitsY.isSubsetOf(*Mask) && !EffBitsX.intersects(*Mask)) + return Y; + if (EffBitsX.isSubsetOf(*Mask) && !EffBitsY.intersects(*Mask)) + return XShifted; + } + } + + return nullptr; +} + +Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for an Or, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) + return C; + + // X | undef -> -1 + // X | -1 = -1 + // Do not return Op1 because it may contain undef elements if it's a vector. + if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) + return Constant::getAllOnesValue(Op0->getType()); + + // X | X = X + // X | 0 = X + if (Op0 == Op1 || match(Op1, m_Zero())) + return Op0; + + // A | ~A = ~A | A = -1 + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) + return Constant::getAllOnesValue(Op0->getType()); + + // (A & ?) | A = A + if (match(Op0, m_c_And(m_Specific(Op1), m_Value()))) + return Op1; + + // A | (A & ?) = A + if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) + return Op0; + + // ~(A & ?) | A = -1 + if (match(Op0, m_Not(m_c_And(m_Specific(Op1), m_Value())))) + return Constant::getAllOnesValue(Op1->getType()); + + // A | ~(A & ?) = -1 + if (match(Op1, m_Not(m_c_And(m_Specific(Op1), m_Value())))) + return Constant::getAllOnesValue(Op0->getType()); + + Value *A, *B; + // (A & ~B) | (A ^ B) -> (A ^ B) + // (~B & A) | (A ^ B) -> (A ^ B) + // (A & ~B) | (B ^ A) -> (B ^ A) + // (~B & A) | (B ^ A) -> (B ^ A) + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // Commute the 'or' operands. + // (A ^ B) | (A & ~B) -> (A ^ B) + // (A ^ B) | (~B & A) -> (A ^ B) + // (B ^ A) | (A & ~B) -> (B ^ A) + // (B ^ A) | (~B & A) -> (B ^ A) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + (match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + // (A & B) | (~A ^ B) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (B ^ ~A) + // (B & A) | (B ^ ~A) -> (B ^ ~A) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // (~A ^ B) | (A & B) -> (~A ^ B) + // (~A ^ B) | (B & A) -> (~A ^ B) + // (B ^ ~A) | (A & B) -> (B ^ ~A) + // (B ^ ~A) | (B & A) -> (B ^ ~A) + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) + return V; + + // If we have a multiplication overflow check that is being 'and'ed with a + // check that one of the multipliers is not zero, we can omit the 'and', and + // only keep the overflow check. + if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1)) + return V; + if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0)) + return V; + + // Try some generic simplifications for associative operations. + if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, + MaxRecurse)) + return V; + + // Or distributes over And. Try some generic simplifications based on this. + if (Value *V = ExpandBinOp(Instruction::Or, Op0, Op1, Instruction::And, Q, + MaxRecurse)) + return V; + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, + MaxRecurse)) + return V; + + // (A & C1)|(B & C2) + const APInt *C1, *C2; + if (match(Op0, m_And(m_Value(A), m_APInt(C1))) && + match(Op1, m_And(m_Value(B), m_APInt(C2)))) { + if (*C1 == ~*C2) { + // (A & C1)|(B & C2) + // If we have: ((V + N) & C1) | (V & C2) + // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 + // replace with V+N. + Value *N; + if (C2->isMask() && // C2 == 0+1+ + match(A, m_c_Add(m_Specific(B), m_Value(N)))) { + // Add commutes, try both ways. + if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return A; + } + // Or commutes, try both ways. + if (C1->isMask() && + match(B, m_c_Add(m_Specific(A), m_Value(N)))) { + // Add commutes, try both ways. + if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return B; + } + } + } + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = ThreadBinOpOverPHI(Instruction::Or, Op0, Op1, Q, MaxRecurse)) + return V; + + return nullptr; +} + +Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a Xor, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) + return C; + + // A ^ undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + // A ^ 0 = A + if (match(Op1, m_Zero())) + return Op0; + + // A ^ A = 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // A ^ ~A = ~A ^ A = -1 + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) + return Constant::getAllOnesValue(Op0->getType()); + + // Try some generic simplifications for associative operations. + if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, + MaxRecurse)) + return V; + + // Threading Xor over selects and phi nodes is pointless, so don't bother. + // Threading over the select in "A ^ select(cond, B, C)" means evaluating + // "A^B" and "A^C" and seeing if they are equal; but they are equal if and + // only if B and C are equal. If B and C are equal then (since we assume + // that operands have already been simplified) "select(cond, B, C)" should + // have been simplified to the common value of B and C already. Analysing + // "A^B" and "A^C" thus gains nothing, but costs compile time. Similarly + // for threading over phi nodes. + + return nullptr; +} + +Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); +} + + +static Type *GetCompareTy(Value *Op) { + return CmpInst::makeCmpResultType(Op->getType()); +} + +/// Rummage around inside V looking for something equivalent to the comparison +/// "LHS Pred RHS". Return such a value if found, otherwise return null. +/// Helper function for analyzing max/min idioms. +static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, + Value *LHS, Value *RHS) { + SelectInst *SI = dyn_cast<SelectInst>(V); + if (!SI) + return nullptr; + CmpInst *Cmp = dyn_cast<CmpInst>(SI->getCondition()); + if (!Cmp) + return nullptr; + Value *CmpLHS = Cmp->getOperand(0), *CmpRHS = Cmp->getOperand(1); + if (Pred == Cmp->getPredicate() && LHS == CmpLHS && RHS == CmpRHS) + return Cmp; + if (Pred == CmpInst::getSwappedPredicate(Cmp->getPredicate()) && + LHS == CmpRHS && RHS == CmpLHS) + return Cmp; + return nullptr; +} + +// A significant optimization not implemented here is assuming that alloca +// addresses are not equal to incoming argument values. They don't *alias*, +// as we say, but that doesn't mean they aren't equal, so we take a +// conservative approach. +// +// This is inspired in part by C++11 5.10p1: +// "Two pointers of the same type compare equal if and only if they are both +// null, both point to the same function, or both represent the same +// address." +// +// This is pretty permissive. +// +// It's also partly due to C11 6.5.9p6: +// "Two pointers compare equal if and only if both are null pointers, both are +// pointers to the same object (including a pointer to an object and a +// subobject at its beginning) or function, both are pointers to one past the +// last element of the same array object, or one is a pointer to one past the +// end of one array object and the other is a pointer to the start of a +// different array object that happens to immediately follow the first array +// object in the address space.) +// +// C11's version is more restrictive, however there's no reason why an argument +// couldn't be a one-past-the-end value for a stack object in the caller and be +// equal to the beginning of a stack object in the callee. +// +// If the C and C++ standards are ever made sufficiently restrictive in this +// area, it may be possible to update LLVM's semantics accordingly and reinstate +// this optimization. +static Constant * +computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, + const DominatorTree *DT, CmpInst::Predicate Pred, + AssumptionCache *AC, const Instruction *CxtI, + const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) { + // First, skip past any trivial no-ops. + LHS = LHS->stripPointerCasts(); + RHS = RHS->stripPointerCasts(); + + // A non-null pointer is not equal to a null pointer. + if (llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, + IIQ.UseInstrInfo) && + isa<ConstantPointerNull>(RHS) && + (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)) + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); + + // We can only fold certain predicates on pointer comparisons. + switch (Pred) { + default: + return nullptr; + + // Equality comaprisons are easy to fold. + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: + break; + + // We can only handle unsigned relational comparisons because 'inbounds' on + // a GEP only protects against unsigned wrapping. + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: + // However, we have to switch them to their signed variants to handle + // negative indices from the base pointer. + Pred = ICmpInst::getSignedPredicate(Pred); + break; + } + + // Strip off any constant offsets so that we can reason about them. + // It's tempting to use getUnderlyingObject or even just stripInBoundsOffsets + // here and compare base addresses like AliasAnalysis does, however there are + // numerous hazards. AliasAnalysis and its utilities rely on special rules + // governing loads and stores which don't apply to icmps. Also, AliasAnalysis + // doesn't need to guarantee pointer inequality when it says NoAlias. + Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); + Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + + // If LHS and RHS are related via constant offsets to the same base + // value, we can replace it with an icmp which just compares the offsets. + if (LHS == RHS) + return ConstantExpr::getICmp(Pred, LHSOffset, RHSOffset); + + // Various optimizations for (in)equality comparisons. + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) { + // Different non-empty allocations that exist at the same time have + // different addresses (if the program can tell). Global variables always + // exist, so they always exist during the lifetime of each other and all + // allocas. Two different allocas usually have different addresses... + // + // However, if there's an @llvm.stackrestore dynamically in between two + // allocas, they may have the same address. It's tempting to reduce the + // scope of the problem by only looking at *static* allocas here. That would + // cover the majority of allocas while significantly reducing the likelihood + // of having an @llvm.stackrestore pop up in the middle. However, it's not + // actually impossible for an @llvm.stackrestore to pop up in the middle of + // an entry block. Also, if we have a block that's not attached to a + // function, we can't tell if it's "static" under the current definition. + // Theoretically, this problem could be fixed by creating a new kind of + // instruction kind specifically for static allocas. Such a new instruction + // could be required to be at the top of the entry block, thus preventing it + // from being subject to a @llvm.stackrestore. Instcombine could even + // convert regular allocas into these special allocas. It'd be nifty. + // However, until then, this problem remains open. + // + // So, we'll assume that two non-empty allocas have different addresses + // for now. + // + // With all that, if the offsets are within the bounds of their allocations + // (and not one-past-the-end! so we can't use inbounds!), and their + // allocations aren't the same, the pointers are not equal. + // + // Note that it's not necessary to check for LHS being a global variable + // address, due to canonicalization and constant folding. + if (isa<AllocaInst>(LHS) && + (isa<AllocaInst>(RHS) || isa<GlobalVariable>(RHS))) { + ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset); + ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset); + uint64_t LHSSize, RHSSize; + ObjectSizeOpts Opts; + Opts.NullIsUnknownSize = + NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction()); + if (LHSOffsetCI && RHSOffsetCI && + getObjectSize(LHS, LHSSize, DL, TLI, Opts) && + getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { + const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); + const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); + if (!LHSOffsetValue.isNegative() && + !RHSOffsetValue.isNegative() && + LHSOffsetValue.ult(LHSSize) && + RHSOffsetValue.ult(RHSSize)) { + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); + } + } + + // Repeat the above check but this time without depending on DataLayout + // or being able to compute a precise size. + if (!cast<PointerType>(LHS->getType())->isEmptyTy() && + !cast<PointerType>(RHS->getType())->isEmptyTy() && + LHSOffset->isNullValue() && + RHSOffset->isNullValue()) + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); + } + + // Even if an non-inbounds GEP occurs along the path we can still optimize + // equality comparisons concerning the result. We avoid walking the whole + // chain again by starting where the last calls to + // stripAndComputeConstantOffsets left off and accumulate the offsets. + Constant *LHSNoBound = stripAndComputeConstantOffsets(DL, LHS, true); + Constant *RHSNoBound = stripAndComputeConstantOffsets(DL, RHS, true); + if (LHS == RHS) + return ConstantExpr::getICmp(Pred, + ConstantExpr::getAdd(LHSOffset, LHSNoBound), + ConstantExpr::getAdd(RHSOffset, RHSNoBound)); + + // If one side of the equality comparison must come from a noalias call + // (meaning a system memory allocation function), and the other side must + // come from a pointer that cannot overlap with dynamically-allocated + // memory within the lifetime of the current function (allocas, byval + // arguments, globals), then determine the comparison result here. + SmallVector<const Value *, 8> LHSUObjs, RHSUObjs; + GetUnderlyingObjects(LHS, LHSUObjs, DL); + GetUnderlyingObjects(RHS, RHSUObjs, DL); + + // Is the set of underlying objects all noalias calls? + auto IsNAC = [](ArrayRef<const Value *> Objects) { + return all_of(Objects, isNoAliasCall); + }; + + // Is the set of underlying objects all things which must be disjoint from + // noalias calls. For allocas, we consider only static ones (dynamic + // allocas might be transformed into calls to malloc not simultaneously + // live with the compared-to allocation). For globals, we exclude symbols + // that might be resolve lazily to symbols in another dynamically-loaded + // library (and, thus, could be malloc'ed by the implementation). + auto IsAllocDisjoint = [](ArrayRef<const Value *> Objects) { + return all_of(Objects, [](const Value *V) { + if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) + return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) + return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || + GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && + !GV->isThreadLocal(); + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasByValAttr(); + return false; + }); + }; + + if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) || + (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); + + // Fold comparisons for non-escaping pointer even if the allocation call + // cannot be elided. We cannot fold malloc comparison to null. Also, the + // dynamic allocation call could be either of the operands. + Value *MI = nullptr; + if (isAllocLikeFn(LHS, TLI) && + llvm::isKnownNonZero(RHS, DL, 0, nullptr, CxtI, DT)) + MI = LHS; + else if (isAllocLikeFn(RHS, TLI) && + llvm::isKnownNonZero(LHS, DL, 0, nullptr, CxtI, DT)) + MI = RHS; + // FIXME: We should also fold the compare when the pointer escapes, but the + // compare dominates the pointer escape + if (MI && !PointerMayBeCaptured(MI, true, true)) + return ConstantInt::get(GetCompareTy(LHS), + CmpInst::isFalseWhenEqual(Pred)); + } + + // Otherwise, fail. + return nullptr; +} + +/// Fold an icmp when its operands have i1 scalar type. +static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q) { + Type *ITy = GetCompareTy(LHS); // The return type. + Type *OpTy = LHS->getType(); // The operand type. + if (!OpTy->isIntOrIntVectorTy(1)) + return nullptr; + + // A boolean compared to true/false can be simplified in 14 out of the 20 + // (10 predicates * 2 constants) possible combinations. Cases not handled here + // require a 'not' of the LHS, so those must be transformed in InstCombine. + if (match(RHS, m_Zero())) { + switch (Pred) { + case CmpInst::ICMP_NE: // X != 0 -> X + case CmpInst::ICMP_UGT: // X >u 0 -> X + case CmpInst::ICMP_SLT: // X <s 0 -> X + return LHS; + + case CmpInst::ICMP_ULT: // X <u 0 -> false + case CmpInst::ICMP_SGT: // X >s 0 -> false + return getFalse(ITy); + + case CmpInst::ICMP_UGE: // X >=u 0 -> true + case CmpInst::ICMP_SLE: // X <=s 0 -> true + return getTrue(ITy); + + default: break; + } + } else if (match(RHS, m_One())) { + switch (Pred) { + case CmpInst::ICMP_EQ: // X == 1 -> X + case CmpInst::ICMP_UGE: // X >=u 1 -> X + case CmpInst::ICMP_SLE: // X <=s -1 -> X + return LHS; + + case CmpInst::ICMP_UGT: // X >u 1 -> false + case CmpInst::ICMP_SLT: // X <s -1 -> false + return getFalse(ITy); + + case CmpInst::ICMP_ULE: // X <=u 1 -> true + case CmpInst::ICMP_SGE: // X >=s -1 -> true + return getTrue(ITy); + + default: break; + } + } + + switch (Pred) { + default: + break; + case ICmpInst::ICMP_UGE: + if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SGE: + /// For signed comparison, the values for an i1 are 0 and -1 + /// respectively. This maps into a truth table of: + /// LHS | RHS | LHS >=s RHS | LHS implies RHS + /// 0 | 0 | 1 (0 >= 0) | 1 + /// 0 | 1 | 1 (0 >= -1) | 1 + /// 1 | 0 | 0 (-1 >= 0) | 0 + /// 1 | 1 | 1 (-1 >= -1) | 1 + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_ULE: + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; + } + + return nullptr; +} + +/// Try hard to fold icmp with zero RHS because this is a common case. +static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q) { + if (!match(RHS, m_Zero())) + return nullptr; + + Type *ITy = GetCompareTy(LHS); // The return type. + switch (Pred) { + default: + llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_ULT: + return getFalse(ITy); + case ICmpInst::ICMP_UGE: + return getTrue(ITy); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULE: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) + return getFalse(ITy); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SLT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) + return getTrue(ITy); + if (LHSKnown.isNonNegative()) + return getFalse(ITy); + break; + } + case ICmpInst::ICMP_SLE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) + return getTrue(ITy); + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getFalse(ITy); + break; + } + case ICmpInst::ICMP_SGE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) + return getFalse(ITy); + if (LHSKnown.isNonNegative()) + return getTrue(ITy); + break; + } + case ICmpInst::ICMP_SGT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) + return getFalse(ITy); + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getTrue(ITy); + break; + } + } + + return nullptr; +} + +static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const InstrInfoQuery &IIQ) { + Type *ITy = GetCompareTy(RHS); // The return type. + + Value *X; + // Sign-bit checks can be optimized to true/false after unsigned + // floating-point casts: + // icmp slt (bitcast (uitofp X)), 0 --> false + // icmp sgt (bitcast (uitofp X)), -1 --> true + if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { + if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) + return ConstantInt::getFalse(ITy); + if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) + return ConstantInt::getTrue(ITy); + } + + const APInt *C; + if (!match(RHS, m_APInt(C))) + return nullptr; + + // Rule out tautological comparisons (eg., ult 0 or uge 0). + ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); + if (RHS_CR.isEmptySet()) + return ConstantInt::getFalse(ITy); + if (RHS_CR.isFullSet()) + return ConstantInt::getTrue(ITy); + + ConstantRange LHS_CR = computeConstantRange(LHS, IIQ.UseInstrInfo); + if (!LHS_CR.isFullSet()) { + if (RHS_CR.contains(LHS_CR)) + return ConstantInt::getTrue(ITy); + if (RHS_CR.inverse().contains(LHS_CR)) + return ConstantInt::getFalse(ITy); + } + + return nullptr; +} + +/// TODO: A large part of this logic is duplicated in InstCombine's +/// foldICmpBinOp(). We should be able to share that and avoid the code +/// duplication. +static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. + + BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); + BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); + if (MaxRecurse && (LBO || RBO)) { + // Analyze the case when either LHS or RHS is an add instruction. + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null). + bool NoLHSWrapProblem = false, NoRHSWrapProblem = false; + if (LBO && LBO->getOpcode() == Instruction::Add) { + A = LBO->getOperand(0); + B = LBO->getOperand(1); + NoLHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO))) || + (CmpInst::isSigned(Pred) && + Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO))); + } + if (RBO && RBO->getOpcode() == Instruction::Add) { + C = RBO->getOperand(0); + D = RBO->getOperand(1); + NoRHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(RBO))) || + (CmpInst::isSigned(Pred) && + Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(RBO))); + } + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == RHS || B == RHS) && NoLHSWrapProblem) + if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, + Constant::getNullValue(RHS->getType()), Q, + MaxRecurse - 1)) + return V; + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == LHS || D == LHS) && NoRHSWrapProblem) + if (Value *V = + SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), + C == LHS ? D : C, Q, MaxRecurse - 1)) + return V; + + // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && + NoRHSWrapProblem) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y, *Z; + if (A == C) { + // C + B == C + D -> B == D + Y = B; + Z = D; + } else if (A == D) { + // D + B == C + D -> B == C + Y = B; + Z = C; + } else if (B == C) { + // A + C == C + D -> A == D + Y = A; + Z = D; + } else { + assert(B == D); + // A + D == C + D -> A == C + Y = A; + Z = C; + } + if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) + return V; + } + } + + { + Value *Y = nullptr; + // icmp pred (or X, Y), X + if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { + KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (RHSKnown.isNonNegative() && YKnown.isNegative()) + return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); + if (RHSKnown.isNegative() || YKnown.isNonNegative()) + return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); + } + } + // icmp pred X, (or X, Y) + if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) { + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNonNegative() && YKnown.isNegative()) + return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); + if (LHSKnown.isNegative() || YKnown.isNonNegative()) + return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); + } + } + } + + // icmp pred (and X, Y), X + if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + // icmp pred X, (and X, Y) + if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) { + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + } + + // 0 - (zext X) pred C + if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) { + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { + if (RHSC->getValue().isStrictlyPositive()) { + if (Pred == ICmpInst::ICMP_SLT) + return ConstantInt::getTrue(RHSC->getContext()); + if (Pred == ICmpInst::ICMP_SGE) + return ConstantInt::getFalse(RHSC->getContext()); + if (Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(RHSC->getContext()); + if (Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(RHSC->getContext()); + } + if (RHSC->getValue().isNonNegative()) { + if (Pred == ICmpInst::ICMP_SLE) + return ConstantInt::getTrue(RHSC->getContext()); + if (Pred == ICmpInst::ICMP_SGT) + return ConstantInt::getFalse(RHSC->getContext()); + } + } + } + + // icmp pred (urem X, Y), Y + if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return getFalse(ITy); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return getTrue(ITy); + } + } + + // icmp pred X, (urem Y, X) + if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return getTrue(ITy); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return getFalse(ITy); + } + } + + // x >> y <=u x + // x udiv y <=u x. + if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || + match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) { + // icmp pred (X op Y), X + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + // x >=u x >> y + // x >=u x udiv y. + if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || + match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { + // icmp pred X, (X op Y) + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + } + + // handle: + // CI2 << X == CI + // CI2 << X != CI + // + // where CI2 is a power of 2 and CI isn't + if (auto *CI = dyn_cast<ConstantInt>(RHS)) { + const APInt *CI2Val, *CIVal = &CI->getValue(); + if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && + CI2Val->isPowerOf2()) { + if (!CIVal->isPowerOf2()) { + // CI2 << X can equal zero in some circumstances, + // this simplification is unsafe if CI is zero. + // + // We know it is safe if: + // - The shift is nsw, we can't shift out the one bit. + // - The shift is nuw, we can't shift out the one bit. + // - CI2 is one + // - CI isn't zero + if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || + CI2Val->isOneValue() || !CI->isZero()) { + if (Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + if (CIVal->isSignMask() && CI2Val->isOneValue()) { + if (Pred == ICmpInst::ICMP_UGT) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_ULE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + } + + if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && + LBO->getOperand(1) == RBO->getOperand(1)) { + switch (LBO->getOpcode()) { + default: + break; + case Instruction::UDiv: + case Instruction::LShr: + if (ICmpInst::isSigned(Pred) || !Q.IIQ.isExact(LBO) || + !Q.IIQ.isExact(RBO)) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; + case Instruction::SDiv: + if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) || + !Q.IIQ.isExact(RBO)) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; + case Instruction::AShr: + if (!Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO)) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; + case Instruction::Shl: { + bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO); + bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO); + if (!NUW && !NSW) + break; + if (!NSW && ICmpInst::isSigned(Pred)) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; + } + } + } + return nullptr; +} + +/// Simplify integer comparisons where at least one operand of the compare +/// matches an integer min/max idiom. +static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. + Value *A, *B; + CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; + CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". + + // Signed variants on "max(a,b)>=a -> true". + if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { + if (A != RHS) + std::swap(A, B); // smax(A, B) pred A. + EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". + // We analyze this as smax(A, B) pred A. + P = Pred; + } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) + std::swap(A, B); // A pred smax(A, B). + EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". + // We analyze this as smax(A, B) swapped-pred A. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && + (A == RHS || B == RHS)) { + if (A != RHS) + std::swap(A, B); // smin(A, B) pred A. + EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". + // We analyze this as smax(-A, -B) swapped-pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) + std::swap(A, B); // A pred smin(A, B). + EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". + // We analyze this as smax(-A, -B) pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = Pred; + } + if (P != CmpInst::BAD_ICMP_PREDICATE) { + // Cases correspond to "max(A, B) p A". + switch (P) { + default: + break; + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_SLE: + // Equivalent to "A EqP B". This may be the same as the condition tested + // in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + return V; + // Otherwise, see if "A EqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) + return V; + break; + case CmpInst::ICMP_NE: + case CmpInst::ICMP_SGT: { + CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); + // Equivalent to "A InvEqP B". This may be the same as the condition + // tested in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + return V; + // Otherwise, see if "A InvEqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) + return V; + break; + } + case CmpInst::ICMP_SGE: + // Always true. + return getTrue(ITy); + case CmpInst::ICMP_SLT: + // Always false. + return getFalse(ITy); + } + } + + // Unsigned variants on "max(a,b)>=a -> true". + P = CmpInst::BAD_ICMP_PREDICATE; + if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { + if (A != RHS) + std::swap(A, B); // umax(A, B) pred A. + EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". + // We analyze this as umax(A, B) pred A. + P = Pred; + } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) + std::swap(A, B); // A pred umax(A, B). + EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". + // We analyze this as umax(A, B) swapped-pred A. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && + (A == RHS || B == RHS)) { + if (A != RHS) + std::swap(A, B); // umin(A, B) pred A. + EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". + // We analyze this as umax(-A, -B) swapped-pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) + std::swap(A, B); // A pred umin(A, B). + EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". + // We analyze this as umax(-A, -B) pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = Pred; + } + if (P != CmpInst::BAD_ICMP_PREDICATE) { + // Cases correspond to "max(A, B) p A". + switch (P) { + default: + break; + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_ULE: + // Equivalent to "A EqP B". This may be the same as the condition tested + // in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + return V; + // Otherwise, see if "A EqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) + return V; + break; + case CmpInst::ICMP_NE: + case CmpInst::ICMP_UGT: { + CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); + // Equivalent to "A InvEqP B". This may be the same as the condition + // tested in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + return V; + // Otherwise, see if "A InvEqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) + return V; + break; + } + case CmpInst::ICMP_UGE: + // Always true. + return getTrue(ITy); + case CmpInst::ICMP_ULT: + // Always false. + return getFalse(ITy); + } + } + + // Variants on "max(x,y) >= min(x,z)". + Value *C, *D; + if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && + match(RHS, m_SMin(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // max(x, ?) pred min(x, ?). + if (Pred == CmpInst::ICMP_SGE) + // Always true. + return getTrue(ITy); + if (Pred == CmpInst::ICMP_SLT) + // Always false. + return getFalse(ITy); + } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && + match(RHS, m_SMax(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // min(x, ?) pred max(x, ?). + if (Pred == CmpInst::ICMP_SLE) + // Always true. + return getTrue(ITy); + if (Pred == CmpInst::ICMP_SGT) + // Always false. + return getFalse(ITy); + } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && + match(RHS, m_UMin(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // max(x, ?) pred min(x, ?). + if (Pred == CmpInst::ICMP_UGE) + // Always true. + return getTrue(ITy); + if (Pred == CmpInst::ICMP_ULT) + // Always false. + return getFalse(ITy); + } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && + match(RHS, m_UMax(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // min(x, ?) pred max(x, ?). + if (Pred == CmpInst::ICMP_ULE) + // Always true. + return getTrue(ITy); + if (Pred == CmpInst::ICMP_UGT) + // Always false. + return getFalse(ITy); + } + + return nullptr; +} + +/// Given operands for an ICmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; + assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + + // If we have a constant, make sure it is on the RHS. + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + assert(!isa<UndefValue>(LHS) && "Unexpected icmp undef,%X"); + + Type *ITy = GetCompareTy(LHS); // The return type. + + // For EQ and NE, we can always pick a value for the undef to make the + // predicate pass or fail, so we can return undef. + // Matches behavior in llvm::ConstantFoldCompareInstruction. + if (isa<UndefValue>(RHS) && ICmpInst::isEquality(Pred)) + return UndefValue::get(ITy); + + // icmp X, X -> true/false + // icmp X, undef -> true/false because undef could be X. + if (LHS == RHS || isa<UndefValue>(RHS)) + return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); + + if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS, Q.IIQ)) + return V; + + // If both operands have range metadata, use the metadata + // to simplify the comparison. + if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { + auto RHS_Instr = cast<Instruction>(RHS); + auto LHS_Instr = cast<Instruction>(LHS); + + if (Q.IIQ.getMetadata(RHS_Instr, LLVMContext::MD_range) && + Q.IIQ.getMetadata(LHS_Instr, LLVMContext::MD_range)) { + auto RHS_CR = getConstantRangeFromMetadata( + *RHS_Instr->getMetadata(LLVMContext::MD_range)); + auto LHS_CR = getConstantRangeFromMetadata( + *LHS_Instr->getMetadata(LLVMContext::MD_range)); + + auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); + if (Satisfied_CR.contains(LHS_CR)) + return ConstantInt::getTrue(RHS->getContext()); + + auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( + CmpInst::getInversePredicate(Pred), RHS_CR); + if (InversedSatisfied_CR.contains(LHS_CR)) + return ConstantInt::getFalse(RHS->getContext()); + } + } + + // Compare of cast, for example (zext X) != 0 -> X != 0 + if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { + Instruction *LI = cast<CastInst>(LHS); + Value *SrcOp = LI->getOperand(0); + Type *SrcTy = SrcOp->getType(); + Type *DstTy = LI->getType(); + + // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input + // if the integer type is the same size as the pointer type. + if (MaxRecurse && isa<PtrToIntInst>(LI) && + Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { + if (Constant *RHSC = dyn_cast<Constant>(RHS)) { + // Transfer the cast to the constant. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, + ConstantExpr::getIntToPtr(RHSC, SrcTy), + Q, MaxRecurse-1)) + return V; + } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { + if (RI->getOperand(0)->getType() == SrcTy) + // Compare without the cast. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + } + + if (isa<ZExtInst>(LHS)) { + // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the + // same type. + if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that signed predicates become unsigned. + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, RI->getOperand(0), Q, + MaxRecurse-1)) + return V; + } + // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two zero-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit + // there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + // LHS <u RHS. + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return ConstantInt::getTrue(CI->getContext()); + + // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS + // is non-negative then LHS <s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + } + } + } + } + + if (isa<SExtInst>(LHS)) { + // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the + // same type. + if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that the predicate does not change. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two sign-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are all equal, while RHS has varying + // bits there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_EQ: + return ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_NE: + return ConstantInt::getTrue(CI->getContext()); + + // If RHS is non-negative then LHS <s RHS. If RHS is negative then + // LHS >s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + + // If LHS is non-negative then LHS <u RHS. If LHS is negative then + // LHS >u RHS. + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // Comparison is true iff the LHS <s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + // Comparison is true iff the LHS >=s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + } + } + } + } + } + + // icmp eq|ne X, Y -> false|true if X != Y + if (ICmpInst::isEquality(Pred) && + isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) { + return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy); + } + + if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + // Simplify comparisons of related pointers using a powerful, recursive + // GEP-walk when we have target data available.. + if (LHS->getType()->isPointerTy()) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, + Q.IIQ, LHS, RHS)) + return C; + if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) + if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) + if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CLHS->getType()) && + Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CRHS->getType())) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, + Q.IIQ, CLHS->getPointerOperand(), + CRHS->getPointerOperand())) + return C; + + if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { + if (GEPOperator *GRHS = dyn_cast<GEPOperator>(RHS)) { + if (GLHS->getPointerOperand() == GRHS->getPointerOperand() && + GLHS->hasAllConstantIndices() && GRHS->hasAllConstantIndices() && + (ICmpInst::isEquality(Pred) || + (GLHS->isInBounds() && GRHS->isInBounds() && + Pred == ICmpInst::getSignedPredicate(Pred)))) { + // The bases are equal and the indices are constant. Build a constant + // expression GEP with the same indices and a null base pointer to see + // what constant folding can make out of it. + Constant *Null = Constant::getNullValue(GLHS->getPointerOperandType()); + SmallVector<Value *, 4> IndicesLHS(GLHS->idx_begin(), GLHS->idx_end()); + Constant *NewLHS = ConstantExpr::getGetElementPtr( + GLHS->getSourceElementType(), Null, IndicesLHS); + + SmallVector<Value *, 4> IndicesRHS(GRHS->idx_begin(), GRHS->idx_end()); + Constant *NewRHS = ConstantExpr::getGetElementPtr( + GLHS->getSourceElementType(), Null, IndicesRHS); + return ConstantExpr::getICmp(Pred, NewLHS, NewRHS); + } + } + } + + // If the comparison is with the result of a select instruction, check whether + // comparing with either branch of the select always yields the same value. + if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) + if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + // If the comparison is with the result of a phi instruction, check whether + // doing the compare with each incoming phi value yields a common result. + if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) + if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + return nullptr; +} + +Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); +} + +/// Given operands for an FCmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + unsigned MaxRecurse) { + CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; + assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + + // If we have a constant, make sure it is on the RHS. + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + // Fold trivial predicates. + Type *RetTy = GetCompareTy(LHS); + if (Pred == FCmpInst::FCMP_FALSE) + return getFalse(RetTy); + if (Pred == FCmpInst::FCMP_TRUE) + return getTrue(RetTy); + + // Fold (un)ordered comparison if we can determine there are no NaNs. + if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD) + if (FMF.noNaNs() || + (isKnownNeverNaN(LHS, Q.TLI) && isKnownNeverNaN(RHS, Q.TLI))) + return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD); + + // NaN is unordered; NaN is not ordered. + assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && + "Comparison must be either ordered or unordered"); + if (match(RHS, m_NaN())) + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + + // fcmp pred x, undef and fcmp pred undef, x + // fold to true if unordered, false if ordered + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { + // Choosing NaN for the undef will always make unordered comparison succeed + // and ordered comparison fail. + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + } + + // fcmp x,x -> true/false. Not all compares are foldable. + if (LHS == RHS) { + if (CmpInst::isTrueWhenEqual(Pred)) + return getTrue(RetTy); + if (CmpInst::isFalseWhenEqual(Pred)) + return getFalse(RetTy); + } + + // Handle fcmp with constant RHS. + // TODO: Use match with a specific FP value, so these work with vectors with + // undef lanes. + const APFloat *C; + if (match(RHS, m_APFloat(C))) { + // Check whether the constant is an infinity. + if (C->isInfinity()) { + if (C->isNegative()) { + switch (Pred) { + case FCmpInst::FCMP_OLT: + // No value is ordered and less than negative infinity. + return getFalse(RetTy); + case FCmpInst::FCMP_UGE: + // All values are unordered with or at least negative infinity. + return getTrue(RetTy); + default: + break; + } + } else { + switch (Pred) { + case FCmpInst::FCMP_OGT: + // No value is ordered and greater than infinity. + return getFalse(RetTy); + case FCmpInst::FCMP_ULE: + // All values are unordered with and at most infinity. + return getTrue(RetTy); + default: + break; + } + } + } + if (C->isNegative() && !C->isNegZero()) { + assert(!C->isNaN() && "Unexpected NaN constant!"); + // TODO: We can catch more cases by using a range check rather than + // relying on CannotBeOrderedLessThanZero. + switch (Pred) { + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UNE: + // (X >= 0) implies (X > C) when (C < 0) + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getTrue(RetTy); + break; + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_OLT: + // (X >= 0) implies !(X < C) when (C < 0) + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getFalse(RetTy); + break; + default: + break; + } + } + + // Check comparison of [minnum/maxnum with constant] with other constant. + const APFloat *C2; + if ((match(LHS, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_APFloat(C2))) && + C2->compare(*C) == APFloat::cmpLessThan) || + (match(LHS, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_APFloat(C2))) && + C2->compare(*C) == APFloat::cmpGreaterThan)) { + bool IsMaxNum = + cast<IntrinsicInst>(LHS)->getIntrinsicID() == Intrinsic::maxnum; + // The ordered relationship and minnum/maxnum guarantee that we do not + // have NaN constants, so ordered/unordered preds are handled the same. + switch (Pred) { + case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_UEQ: + // minnum(X, LesserC) == C --> false + // maxnum(X, GreaterC) == C --> false + return getFalse(RetTy); + case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: + // minnum(X, LesserC) != C --> true + // maxnum(X, GreaterC) != C --> true + return getTrue(RetTy); + case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: case FCmpInst::FCMP_UGT: + // minnum(X, LesserC) >= C --> false + // minnum(X, LesserC) > C --> false + // maxnum(X, GreaterC) >= C --> true + // maxnum(X, GreaterC) > C --> true + return ConstantInt::get(RetTy, IsMaxNum); + case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: case FCmpInst::FCMP_ULT: + // minnum(X, LesserC) <= C --> true + // minnum(X, LesserC) < C --> true + // maxnum(X, GreaterC) <= C --> false + // maxnum(X, GreaterC) < C --> false + return ConstantInt::get(RetTy, !IsMaxNum); + default: + // TRUE/FALSE/ORD/UNO should be handled before this. + llvm_unreachable("Unexpected fcmp predicate"); + } + } + } + + if (match(RHS, m_AnyZeroFP())) { + switch (Pred) { + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_ULT: + // Positive or zero X >= 0.0 --> true + // Positive or zero X < 0.0 --> false + if ((FMF.noNaNs() || isKnownNeverNaN(LHS, Q.TLI)) && + CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return Pred == FCmpInst::FCMP_OGE ? getTrue(RetTy) : getFalse(RetTy); + break; + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OLT: + // Positive or zero or nan X >= 0.0 --> true + // Positive or zero or nan X < 0.0 --> false + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy); + break; + default: + break; + } + } + + // If the comparison is with the result of a select instruction, check whether + // comparing with either branch of the select always yields the same value. + if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) + if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + // If the comparison is with the result of a phi instruction, check whether + // doing the compare with each incoming phi value yields a common result. + if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) + if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + return nullptr; +} + +Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); +} + +/// See if V simplifies when its operand Op is replaced with RepOp. +static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Trivial replacement. + if (V == Op) + return RepOp; + + // We cannot replace a constant, and shouldn't even try. + if (isa<Constant>(Op)) + return nullptr; + + auto *I = dyn_cast<Instruction>(V); + if (!I) + return nullptr; + + // If this is a binary operator, try to simplify it with the replaced op. + if (auto *B = dyn_cast<BinaryOperator>(I)) { + // Consider: + // %cmp = icmp eq i32 %x, 2147483647 + // %add = add nsw i32 %x, 1 + // %sel = select i1 %cmp, i32 -2147483648, i32 %add + // + // We can't replace %sel with %add unless we strip away the flags. + // TODO: This is an unusual limitation because better analysis results in + // worse simplification. InstCombine can do this fold more generally + // by dropping the flags. Remove this fold to save compile-time? + if (isa<OverflowingBinaryOperator>(B)) + if (Q.IIQ.hasNoSignedWrap(B) || Q.IIQ.hasNoUnsignedWrap(B)) + return nullptr; + if (isa<PossiblyExactOperator>(B) && Q.IIQ.isExact(B)) + return nullptr; + + if (MaxRecurse) { + if (B->getOperand(0) == Op) + return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q, + MaxRecurse - 1); + if (B->getOperand(1) == Op) + return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, Q, + MaxRecurse - 1); + } + } + + // Same for CmpInsts. + if (CmpInst *C = dyn_cast<CmpInst>(I)) { + if (MaxRecurse) { + if (C->getOperand(0) == Op) + return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), Q, + MaxRecurse - 1); + if (C->getOperand(1) == Op) + return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, Q, + MaxRecurse - 1); + } + } + + // Same for GEPs. + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + if (MaxRecurse) { + SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); + transform(GEP->operands(), NewOps.begin(), + [&](Value *V) { return V == Op ? RepOp : V; }); + return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, + MaxRecurse - 1); + } + } + + // TODO: We could hand off more cases to instsimplify here. + + // If all operands are constant after substituting Op for RepOp then we can + // constant fold the instruction. + if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) { + // Build a list of all constant operands. + SmallVector<Constant *, 8> ConstOps; + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (I->getOperand(i) == Op) + ConstOps.push_back(CRepOp); + else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i))) + ConstOps.push_back(COp); + else + break; + } + + // All operands were constants, fold it. + if (ConstOps.size() == I->getNumOperands()) { + if (CmpInst *C = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], + ConstOps[1], Q.DL, Q.TLI); + + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + if (!LI->isVolatile()) + return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); + + return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); + } + } + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison where one operand of the compare is a constant. +static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, + const APInt *Y, bool TrueWhenUnset) { + const APInt *C; + + // (X & Y) == 0 ? X & ~Y : X --> X + // (X & Y) != 0 ? X & ~Y : X --> X & ~Y + if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + // (X & Y) == 0 ? X : X & ~Y --> X & ~Y + // (X & Y) != 0 ? X : X & ~Y --> X + if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + if (Y->isPowerOf2()) { + // (X & Y) == 0 ? X | Y : X --> X | Y + // (X & Y) != 0 ? X | Y : X --> X + if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + + // (X & Y) == 0 ? X : X | Y --> X + // (X & Y) != 0 ? X : X | Y --> X | Y + if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + } + + return nullptr; +} + +/// An alternative way to test if a bit is set or not uses sgt/slt instead of +/// eq/ne. +static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, + ICmpInst::Predicate Pred, + Value *TrueVal, Value *FalseVal) { + Value *X; + APInt Mask; + if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask)) + return nullptr; + + return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask, + Pred == ICmpInst::ICMP_EQ); +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison. +static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, + Value *FalseVal, const SimplifyQuery &Q, + unsigned MaxRecurse) { + ICmpInst::Predicate Pred; + Value *CmpLHS, *CmpRHS; + if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) + return nullptr; + + if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { + Value *X; + const APInt *Y; + if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, + Pred == ICmpInst::ICMP_EQ)) + return V; + + // Test for a bogus zero-shift-guard-op around funnel-shift or rotate. + Value *ShAmt; + auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), + m_Value(ShAmt)), + m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), + m_Value(ShAmt))); + // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X + // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X + if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_EQ) + return X; + // (ShAmt != 0) ? X : fshl(X, *, ShAmt) --> X + // (ShAmt != 0) ? X : fshr(*, X, ShAmt) --> X + if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_NE) + return X; + + // Test for a zero-shift-guard-op around rotates. These are used to + // avoid UB from oversized shifts in raw IR rotate patterns, but the + // intrinsics do not have that problem. + // We do not allow this transform for the general funnel shift case because + // that would not preserve the poison safety of the original code. + auto isRotate = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), + m_Deferred(X), + m_Value(ShAmt)), + m_Intrinsic<Intrinsic::fshr>(m_Value(X), + m_Deferred(X), + m_Value(ShAmt))); + // (ShAmt != 0) ? fshl(X, X, ShAmt) : X --> fshl(X, X, ShAmt) + // (ShAmt != 0) ? fshr(X, X, ShAmt) : X --> fshr(X, X, ShAmt) + if (match(TrueVal, isRotate) && FalseVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_NE) + return TrueVal; + // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt) + // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt) + if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_EQ) + return FalseVal; + } + + // Check for other compares that behave like bit test. + if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, + TrueVal, FalseVal)) + return V; + + // If we have an equality comparison, then we know the value in one of the + // arms of the select. See if substituting this value into the arm and + // simplifying the result yields the same value as the other arm. + if (Pred == ICmpInst::ICMP_EQ) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return FalseVal; + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return FalseVal; + } else if (Pred == ICmpInst::ICMP_NE) { + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return TrueVal; + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return TrueVal; + } + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is a +/// floating-point comparison. +static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F) { + FCmpInst::Predicate Pred; + if (!match(Cond, m_FCmp(Pred, m_Specific(T), m_Specific(F))) && + !match(Cond, m_FCmp(Pred, m_Specific(F), m_Specific(T)))) + return nullptr; + + // TODO: The transform may not be valid with -0.0. An incomplete way of + // testing for that possibility is to check if at least one operand is a + // non-zero constant. + const APFloat *C; + if ((match(T, m_APFloat(C)) && C->isNonZero()) || + (match(F, m_APFloat(C)) && C->isNonZero())) { + // (T == F) ? T : F --> F + // (F == T) ? T : F --> F + if (Pred == FCmpInst::FCMP_OEQ) + return F; + + // (T != F) ? T : F --> T + // (F != T) ? T : F --> T + if (Pred == FCmpInst::FCMP_UNE) + return T; + } + + return nullptr; +} + +/// Given operands for a SelectInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (auto *CondC = dyn_cast<Constant>(Cond)) { + if (auto *TrueC = dyn_cast<Constant>(TrueVal)) + if (auto *FalseC = dyn_cast<Constant>(FalseVal)) + return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + + // select undef, X, Y -> X or Y + if (isa<UndefValue>(CondC)) + return isa<Constant>(FalseVal) ? FalseVal : TrueVal; + + // TODO: Vector constants with undef elements don't simplify. + + // select true, X, Y -> X + if (CondC->isAllOnesValue()) + return TrueVal; + // select false, X, Y -> Y + if (CondC->isNullValue()) + return FalseVal; + } + + // select ?, X, X -> X + if (TrueVal == FalseVal) + return TrueVal; + + if (isa<UndefValue>(TrueVal)) // select ?, undef, X -> X + return FalseVal; + if (isa<UndefValue>(FalseVal)) // select ?, X, undef -> X + return TrueVal; + + if (Value *V = + simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) + return V; + + if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal)) + return V; + + if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal)) + return V; + + Optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL); + if (Imp) + return *Imp ? TrueVal : FalseVal; + + return nullptr; +} + +Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); +} + +/// Given operands for an GetElementPtrInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, + const SimplifyQuery &Q, unsigned) { + // The type of the GEP pointer operand. + unsigned AS = + cast<PointerType>(Ops[0]->getType()->getScalarType())->getAddressSpace(); + + // getelementptr P -> P. + if (Ops.size() == 1) + return Ops[0]; + + // Compute the (pointer) type returned by the GEP instruction. + Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1)); + Type *GEPTy = PointerType::get(LastType, AS); + if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + + if (isa<UndefValue>(Ops[0])) + return UndefValue::get(GEPTy); + + if (Ops.size() == 2) { + // getelementptr P, 0 -> P. + if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy) + return Ops[0]; + + Type *Ty = SrcTy; + if (Ty->isSized()) { + Value *P; + uint64_t C; + uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty); + // getelementptr P, N -> P if P points to a type of zero size. + if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy) + return Ops[0]; + + // The following transforms are only safe if the ptrtoint cast + // doesn't truncate the pointers. + if (Ops[1]->getType()->getScalarSizeInBits() == + Q.DL.getIndexSizeInBits(AS)) { + auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { + if (match(P, m_Zero())) + return Constant::getNullValue(GEPTy); + Value *Temp; + if (match(P, m_PtrToInt(m_Value(Temp)))) + if (Temp->getType() == GEPTy) + return Temp; + return nullptr; + }; + + // getelementptr V, (sub P, V) -> P if P points to a type of size 1. + if (TyAllocSize == 1 && + match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (ashr (sub P, V), C) -> Q + // if P points to a type of size 1 << C. + if (match(Ops[1], + m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_ConstantInt(C))) && + TyAllocSize == 1ULL << C) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (sdiv (sub P, V), C) -> Q + // if P points to a type of size C. + if (match(Ops[1], + m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_SpecificInt(TyAllocSize)))) + if (Value *R = PtrToIntOrZero(P)) + return R; + } + } + } + + if (Q.DL.getTypeAllocSize(LastType) == 1 && + all_of(Ops.slice(1).drop_back(1), + [](Value *Idx) { return match(Idx, m_Zero()); })) { + unsigned IdxWidth = + Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); + if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { + APInt BasePtrOffset(IdxWidth, 0); + Value *StrippedBasePtr = + Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, + BasePtrOffset); + + // gep (gep V, C), (sub 0, V) -> C + if (match(Ops.back(), + m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + // gep (gep V, C), (xor V, -1) -> C-1 + if (match(Ops.back(), + m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + } + } + + // Check to see if this is constant foldable. + if (!all_of(Ops, [](Value *V) { return isa<Constant>(V); })) + return nullptr; + + auto *CE = ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ops[0]), + Ops.slice(1)); + if (auto *CEFolded = ConstantFoldConstant(CE, Q.DL)) + return CEFolded; + return CE; +} + +Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, + const SimplifyQuery &Q) { + return ::SimplifyGEPInst(SrcTy, Ops, Q, RecursionLimit); +} + +/// Given operands for an InsertValueInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, + ArrayRef<unsigned> Idxs, const SimplifyQuery &Q, + unsigned) { + if (Constant *CAgg = dyn_cast<Constant>(Agg)) + if (Constant *CVal = dyn_cast<Constant>(Val)) + return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); + + // insertvalue x, undef, n -> x + if (match(Val, m_Undef())) + return Agg; + + // insertvalue x, (extractvalue y, n), n + if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Val)) + if (EV->getAggregateOperand()->getType() == Agg->getType() && + EV->getIndices() == Idxs) { + // insertvalue undef, (extractvalue y, n), n -> y + if (match(Agg, m_Undef())) + return EV->getAggregateOperand(); + + // insertvalue y, (extractvalue y, n), n -> y + if (Agg == EV->getAggregateOperand()) + return Agg; + } + + return nullptr; +} + +Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, + ArrayRef<unsigned> Idxs, + const SimplifyQuery &Q) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); +} + +Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, + const SimplifyQuery &Q) { + // Try to constant fold. + auto *VecC = dyn_cast<Constant>(Vec); + auto *ValC = dyn_cast<Constant>(Val); + auto *IdxC = dyn_cast<Constant>(Idx); + if (VecC && ValC && IdxC) + return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC); + + // Fold into undef if index is out of bounds. + if (auto *CI = dyn_cast<ConstantInt>(Idx)) { + uint64_t NumElements = cast<VectorType>(Vec->getType())->getNumElements(); + if (CI->uge(NumElements)) + return UndefValue::get(Vec->getType()); + } + + // If index is undef, it might be out of bounds (see above case) + if (isa<UndefValue>(Idx)) + return UndefValue::get(Vec->getType()); + + // Inserting an undef scalar? Assume it is the same value as the existing + // vector element. + if (isa<UndefValue>(Val)) + return Vec; + + // If we are extracting a value from a vector, then inserting it into the same + // place, that's the input vector: + // insertelt Vec, (extractelt Vec, Idx), Idx --> Vec + if (match(Val, m_ExtractElement(m_Specific(Vec), m_Specific(Idx)))) + return Vec; + + return nullptr; +} + +/// Given operands for an ExtractValueInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, + const SimplifyQuery &, unsigned) { + if (auto *CAgg = dyn_cast<Constant>(Agg)) + return ConstantFoldExtractValueInstruction(CAgg, Idxs); + + // extractvalue x, (insertvalue y, elt, n), n -> elt + unsigned NumIdxs = Idxs.size(); + for (auto *IVI = dyn_cast<InsertValueInst>(Agg); IVI != nullptr; + IVI = dyn_cast<InsertValueInst>(IVI->getAggregateOperand())) { + ArrayRef<unsigned> InsertValueIdxs = IVI->getIndices(); + unsigned NumInsertValueIdxs = InsertValueIdxs.size(); + unsigned NumCommonIdxs = std::min(NumInsertValueIdxs, NumIdxs); + if (InsertValueIdxs.slice(0, NumCommonIdxs) == + Idxs.slice(0, NumCommonIdxs)) { + if (NumIdxs == NumInsertValueIdxs) + return IVI->getInsertedValueOperand(); + break; + } + } + + return nullptr; +} + +Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, + const SimplifyQuery &Q) { + return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); +} + +/// Given operands for an ExtractElementInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, + unsigned) { + if (auto *CVec = dyn_cast<Constant>(Vec)) { + if (auto *CIdx = dyn_cast<Constant>(Idx)) + return ConstantFoldExtractElementInstruction(CVec, CIdx); + + // The index is not relevant if our vector is a splat. + if (auto *Splat = CVec->getSplatValue()) + return Splat; + + if (isa<UndefValue>(Vec)) + return UndefValue::get(Vec->getType()->getVectorElementType()); + } + + // If extracting a specified index from the vector, see if we can recursively + // find a previously computed scalar that was inserted into the vector. + if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { + if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) + // definitely out of bounds, thus undefined result + return UndefValue::get(Vec->getType()->getVectorElementType()); + if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) + return Elt; + } + + // An undef extract index can be arbitrarily chosen to be an out-of-range + // index value, which would result in the instruction being undef. + if (isa<UndefValue>(Idx)) + return UndefValue::get(Vec->getType()->getVectorElementType()); + + return nullptr; +} + +Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, + const SimplifyQuery &Q) { + return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); +} + +/// See if we can fold the given phi. If not, returns null. +static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { + // If all of the PHI's incoming values are the same then replace the PHI node + // with the common value. + Value *CommonValue = nullptr; + bool HasUndefInput = false; + for (Value *Incoming : PN->incoming_values()) { + // If the incoming value is the phi node itself, it can safely be skipped. + if (Incoming == PN) continue; + if (isa<UndefValue>(Incoming)) { + // Remember that we saw an undef value, but otherwise ignore them. + HasUndefInput = true; + continue; + } + if (CommonValue && Incoming != CommonValue) + return nullptr; // Not the same, bail out. + CommonValue = Incoming; + } + + // If CommonValue is null then all of the incoming values were either undef or + // equal to the phi node itself. + if (!CommonValue) + return UndefValue::get(PN->getType()); + + // If we have a PHI node like phi(X, undef, X), where X is defined by some + // instruction, we cannot return X as the result of the PHI node unless it + // dominates the PHI block. + if (HasUndefInput) + return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; + + return CommonValue; +} + +static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, + Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { + if (auto *C = dyn_cast<Constant>(Op)) + return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); + + if (auto *CI = dyn_cast<CastInst>(Op)) { + auto *Src = CI->getOperand(0); + Type *SrcTy = Src->getType(); + Type *MidTy = CI->getType(); + Type *DstTy = Ty; + if (Src->getType() == Ty) { + auto FirstOp = static_cast<Instruction::CastOps>(CI->getOpcode()); + auto SecondOp = static_cast<Instruction::CastOps>(CastOpc); + Type *SrcIntPtrTy = + SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr; + Type *MidIntPtrTy = + MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr; + Type *DstIntPtrTy = + DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr; + if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy, + SrcIntPtrTy, MidIntPtrTy, + DstIntPtrTy) == Instruction::BitCast) + return Src; + } + } + + // bitcast x -> x + if (CastOpc == Instruction::BitCast) + if (Op->getType() == Ty) + return Op; + + return nullptr; +} + +Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, + const SimplifyQuery &Q) { + return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); +} + +/// For the given destination element of a shuffle, peek through shuffles to +/// match a root vector source operand that contains that element in the same +/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s). +static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, + int MaskVal, Value *RootVec, + unsigned MaxRecurse) { + if (!MaxRecurse--) + return nullptr; + + // Bail out if any mask value is undefined. That kind of shuffle may be + // simplified further based on demanded bits or other folds. + if (MaskVal == -1) + return nullptr; + + // The mask value chooses which source operand we need to look at next. + int InVecNumElts = Op0->getType()->getVectorNumElements(); + int RootElt = MaskVal; + Value *SourceOp = Op0; + if (MaskVal >= InVecNumElts) { + RootElt = MaskVal - InVecNumElts; + SourceOp = Op1; + } + + // If the source operand is a shuffle itself, look through it to find the + // matching root vector. + if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) { + return foldIdentityShuffles( + DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), + SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); + } + + // TODO: Look through bitcasts? What if the bitcast changes the vector element + // size? + + // The source operand is not a shuffle. Initialize the root vector value for + // this shuffle if that has not been done yet. + if (!RootVec) + RootVec = SourceOp; + + // Give up as soon as a source operand does not match the existing root value. + if (RootVec != SourceOp) + return nullptr; + + // The element must be coming from the same lane in the source vector + // (although it may have crossed lanes in intermediate shuffles). + if (RootElt != DestElt) + return nullptr; + + return RootVec; +} + +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (isa<UndefValue>(Mask)) + return UndefValue::get(RetTy); + + Type *InVecTy = Op0->getType(); + unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); + unsigned InVecNumElts = InVecTy->getVectorNumElements(); + + SmallVector<int, 32> Indices; + ShuffleVectorInst::getShuffleMask(Mask, Indices); + assert(MaskNumElts == Indices.size() && + "Size of Indices not same as number of mask elements?"); + + // Canonicalization: If mask does not select elements from an input vector, + // replace that input vector with undef. + bool MaskSelects0 = false, MaskSelects1 = false; + for (unsigned i = 0; i != MaskNumElts; ++i) { + if (Indices[i] == -1) + continue; + if ((unsigned)Indices[i] < InVecNumElts) + MaskSelects0 = true; + else + MaskSelects1 = true; + } + if (!MaskSelects0) + Op0 = UndefValue::get(InVecTy); + if (!MaskSelects1) + Op1 = UndefValue::get(InVecTy); + + auto *Op0Const = dyn_cast<Constant>(Op0); + auto *Op1Const = dyn_cast<Constant>(Op1); + + // If all operands are constant, constant fold the shuffle. + if (Op0Const && Op1Const) + return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + + // Canonicalization: if only one input vector is constant, it shall be the + // second one. + if (Op0Const && !Op1Const) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts); + } + + // A shuffle of a splat is always the splat itself. Legal if the shuffle's + // value type is same as the input vectors' type. + if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) + if (isa<UndefValue>(Op1) && RetTy == InVecTy && + OpShuf->getMask()->getSplatValue()) + return Op0; + + // Don't fold a shuffle with undef mask elements. This may get folded in a + // better way using demanded bits or other analysis. + // TODO: Should we allow this? + if (find(Indices, -1) != Indices.end()) + return nullptr; + + // Check if every element of this shuffle can be mapped back to the + // corresponding element of a single root vector. If so, we don't need this + // shuffle. This handles simple identity shuffles as well as chains of + // shuffles that may widen/narrow and/or move elements across lanes and back. + Value *RootVec = nullptr; + for (unsigned i = 0; i != MaskNumElts; ++i) { + // Note that recursion is limited for each vector element, so if any element + // exceeds the limit, this will fail to simplify. + RootVec = + foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse); + + // We can't replace a widening/narrowing shuffle with one of its operands. + if (!RootVec || RootVec->getType() != RetTy) + return nullptr; + } + return RootVec; +} + +/// Given operands for a ShuffleVectorInst, fold the result or return null. +Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q) { + return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); +} + +static Constant *foldConstant(Instruction::UnaryOps Opcode, + Value *&Op, const SimplifyQuery &Q) { + if (auto *C = dyn_cast<Constant>(Op)) + return ConstantFoldUnaryOpOperand(Opcode, C, Q.DL); + return nullptr; +} + +/// Given the operand for an FNeg, see if we can fold the result. If not, this +/// returns null. +static Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldConstant(Instruction::FNeg, Op, Q)) + return C; + + Value *X; + // fneg (fneg X) ==> X + if (match(Op, m_FNeg(m_Value(X)))) + return X; + + return nullptr; +} + +Value *llvm::SimplifyFNegInst(Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::simplifyFNegInst(Op, FMF, Q, RecursionLimit); +} + +static Constant *propagateNaN(Constant *In) { + // If the input is a vector with undef elements, just return a default NaN. + if (!In->isNaN()) + return ConstantFP::getNaN(In->getType()); + + // Propagate the existing NaN constant when possible. + // TODO: Should we quiet a signaling NaN? + return In; +} + +/// Perform folds that are common to any floating-point operation. This implies +/// transforms based on undef/NaN because the operation itself makes no +/// difference to the result. +static Constant *simplifyFPOp(ArrayRef<Value *> Ops) { + if (any_of(Ops, [](Value *V) { return isa<UndefValue>(V); })) + return ConstantFP::getNaN(Ops[0]->getType()); + + for (Value *V : Ops) + if (match(V, m_NaN())) + return propagateNaN(cast<Constant>(V)); + + return nullptr; +} + +/// Given operands for an FAdd, see if we can fold the result. If not, this +/// returns null. +static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; + + if (Constant *C = simplifyFPOp({Op0, Op1})) + return C; + + // fadd X, -0 ==> X + if (match(Op1, m_NegZeroFP())) + return Op0; + + // fadd X, 0 ==> X, when we know X is not -0 + if (match(Op1, m_PosZeroFP()) && + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + return Op0; + + // With nnan: -X + X --> 0.0 (and commuted variant) + // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. + // Negative zeros are allowed because we always end up with positive zero: + // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 + // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 + if (FMF.noNaNs()) { + if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || + match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) + return ConstantFP::getNullValue(Op0->getType()); + + if (match(Op0, m_FNeg(m_Specific(Op1))) || + match(Op1, m_FNeg(m_Specific(Op0)))) + return ConstantFP::getNullValue(Op0->getType()); + } + + // (X - Y) + Y --> X + // Y + (X - Y) --> X + Value *X; + if (FMF.noSignedZeros() && FMF.allowReassoc() && + (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) || + match(Op1, m_FSub(m_Value(X), m_Specific(Op0))))) + return X; + + return nullptr; +} + +/// Given operands for an FSub, see if we can fold the result. If not, this +/// returns null. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; + + if (Constant *C = simplifyFPOp({Op0, Op1})) + return C; + + // fsub X, +0 ==> X + if (match(Op1, m_PosZeroFP())) + return Op0; + + // fsub X, -0 ==> X, when we know X is not -0 + if (match(Op1, m_NegZeroFP()) && + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + return Op0; + + // fsub -0.0, (fsub -0.0, X) ==> X + // fsub -0.0, (fneg X) ==> X + Value *X; + if (match(Op0, m_NegZeroFP()) && + match(Op1, m_FNeg(m_Value(X)))) + return X; + + // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. + // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. + if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && + (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + match(Op1, m_FNeg(m_Value(X))))) + return X; + + // fsub nnan x, x ==> 0.0 + if (FMF.noNaNs() && Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // Y - (Y - X) --> X + // (X + Y) - Y --> X + if (FMF.noSignedZeros() && FMF.allowReassoc() && + (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + return X; + + return nullptr; +} + +static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = simplifyFPOp({Op0, Op1})) + return C; + + // fmul X, 1.0 ==> X + if (match(Op1, m_FPOne())) + return Op0; + + // fmul 1.0, X ==> X + if (match(Op0, m_FPOne())) + return Op1; + + // fmul nnan nsz X, 0 ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + + // fmul nnan nsz 0, X ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op1->getType()); + + // sqrt(X) * sqrt(X) --> X, if we can: + // 1. Remove the intermediate rounding (reassociate). + // 2. Ignore non-zero negative numbers because sqrt would produce NAN. + // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. + Value *X; + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && + FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) + return X; + + return nullptr; +} + +/// Given the operands for an FMul, see if we can fold the result +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; + + // Now apply simplifications that do not require rounding. + return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse); +} + +Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); +} + + +Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit); +} + +static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; + + if (Constant *C = simplifyFPOp({Op0, Op1})) + return C; + + // X / 1.0 -> X + if (match(Op1, m_FPOne())) + return Op0; + + // 0 / X -> 0 + // Requires that NaNs are off (X could be zero) and signed zeroes are + // ignored (X could be positive or negative, so the output sign is unknown). + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + + if (FMF.noNaNs()) { + // X / X -> 1.0 is legal when NaNs are ignored. + // We can ignore infinities because INF/INF is NaN. + if (Op0 == Op1) + return ConstantFP::get(Op0->getType(), 1.0); + + // (X * Y) / Y --> X if we can reassociate to the above form. + Value *X; + if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1)))) + return X; + + // -X / X -> -1.0 and + // X / -X -> -1.0 are legal when NaNs are ignored. + // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. + if (match(Op0, m_FNegNSZ(m_Specific(Op1))) || + match(Op1, m_FNegNSZ(m_Specific(Op0)))) + return ConstantFP::get(Op0->getType(), -1.0); + } + + return nullptr; +} + +Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; + + if (Constant *C = simplifyFPOp({Op0, Op1})) + return C; + + // Unlike fdiv, the result of frem always matches the sign of the dividend. + // The constant match may include undef elements in a vector, so return a full + // zero constant as the result. + if (FMF.noNaNs()) { + // +0 % X -> 0 + if (match(Op0, m_PosZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + // -0 % X -> -0 + if (match(Op0, m_NegZeroFP())) + return ConstantFP::getNegativeZero(Op0->getType()); + } + + return nullptr; +} + +Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +//=== Helper functions for higher up the class hierarchy. + +/// Given the operand for a UnaryOperator, see if we can fold the result. +/// If not, this returns null. +static Value *simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q, + unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::FNeg: + return simplifyFNegInst(Op, FastMathFlags(), Q, MaxRecurse); + default: + llvm_unreachable("Unexpected opcode"); + } +} + +/// Given the operand for a UnaryOperator, see if we can fold the result. +/// If not, this returns null. +/// Try to use FastMathFlags when folding the result. +static Value *simplifyFPUnOp(unsigned Opcode, Value *Op, + const FastMathFlags &FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::FNeg: + return simplifyFNegInst(Op, FMF, Q, MaxRecurse); + default: + return simplifyUnOp(Opcode, Op, Q, MaxRecurse); + } +} + +Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q) { + return ::simplifyUnOp(Opcode, Op, Q, RecursionLimit); +} + +Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::simplifyFPUnOp(Opcode, Op, FMF, Q, RecursionLimit); +} + +/// Given operands for a BinaryOperator, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::Add: + return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); + case Instruction::Sub: + return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); + case Instruction::Mul: + return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); + case Instruction::SDiv: + return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); + case Instruction::UDiv: + return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); + case Instruction::SRem: + return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); + case Instruction::URem: + return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Shl: + return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); + case Instruction::LShr: + return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); + case Instruction::AShr: + return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); + case Instruction::And: + return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Or: + return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Xor: + return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); + case Instruction::FAdd: + return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::FSub: + return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::FMul: + return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::FDiv: + return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::FRem: + return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + default: + llvm_unreachable("Unexpected opcode"); + } +} + +/// Given operands for a BinaryOperator, see if we can fold the result. +/// If not, this returns null. +/// Try to use FastMathFlags when folding the result. +static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const FastMathFlags &FMF, const SimplifyQuery &Q, + unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::FAdd: + return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); + case Instruction::FSub: + return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); + case Instruction::FMul: + return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); + case Instruction::FDiv: + return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); + default: + return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); + } +} + +Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); +} + +Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); +} + +/// Given operands for a CmpInst, see if we can fold the result. +static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) + return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); + return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); +} + +Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); +} + +static bool IsIdempotent(Intrinsic::ID ID) { + switch (ID) { + default: return false; + + // Unary idempotent: f(f(x)) = f(x) + case Intrinsic::fabs: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + case Intrinsic::canonicalize: + return true; + } +} + +static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, + const DataLayout &DL) { + GlobalValue *PtrSym; + APInt PtrOffset; + if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL)) + return nullptr; + + Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext()); + Type *Int32Ty = Type::getInt32Ty(Ptr->getContext()); + Type *Int32PtrTy = Int32Ty->getPointerTo(); + Type *Int64Ty = Type::getInt64Ty(Ptr->getContext()); + + auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset); + if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64) + return nullptr; + + uint64_t OffsetInt = OffsetConstInt->getSExtValue(); + if (OffsetInt % 4 != 0) + return nullptr; + + Constant *C = ConstantExpr::getGetElementPtr( + Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy), + ConstantInt::get(Int64Ty, OffsetInt / 4)); + Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL); + if (!Loaded) + return nullptr; + + auto *LoadedCE = dyn_cast<ConstantExpr>(Loaded); + if (!LoadedCE) + return nullptr; + + if (LoadedCE->getOpcode() == Instruction::Trunc) { + LoadedCE = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); + if (!LoadedCE) + return nullptr; + } + + if (LoadedCE->getOpcode() != Instruction::Sub) + return nullptr; + + auto *LoadedLHS = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); + if (!LoadedLHS || LoadedLHS->getOpcode() != Instruction::PtrToInt) + return nullptr; + auto *LoadedLHSPtr = LoadedLHS->getOperand(0); + + Constant *LoadedRHS = LoadedCE->getOperand(1); + GlobalValue *LoadedRHSSym; + APInt LoadedRHSOffset; + if (!IsConstantOffsetFromGlobal(LoadedRHS, LoadedRHSSym, LoadedRHSOffset, + DL) || + PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset) + return nullptr; + + return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); +} + +static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, + const SimplifyQuery &Q) { + // Idempotent functions return the same result when called repeatedly. + Intrinsic::ID IID = F->getIntrinsicID(); + if (IsIdempotent(IID)) + if (auto *II = dyn_cast<IntrinsicInst>(Op0)) + if (II->getIntrinsicID() == IID) + return II; + + Value *X; + switch (IID) { + case Intrinsic::fabs: + if (SignBitMustBeZero(Op0, Q.TLI)) return Op0; + break; + case Intrinsic::bswap: + // bswap(bswap(x)) -> x + if (match(Op0, m_BSwap(m_Value(X)))) return X; + break; + case Intrinsic::bitreverse: + // bitreverse(bitreverse(x)) -> x + if (match(Op0, m_BitReverse(m_Value(X)))) return X; + break; + case Intrinsic::exp: + // exp(log(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) return X; + break; + case Intrinsic::exp2: + // exp2(log2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) return X; + break; + case Intrinsic::log: + // log(exp(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) return X; + break; + case Intrinsic::log2: + // log2(exp2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) || + match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), + m_Value(X))))) return X; + break; + case Intrinsic::log10: + // log10(pow(10.0, x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), + m_Value(X)))) return X; + break; + case Intrinsic::floor: + case Intrinsic::trunc: + case Intrinsic::ceil: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::rint: { + // floor (sitofp x) -> sitofp x + // floor (uitofp x) -> uitofp x + // + // Converting from int always results in a finite integral number or + // infinity. For either of those inputs, these rounding functions always + // return the same value, so the rounding can be eliminated. + if (match(Op0, m_SIToFP(m_Value())) || match(Op0, m_UIToFP(m_Value()))) + return Op0; + break; + } + default: + break; + } + + return nullptr; +} + +static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, + const SimplifyQuery &Q) { + Intrinsic::ID IID = F->getIntrinsicID(); + Type *ReturnType = F->getReturnType(); + switch (IID) { + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + // X - X -> { 0, false } + if (Op0 == Op1) + return Constant::getNullValue(ReturnType); + LLVM_FALLTHROUGH; + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + // X - undef -> { undef, false } + // undef - X -> { undef, false } + // X + undef -> { undef, false } + // undef + x -> { undef, false } + if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) { + return ConstantStruct::get( + cast<StructType>(ReturnType), + {UndefValue::get(ReturnType->getStructElementType(0)), + Constant::getNullValue(ReturnType->getStructElementType(1))}); + } + break; + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + // 0 * X -> { 0, false } + // X * 0 -> { 0, false } + if (match(Op0, m_Zero()) || match(Op1, m_Zero())) + return Constant::getNullValue(ReturnType); + // undef * X -> { 0, false } + // X * undef -> { 0, false } + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getNullValue(ReturnType); + break; + case Intrinsic::uadd_sat: + // sat(MAX + X) -> MAX + // sat(X + MAX) -> MAX + if (match(Op0, m_AllOnes()) || match(Op1, m_AllOnes())) + return Constant::getAllOnesValue(ReturnType); + LLVM_FALLTHROUGH; + case Intrinsic::sadd_sat: + // sat(X + undef) -> -1 + // sat(undef + X) -> -1 + // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1). + // For signed: Assume undef is ~X, in which case X + ~X = -1. + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getAllOnesValue(ReturnType); + + // X + 0 -> X + if (match(Op1, m_Zero())) + return Op0; + // 0 + X -> X + if (match(Op0, m_Zero())) + return Op1; + break; + case Intrinsic::usub_sat: + // sat(0 - X) -> 0, sat(X - MAX) -> 0 + if (match(Op0, m_Zero()) || match(Op1, m_AllOnes())) + return Constant::getNullValue(ReturnType); + LLVM_FALLTHROUGH; + case Intrinsic::ssub_sat: + // X - X -> 0, X - undef -> 0, undef - X -> 0 + if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getNullValue(ReturnType); + // X - 0 -> X + if (match(Op1, m_Zero())) + return Op0; + break; + case Intrinsic::load_relative: + if (auto *C0 = dyn_cast<Constant>(Op0)) + if (auto *C1 = dyn_cast<Constant>(Op1)) + return SimplifyRelativeLoad(C0, C1, Q.DL); + break; + case Intrinsic::powi: + if (auto *Power = dyn_cast<ConstantInt>(Op1)) { + // powi(x, 0) -> 1.0 + if (Power->isZero()) + return ConstantFP::get(Op0->getType(), 1.0); + // powi(x, 1) -> x + if (Power->isOne()) + return Op0; + } + break; + case Intrinsic::maxnum: + case Intrinsic::minnum: + case Intrinsic::maximum: + case Intrinsic::minimum: { + // If the arguments are the same, this is a no-op. + if (Op0 == Op1) return Op0; + + // If one argument is undef, return the other argument. + if (match(Op0, m_Undef())) + return Op1; + if (match(Op1, m_Undef())) + return Op0; + + // If one argument is NaN, return other or NaN appropriately. + bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; + if (match(Op0, m_NaN())) + return PropagateNaN ? Op0 : Op1; + if (match(Op1, m_NaN())) + return PropagateNaN ? Op1 : Op0; + + // Min/max of the same operation with common operand: + // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) + if (auto *M0 = dyn_cast<IntrinsicInst>(Op0)) + if (M0->getIntrinsicID() == IID && + (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1)) + return Op0; + if (auto *M1 = dyn_cast<IntrinsicInst>(Op1)) + if (M1->getIntrinsicID() == IID && + (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0)) + return Op1; + + // min(X, -Inf) --> -Inf (and commuted variant) + // max(X, +Inf) --> +Inf (and commuted variant) + bool UseNegInf = IID == Intrinsic::minnum || IID == Intrinsic::minimum; + const APFloat *C; + if ((match(Op0, m_APFloat(C)) && C->isInfinity() && + C->isNegative() == UseNegInf) || + (match(Op1, m_APFloat(C)) && C->isInfinity() && + C->isNegative() == UseNegInf)) + return ConstantFP::getInfinity(ReturnType, UseNegInf); + + // TODO: minnum(nnan x, inf) -> x + // TODO: minnum(nnan ninf x, flt_max) -> x + // TODO: maxnum(nnan x, -inf) -> x + // TODO: maxnum(nnan ninf x, -flt_max) -> x + break; + } + default: + break; + } + + return nullptr; +} + +static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { + + // Intrinsics with no operands have some kind of side effect. Don't simplify. + unsigned NumOperands = Call->getNumArgOperands(); + if (!NumOperands) + return nullptr; + + Function *F = cast<Function>(Call->getCalledFunction()); + Intrinsic::ID IID = F->getIntrinsicID(); + if (NumOperands == 1) + return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q); + + if (NumOperands == 2) + return simplifyBinaryIntrinsic(F, Call->getArgOperand(0), + Call->getArgOperand(1), Q); + + // Handle intrinsics with 3 or more arguments. + switch (IID) { + case Intrinsic::masked_load: + case Intrinsic::masked_gather: { + Value *MaskArg = Call->getArgOperand(2); + Value *PassthruArg = Call->getArgOperand(3); + // If the mask is all zeros or undef, the "passthru" argument is the result. + if (maskIsAllZeroOrUndef(MaskArg)) + return PassthruArg; + return nullptr; + } + case Intrinsic::fshl: + case Intrinsic::fshr: { + Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1), + *ShAmtArg = Call->getArgOperand(2); + + // If both operands are undef, the result is undef. + if (match(Op0, m_Undef()) && match(Op1, m_Undef())) + return UndefValue::get(F->getReturnType()); + + // If shift amount is undef, assume it is zero. + if (match(ShAmtArg, m_Undef())) + return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + + const APInt *ShAmtC; + if (match(ShAmtArg, m_APInt(ShAmtC))) { + // If there's effectively no shift, return the 1st arg or 2nd arg. + APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); + if (ShAmtC->urem(BitWidth).isNullValue()) + return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + } + return nullptr; + } + case Intrinsic::fma: + case Intrinsic::fmuladd: { + Value *Op0 = Call->getArgOperand(0); + Value *Op1 = Call->getArgOperand(1); + Value *Op2 = Call->getArgOperand(2); + if (Value *V = simplifyFPOp({ Op0, Op1, Op2 })) + return V; + return nullptr; + } + default: + return nullptr; + } +} + +Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { + Value *Callee = Call->getCalledValue(); + + // call undef -> undef + // call null -> undef + if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) + return UndefValue::get(Call->getType()); + + Function *F = dyn_cast<Function>(Callee); + if (!F) + return nullptr; + + if (F->isIntrinsic()) + if (Value *Ret = simplifyIntrinsic(Call, Q)) + return Ret; + + if (!canConstantFoldCallTo(Call, F)) + return nullptr; + + SmallVector<Constant *, 4> ConstantArgs; + unsigned NumArgs = Call->getNumArgOperands(); + ConstantArgs.reserve(NumArgs); + for (auto &Arg : Call->args()) { + Constant *C = dyn_cast<Constant>(&Arg); + if (!C) + return nullptr; + ConstantArgs.push_back(C); + } + + return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); +} + +/// See if we can compute a simplified version of this instruction. +/// If not, this returns null. + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); + Value *Result; + + switch (I->getOpcode()) { + default: + Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); + break; + case Instruction::FNeg: + Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); + break; + case Instruction::FAdd: + Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q); + break; + case Instruction::Add: + Result = + SimplifyAddInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + break; + case Instruction::FSub: + Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q); + break; + case Instruction::Sub: + Result = + SimplifySubInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + break; + case Instruction::FMul: + Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q); + break; + case Instruction::Mul: + Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::SDiv: + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::UDiv: + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::FDiv: + Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q); + break; + case Instruction::SRem: + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::URem: + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::FRem: + Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q); + break; + case Instruction::Shl: + Result = + SimplifyShlInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + break; + case Instruction::LShr: + Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + break; + case Instruction::AShr: + Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + break; + case Instruction::And: + Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::Or: + Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::Xor: + Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::ICmp: + Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), + I->getOperand(0), I->getOperand(1), Q); + break; + case Instruction::FCmp: + Result = + SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), + I->getOperand(1), I->getFastMathFlags(), Q); + break; + case Instruction::Select: + Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), + I->getOperand(2), Q); + break; + case Instruction::GetElementPtr: { + SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); + Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), + Ops, Q); + break; + } + case Instruction::InsertValue: { + InsertValueInst *IV = cast<InsertValueInst>(I); + Result = SimplifyInsertValueInst(IV->getAggregateOperand(), + IV->getInsertedValueOperand(), + IV->getIndices(), Q); + break; + } + case Instruction::InsertElement: { + auto *IE = cast<InsertElementInst>(I); + Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), + IE->getOperand(2), Q); + break; + } + case Instruction::ExtractValue: { + auto *EVI = cast<ExtractValueInst>(I); + Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), + EVI->getIndices(), Q); + break; + } + case Instruction::ExtractElement: { + auto *EEI = cast<ExtractElementInst>(I); + Result = SimplifyExtractElementInst(EEI->getVectorOperand(), + EEI->getIndexOperand(), Q); + break; + } + case Instruction::ShuffleVector: { + auto *SVI = cast<ShuffleVectorInst>(I); + Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), + SVI->getMask(), SVI->getType(), Q); + break; + } + case Instruction::PHI: + Result = SimplifyPHINode(cast<PHINode>(I), Q); + break; + case Instruction::Call: { + Result = SimplifyCall(cast<CallInst>(I), Q); + break; + } +#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: +#include "llvm/IR/Instruction.def" +#undef HANDLE_CAST_INST + Result = + SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + break; + case Instruction::Alloca: + // No simplifications for Alloca and it can't be constant folded. + Result = nullptr; + break; + } + + // In general, it is possible for computeKnownBits to determine all bits in a + // value even when the operands are not all constants. + if (!Result && I->getType()->isIntOrIntVectorTy()) { + KnownBits Known = computeKnownBits(I, Q.DL, /*Depth*/ 0, Q.AC, I, Q.DT, ORE); + if (Known.isConstant()) + Result = ConstantInt::get(I->getType(), Known.getConstant()); + } + + /// If called on unreachable code, the above logic may report that the + /// instruction simplified to itself. Make life easier for users by + /// detecting that case here, returning a safe value instead. + return Result == I ? UndefValue::get(I->getType()) : Result; +} + +/// Implementation of recursive simplification through an instruction's +/// uses. +/// +/// This is the common implementation of the recursive simplification routines. +/// If we have a pre-simplified value in 'SimpleV', that is forcibly used to +/// replace the instruction 'I'. Otherwise, we simply add 'I' to the list of +/// instructions to process and attempt to simplify it using +/// InstructionSimplify. Recursively visited users which could not be +/// simplified themselves are to the optional UnsimplifiedUsers set for +/// further processing by the caller. +/// +/// This routine returns 'true' only when *it* simplifies something. The passed +/// in simplified value does not count toward this. +static bool replaceAndRecursivelySimplifyImpl( + Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + SmallSetVector<Instruction *, 8> *UnsimplifiedUsers = nullptr) { + bool Simplified = false; + SmallSetVector<Instruction *, 8> Worklist; + const DataLayout &DL = I->getModule()->getDataLayout(); + + // If we have an explicit value to collapse to, do that round of the + // simplification loop by hand initially. + if (SimpleV) { + for (User *U : I->users()) + if (U != I) + Worklist.insert(cast<Instruction>(U)); + + // Replace the instruction with its simplified value. + I->replaceAllUsesWith(SimpleV); + + // Gracefully handle edge cases where the instruction is not wired into any + // parent block. + if (I->getParent() && !I->isEHPad() && !I->isTerminator() && + !I->mayHaveSideEffects()) + I->eraseFromParent(); + } else { + Worklist.insert(I); + } + + // Note that we must test the size on each iteration, the worklist can grow. + for (unsigned Idx = 0; Idx != Worklist.size(); ++Idx) { + I = Worklist[Idx]; + + // See if this instruction simplifies. + SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC}); + if (!SimpleV) { + if (UnsimplifiedUsers) + UnsimplifiedUsers->insert(I); + continue; + } + + Simplified = true; + + // Stash away all the uses of the old instruction so we can check them for + // recursive simplifications after a RAUW. This is cheaper than checking all + // uses of To on the recursive step in most cases. + for (User *U : I->users()) + Worklist.insert(cast<Instruction>(U)); + + // Replace the instruction with its simplified value. + I->replaceAllUsesWith(SimpleV); + + // Gracefully handle edge cases where the instruction is not wired into any + // parent block. + if (I->getParent() && !I->isEHPad() && !I->isTerminator() && + !I->mayHaveSideEffects()) + I->eraseFromParent(); + } + return Simplified; +} + +bool llvm::recursivelySimplifyInstruction(Instruction *I, + const TargetLibraryInfo *TLI, + const DominatorTree *DT, + AssumptionCache *AC) { + return replaceAndRecursivelySimplifyImpl(I, nullptr, TLI, DT, AC, nullptr); +} + +bool llvm::replaceAndRecursivelySimplify( + Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + SmallSetVector<Instruction *, 8> *UnsimplifiedUsers) { + assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!"); + assert(SimpleV && "Must provide a simplified value."); + return replaceAndRecursivelySimplifyImpl(I, SimpleV, TLI, DT, AC, + UnsimplifiedUsers); +} + +namespace llvm { +const SimplifyQuery getBestSimplifyQuery(Pass &P, Function &F) { + auto *DTWP = P.getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + auto *TLIWP = P.getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIWP ? &TLIWP->getTLI(F) : nullptr; + auto *ACWP = P.getAnalysisIfAvailable<AssumptionCacheTracker>(); + auto *AC = ACWP ? &ACWP->getAssumptionCache(F) : nullptr; + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} + +const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &AR, + const DataLayout &DL) { + return {DL, &AR.TLI, &AR.DT, &AR.AC}; +} + +template <class T, class... TArgs> +const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM, + Function &F) { + auto *DT = AM.template getCachedResult<DominatorTreeAnalysis>(F); + auto *TLI = AM.template getCachedResult<TargetLibraryAnalysis>(F); + auto *AC = AM.template getCachedResult<AssumptionAnalysis>(F); + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} +template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &, + Function &); +} diff --git a/llvm/lib/Analysis/Interval.cpp b/llvm/lib/Analysis/Interval.cpp new file mode 100644 index 000000000000..07d6e27c13be --- /dev/null +++ b/llvm/lib/Analysis/Interval.cpp @@ -0,0 +1,51 @@ +//===- Interval.cpp - Interval class code ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the Interval class, which represents a +// partition of a control flow graph of some kind. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Interval.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Interval Implementation +//===----------------------------------------------------------------------===// + +// isLoop - Find out if there is a back edge in this interval... +bool Interval::isLoop() const { + // There is a loop in this interval iff one of the predecessors of the header + // node lives in the interval. + for (::pred_iterator I = ::pred_begin(HeaderNode), E = ::pred_end(HeaderNode); + I != E; ++I) + if (contains(*I)) + return true; + return false; +} + +void Interval::print(raw_ostream &OS) const { + OS << "-------------------------------------------------------------\n" + << "Interval Contents:\n"; + + // Print out all of the basic blocks in the interval... + for (const BasicBlock *Node : Nodes) + OS << *Node << "\n"; + + OS << "Interval Predecessors:\n"; + for (const BasicBlock *Predecessor : Predecessors) + OS << *Predecessor << "\n"; + + OS << "Interval Successors:\n"; + for (const BasicBlock *Successor : Successors) + OS << *Successor << "\n"; +} diff --git a/llvm/lib/Analysis/IntervalPartition.cpp b/llvm/lib/Analysis/IntervalPartition.cpp new file mode 100644 index 000000000000..d12db010db6a --- /dev/null +++ b/llvm/lib/Analysis/IntervalPartition.cpp @@ -0,0 +1,113 @@ +//===- IntervalPartition.cpp - Interval Partition module code -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the IntervalPartition class, which +// calculates and represent the interval partition of a function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IntervalPartition.h" +#include "llvm/Analysis/Interval.h" +#include "llvm/Analysis/IntervalIterator.h" +#include "llvm/Pass.h" +#include <cassert> +#include <utility> + +using namespace llvm; + +char IntervalPartition::ID = 0; + +INITIALIZE_PASS(IntervalPartition, "intervals", + "Interval Partition Construction", true, true) + +//===----------------------------------------------------------------------===// +// IntervalPartition Implementation +//===----------------------------------------------------------------------===// + +// releaseMemory - Reset state back to before function was analyzed +void IntervalPartition::releaseMemory() { + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + delete Intervals[i]; + IntervalMap.clear(); + Intervals.clear(); + RootInterval = nullptr; +} + +void IntervalPartition::print(raw_ostream &O, const Module*) const { + for(unsigned i = 0, e = Intervals.size(); i != e; ++i) + Intervals[i]->print(O); +} + +// addIntervalToPartition - Add an interval to the internal list of intervals, +// and then add mappings from all of the basic blocks in the interval to the +// interval itself (in the IntervalMap). +void IntervalPartition::addIntervalToPartition(Interval *I) { + Intervals.push_back(I); + + // Add mappings for all of the basic blocks in I to the IntervalPartition + for (Interval::node_iterator It = I->Nodes.begin(), End = I->Nodes.end(); + It != End; ++It) + IntervalMap.insert(std::make_pair(*It, I)); +} + +// updatePredecessors - Interval generation only sets the successor fields of +// the interval data structures. After interval generation is complete, +// run through all of the intervals and propagate successor info as +// predecessor info. +void IntervalPartition::updatePredecessors(Interval *Int) { + BasicBlock *Header = Int->getHeaderNode(); + for (BasicBlock *Successor : Int->Successors) + getBlockInterval(Successor)->Predecessors.push_back(Header); +} + +// IntervalPartition ctor - Build the first level interval partition for the +// specified function... +bool IntervalPartition::runOnFunction(Function &F) { + // Pass false to intervals_begin because we take ownership of it's memory + function_interval_iterator I = intervals_begin(&F, false); + assert(I != intervals_end(&F) && "No intervals in function!?!?!"); + + addIntervalToPartition(RootInterval = *I); + + ++I; // After the first one... + + // Add the rest of the intervals to the partition. + for (function_interval_iterator E = intervals_end(&F); I != E; ++I) + addIntervalToPartition(*I); + + // Now that we know all of the successor information, propagate this to the + // predecessors for each block. + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + updatePredecessors(Intervals[i]); + return false; +} + +// IntervalPartition ctor - Build a reduced interval partition from an +// existing interval graph. This takes an additional boolean parameter to +// distinguish it from a copy constructor. Always pass in false for now. +IntervalPartition::IntervalPartition(IntervalPartition &IP, bool) + : FunctionPass(ID) { + assert(IP.getRootInterval() && "Cannot operate on empty IntervalPartitions!"); + + // Pass false to intervals_begin because we take ownership of it's memory + interval_part_interval_iterator I = intervals_begin(IP, false); + assert(I != intervals_end(IP) && "No intervals in interval partition!?!?!"); + + addIntervalToPartition(RootInterval = *I); + + ++I; // After the first one... + + // Add the rest of the intervals to the partition. + for (interval_part_interval_iterator E = intervals_end(IP); I != E; ++I) + addIntervalToPartition(*I); + + // Now that we know all of the successor information, propagate this to the + // predecessors for each block. + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + updatePredecessors(Intervals[i]); +} diff --git a/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp b/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp new file mode 100644 index 000000000000..439758560284 --- /dev/null +++ b/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp @@ -0,0 +1,71 @@ +//===- LazyBlockFrequencyInfo.cpp - Lazy Block Frequency Analysis ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is an alternative analysis pass to BlockFrequencyInfoWrapperPass. The +// difference is that with this pass the block frequencies are not computed when +// the analysis pass is executed but rather when the BFI result is explicitly +// requested by the analysis client. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LazyBranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +#define DEBUG_TYPE "lazy-block-freq" + +INITIALIZE_PASS_BEGIN(LazyBlockFrequencyInfoPass, DEBUG_TYPE, + "Lazy Block Frequency Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LazyBPIPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LazyBlockFrequencyInfoPass, DEBUG_TYPE, + "Lazy Block Frequency Analysis", true, true) + +char LazyBlockFrequencyInfoPass::ID = 0; + +LazyBlockFrequencyInfoPass::LazyBlockFrequencyInfoPass() : FunctionPass(ID) { + initializeLazyBlockFrequencyInfoPassPass(*PassRegistry::getPassRegistry()); +} + +void LazyBlockFrequencyInfoPass::print(raw_ostream &OS, const Module *) const { + LBFI.getCalculated().print(OS); +} + +void LazyBlockFrequencyInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { + LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU); + // We require DT so it's available when LI is available. The LI updating code + // asserts that DT is also present so if we don't make sure that we have DT + // here, that assert will trigger. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesAll(); +} + +void LazyBlockFrequencyInfoPass::releaseMemory() { LBFI.releaseMemory(); } + +bool LazyBlockFrequencyInfoPass::runOnFunction(Function &F) { + auto &BPIPass = getAnalysis<LazyBranchProbabilityInfoPass>(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LBFI.setAnalysis(&F, &BPIPass, &LI); + return false; +} + +void LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AnalysisUsage &AU) { + LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU); + AU.addRequired<LazyBlockFrequencyInfoPass>(); + AU.addRequired<LoopInfoWrapperPass>(); +} + +void llvm::initializeLazyBFIPassPass(PassRegistry &Registry) { + initializeLazyBPIPassPass(Registry); + INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass); + INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); +} diff --git a/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp b/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp new file mode 100644 index 000000000000..e727de468a0d --- /dev/null +++ b/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp @@ -0,0 +1,74 @@ +//===- LazyBranchProbabilityInfo.cpp - Lazy Branch Probability Analysis ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is an alternative analysis pass to BranchProbabilityInfoWrapperPass. +// The difference is that with this pass the branch probabilities are not +// computed when the analysis pass is executed but rather when the BPI results +// is explicitly requested by the analysis client. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyBranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +#define DEBUG_TYPE "lazy-branch-prob" + +INITIALIZE_PASS_BEGIN(LazyBranchProbabilityInfoPass, DEBUG_TYPE, + "Lazy Branch Probability Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LazyBranchProbabilityInfoPass, DEBUG_TYPE, + "Lazy Branch Probability Analysis", true, true) + +char LazyBranchProbabilityInfoPass::ID = 0; + +LazyBranchProbabilityInfoPass::LazyBranchProbabilityInfoPass() + : FunctionPass(ID) { + initializeLazyBranchProbabilityInfoPassPass(*PassRegistry::getPassRegistry()); +} + +void LazyBranchProbabilityInfoPass::print(raw_ostream &OS, + const Module *) const { + LBPI->getCalculated().print(OS); +} + +void LazyBranchProbabilityInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { + // We require DT so it's available when LI is available. The LI updating code + // asserts that DT is also present so if we don't make sure that we have DT + // here, that assert will trigger. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesAll(); +} + +void LazyBranchProbabilityInfoPass::releaseMemory() { LBPI.reset(); } + +bool LazyBranchProbabilityInfoPass::runOnFunction(Function &F) { + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + LBPI = std::make_unique<LazyBranchProbabilityInfo>(&F, &LI, &TLI); + return false; +} + +void LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AnalysisUsage &AU) { + AU.addRequired<LazyBranchProbabilityInfoPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +void llvm::initializeLazyBPIPassPass(PassRegistry &Registry) { + INITIALIZE_PASS_DEPENDENCY(LazyBranchProbabilityInfoPass); + INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); + INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass); +} diff --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp new file mode 100644 index 000000000000..ef31c1e0ba8c --- /dev/null +++ b/llvm/lib/Analysis/LazyCallGraph.cpp @@ -0,0 +1,1816 @@ +//===- LazyCallGraph.cpp - Analysis of a Module's call graph --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <iterator> +#include <string> +#include <tuple> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "lcg" + +void LazyCallGraph::EdgeSequence::insertEdgeInternal(Node &TargetN, + Edge::Kind EK) { + EdgeIndexMap.insert({&TargetN, Edges.size()}); + Edges.emplace_back(TargetN, EK); +} + +void LazyCallGraph::EdgeSequence::setEdgeKind(Node &TargetN, Edge::Kind EK) { + Edges[EdgeIndexMap.find(&TargetN)->second].setKind(EK); +} + +bool LazyCallGraph::EdgeSequence::removeEdgeInternal(Node &TargetN) { + auto IndexMapI = EdgeIndexMap.find(&TargetN); + if (IndexMapI == EdgeIndexMap.end()) + return false; + + Edges[IndexMapI->second] = Edge(); + EdgeIndexMap.erase(IndexMapI); + return true; +} + +static void addEdge(SmallVectorImpl<LazyCallGraph::Edge> &Edges, + DenseMap<LazyCallGraph::Node *, int> &EdgeIndexMap, + LazyCallGraph::Node &N, LazyCallGraph::Edge::Kind EK) { + if (!EdgeIndexMap.insert({&N, Edges.size()}).second) + return; + + LLVM_DEBUG(dbgs() << " Added callable function: " << N.getName() << "\n"); + Edges.emplace_back(LazyCallGraph::Edge(N, EK)); +} + +LazyCallGraph::EdgeSequence &LazyCallGraph::Node::populateSlow() { + assert(!Edges && "Must not have already populated the edges for this node!"); + + LLVM_DEBUG(dbgs() << " Adding functions called by '" << getName() + << "' to the graph.\n"); + + Edges = EdgeSequence(); + + SmallVector<Constant *, 16> Worklist; + SmallPtrSet<Function *, 4> Callees; + SmallPtrSet<Constant *, 16> Visited; + + // Find all the potential call graph edges in this function. We track both + // actual call edges and indirect references to functions. The direct calls + // are trivially added, but to accumulate the latter we walk the instructions + // and add every operand which is a constant to the worklist to process + // afterward. + // + // Note that we consider *any* function with a definition to be a viable + // edge. Even if the function's definition is subject to replacement by + // some other module (say, a weak definition) there may still be + // optimizations which essentially speculate based on the definition and + // a way to check that the specific definition is in fact the one being + // used. For example, this could be done by moving the weak definition to + // a strong (internal) definition and making the weak definition be an + // alias. Then a test of the address of the weak function against the new + // strong definition's address would be an effective way to determine the + // safety of optimizing a direct call edge. + for (BasicBlock &BB : *F) + for (Instruction &I : BB) { + if (auto CS = CallSite(&I)) + if (Function *Callee = CS.getCalledFunction()) + if (!Callee->isDeclaration()) + if (Callees.insert(Callee).second) { + Visited.insert(Callee); + addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*Callee), + LazyCallGraph::Edge::Call); + } + + for (Value *Op : I.operand_values()) + if (Constant *C = dyn_cast<Constant>(Op)) + if (Visited.insert(C).second) + Worklist.push_back(C); + } + + // We've collected all the constant (and thus potentially function or + // function containing) operands to all of the instructions in the function. + // Process them (recursively) collecting every function found. + visitReferences(Worklist, Visited, [&](Function &F) { + addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(F), + LazyCallGraph::Edge::Ref); + }); + + // Add implicit reference edges to any defined libcall functions (if we + // haven't found an explicit edge). + for (auto *F : G->LibFunctions) + if (!Visited.count(F)) + addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*F), + LazyCallGraph::Edge::Ref); + + return *Edges; +} + +void LazyCallGraph::Node::replaceFunction(Function &NewF) { + assert(F != &NewF && "Must not replace a function with itself!"); + F = &NewF; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::Node::dump() const { + dbgs() << *this << '\n'; +} +#endif + +static bool isKnownLibFunction(Function &F, TargetLibraryInfo &TLI) { + LibFunc LF; + + // Either this is a normal library function or a "vectorizable" function. + return TLI.getLibFunc(F, LF) || TLI.isFunctionVectorizable(F.getName()); +} + +LazyCallGraph::LazyCallGraph( + Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { + LLVM_DEBUG(dbgs() << "Building CG for module: " << M.getModuleIdentifier() + << "\n"); + for (Function &F : M) { + if (F.isDeclaration()) + continue; + // If this function is a known lib function to LLVM then we want to + // synthesize reference edges to it to model the fact that LLVM can turn + // arbitrary code into a library function call. + if (isKnownLibFunction(F, GetTLI(F))) + LibFunctions.insert(&F); + + if (F.hasLocalLinkage()) + continue; + + // External linkage defined functions have edges to them from other + // modules. + LLVM_DEBUG(dbgs() << " Adding '" << F.getName() + << "' to entry set of the graph.\n"); + addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), Edge::Ref); + } + + // Externally visible aliases of internal functions are also viable entry + // edges to the module. + for (auto &A : M.aliases()) { + if (A.hasLocalLinkage()) + continue; + if (Function* F = dyn_cast<Function>(A.getAliasee())) { + LLVM_DEBUG(dbgs() << " Adding '" << F->getName() + << "' with alias '" << A.getName() + << "' to entry set of the graph.\n"); + addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(*F), Edge::Ref); + } + } + + // Now add entry nodes for functions reachable via initializers to globals. + SmallVector<Constant *, 16> Worklist; + SmallPtrSet<Constant *, 16> Visited; + for (GlobalVariable &GV : M.globals()) + if (GV.hasInitializer()) + if (Visited.insert(GV.getInitializer()).second) + Worklist.push_back(GV.getInitializer()); + + LLVM_DEBUG( + dbgs() << " Adding functions referenced by global initializers to the " + "entry set.\n"); + visitReferences(Worklist, Visited, [&](Function &F) { + addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), + LazyCallGraph::Edge::Ref); + }); +} + +LazyCallGraph::LazyCallGraph(LazyCallGraph &&G) + : BPA(std::move(G.BPA)), NodeMap(std::move(G.NodeMap)), + EntryEdges(std::move(G.EntryEdges)), SCCBPA(std::move(G.SCCBPA)), + SCCMap(std::move(G.SCCMap)), + LibFunctions(std::move(G.LibFunctions)) { + updateGraphPtrs(); +} + +LazyCallGraph &LazyCallGraph::operator=(LazyCallGraph &&G) { + BPA = std::move(G.BPA); + NodeMap = std::move(G.NodeMap); + EntryEdges = std::move(G.EntryEdges); + SCCBPA = std::move(G.SCCBPA); + SCCMap = std::move(G.SCCMap); + LibFunctions = std::move(G.LibFunctions); + updateGraphPtrs(); + return *this; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::SCC::dump() const { + dbgs() << *this << '\n'; +} +#endif + +#ifndef NDEBUG +void LazyCallGraph::SCC::verify() { + assert(OuterRefSCC && "Can't have a null RefSCC!"); + assert(!Nodes.empty() && "Can't have an empty SCC!"); + + for (Node *N : Nodes) { + assert(N && "Can't have a null node!"); + assert(OuterRefSCC->G->lookupSCC(*N) == this && + "Node does not map to this SCC!"); + assert(N->DFSNumber == -1 && + "Must set DFS numbers to -1 when adding a node to an SCC!"); + assert(N->LowLink == -1 && + "Must set low link to -1 when adding a node to an SCC!"); + for (Edge &E : **N) + assert(E.getNode().isPopulated() && "Can't have an unpopulated node!"); + } +} +#endif + +bool LazyCallGraph::SCC::isParentOf(const SCC &C) const { + if (this == &C) + return false; + + for (Node &N : *this) + for (Edge &E : N->calls()) + if (OuterRefSCC->G->lookupSCC(E.getNode()) == &C) + return true; + + // No edges found. + return false; +} + +bool LazyCallGraph::SCC::isAncestorOf(const SCC &TargetC) const { + if (this == &TargetC) + return false; + + LazyCallGraph &G = *OuterRefSCC->G; + + // Start with this SCC. + SmallPtrSet<const SCC *, 16> Visited = {this}; + SmallVector<const SCC *, 16> Worklist = {this}; + + // Walk down the graph until we run out of edges or find a path to TargetC. + do { + const SCC &C = *Worklist.pop_back_val(); + for (Node &N : C) + for (Edge &E : N->calls()) { + SCC *CalleeC = G.lookupSCC(E.getNode()); + if (!CalleeC) + continue; + + // If the callee's SCC is the TargetC, we're done. + if (CalleeC == &TargetC) + return true; + + // If this is the first time we've reached this SCC, put it on the + // worklist to recurse through. + if (Visited.insert(CalleeC).second) + Worklist.push_back(CalleeC); + } + } while (!Worklist.empty()); + + // No paths found. + return false; +} + +LazyCallGraph::RefSCC::RefSCC(LazyCallGraph &G) : G(&G) {} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::RefSCC::dump() const { + dbgs() << *this << '\n'; +} +#endif + +#ifndef NDEBUG +void LazyCallGraph::RefSCC::verify() { + assert(G && "Can't have a null graph!"); + assert(!SCCs.empty() && "Can't have an empty SCC!"); + + // Verify basic properties of the SCCs. + SmallPtrSet<SCC *, 4> SCCSet; + for (SCC *C : SCCs) { + assert(C && "Can't have a null SCC!"); + C->verify(); + assert(&C->getOuterRefSCC() == this && + "SCC doesn't think it is inside this RefSCC!"); + bool Inserted = SCCSet.insert(C).second; + assert(Inserted && "Found a duplicate SCC!"); + auto IndexIt = SCCIndices.find(C); + assert(IndexIt != SCCIndices.end() && + "Found an SCC that doesn't have an index!"); + } + + // Check that our indices map correctly. + for (auto &SCCIndexPair : SCCIndices) { + SCC *C = SCCIndexPair.first; + int i = SCCIndexPair.second; + assert(C && "Can't have a null SCC in the indices!"); + assert(SCCSet.count(C) && "Found an index for an SCC not in the RefSCC!"); + assert(SCCs[i] == C && "Index doesn't point to SCC!"); + } + + // Check that the SCCs are in fact in post-order. + for (int i = 0, Size = SCCs.size(); i < Size; ++i) { + SCC &SourceSCC = *SCCs[i]; + for (Node &N : SourceSCC) + for (Edge &E : *N) { + if (!E.isCall()) + continue; + SCC &TargetSCC = *G->lookupSCC(E.getNode()); + if (&TargetSCC.getOuterRefSCC() == this) { + assert(SCCIndices.find(&TargetSCC)->second <= i && + "Edge between SCCs violates post-order relationship."); + continue; + } + } + } +} +#endif + +bool LazyCallGraph::RefSCC::isParentOf(const RefSCC &RC) const { + if (&RC == this) + return false; + + // Search all edges to see if this is a parent. + for (SCC &C : *this) + for (Node &N : C) + for (Edge &E : *N) + if (G->lookupRefSCC(E.getNode()) == &RC) + return true; + + return false; +} + +bool LazyCallGraph::RefSCC::isAncestorOf(const RefSCC &RC) const { + if (&RC == this) + return false; + + // For each descendant of this RefSCC, see if one of its children is the + // argument. If not, add that descendant to the worklist and continue + // searching. + SmallVector<const RefSCC *, 4> Worklist; + SmallPtrSet<const RefSCC *, 4> Visited; + Worklist.push_back(this); + Visited.insert(this); + do { + const RefSCC &DescendantRC = *Worklist.pop_back_val(); + for (SCC &C : DescendantRC) + for (Node &N : C) + for (Edge &E : *N) { + auto *ChildRC = G->lookupRefSCC(E.getNode()); + if (ChildRC == &RC) + return true; + if (!ChildRC || !Visited.insert(ChildRC).second) + continue; + Worklist.push_back(ChildRC); + } + } while (!Worklist.empty()); + + return false; +} + +/// Generic helper that updates a postorder sequence of SCCs for a potentially +/// cycle-introducing edge insertion. +/// +/// A postorder sequence of SCCs of a directed graph has one fundamental +/// property: all deges in the DAG of SCCs point "up" the sequence. That is, +/// all edges in the SCC DAG point to prior SCCs in the sequence. +/// +/// This routine both updates a postorder sequence and uses that sequence to +/// compute the set of SCCs connected into a cycle. It should only be called to +/// insert a "downward" edge which will require changing the sequence to +/// restore it to a postorder. +/// +/// When inserting an edge from an earlier SCC to a later SCC in some postorder +/// sequence, all of the SCCs which may be impacted are in the closed range of +/// those two within the postorder sequence. The algorithm used here to restore +/// the state is as follows: +/// +/// 1) Starting from the source SCC, construct a set of SCCs which reach the +/// source SCC consisting of just the source SCC. Then scan toward the +/// target SCC in postorder and for each SCC, if it has an edge to an SCC +/// in the set, add it to the set. Otherwise, the source SCC is not +/// a successor, move it in the postorder sequence to immediately before +/// the source SCC, shifting the source SCC and all SCCs in the set one +/// position toward the target SCC. Stop scanning after processing the +/// target SCC. +/// 2) If the source SCC is now past the target SCC in the postorder sequence, +/// and thus the new edge will flow toward the start, we are done. +/// 3) Otherwise, starting from the target SCC, walk all edges which reach an +/// SCC between the source and the target, and add them to the set of +/// connected SCCs, then recurse through them. Once a complete set of the +/// SCCs the target connects to is known, hoist the remaining SCCs between +/// the source and the target to be above the target. Note that there is no +/// need to process the source SCC, it is already known to connect. +/// 4) At this point, all of the SCCs in the closed range between the source +/// SCC and the target SCC in the postorder sequence are connected, +/// including the target SCC and the source SCC. Inserting the edge from +/// the source SCC to the target SCC will form a cycle out of precisely +/// these SCCs. Thus we can merge all of the SCCs in this closed range into +/// a single SCC. +/// +/// This process has various important properties: +/// - Only mutates the SCCs when adding the edge actually changes the SCC +/// structure. +/// - Never mutates SCCs which are unaffected by the change. +/// - Updates the postorder sequence to correctly satisfy the postorder +/// constraint after the edge is inserted. +/// - Only reorders SCCs in the closed postorder sequence from the source to +/// the target, so easy to bound how much has changed even in the ordering. +/// - Big-O is the number of edges in the closed postorder range of SCCs from +/// source to target. +/// +/// This helper routine, in addition to updating the postorder sequence itself +/// will also update a map from SCCs to indices within that sequence. +/// +/// The sequence and the map must operate on pointers to the SCC type. +/// +/// Two callbacks must be provided. The first computes the subset of SCCs in +/// the postorder closed range from the source to the target which connect to +/// the source SCC via some (transitive) set of edges. The second computes the +/// subset of the same range which the target SCC connects to via some +/// (transitive) set of edges. Both callbacks should populate the set argument +/// provided. +template <typename SCCT, typename PostorderSequenceT, typename SCCIndexMapT, + typename ComputeSourceConnectedSetCallableT, + typename ComputeTargetConnectedSetCallableT> +static iterator_range<typename PostorderSequenceT::iterator> +updatePostorderSequenceForEdgeInsertion( + SCCT &SourceSCC, SCCT &TargetSCC, PostorderSequenceT &SCCs, + SCCIndexMapT &SCCIndices, + ComputeSourceConnectedSetCallableT ComputeSourceConnectedSet, + ComputeTargetConnectedSetCallableT ComputeTargetConnectedSet) { + int SourceIdx = SCCIndices[&SourceSCC]; + int TargetIdx = SCCIndices[&TargetSCC]; + assert(SourceIdx < TargetIdx && "Cannot have equal indices here!"); + + SmallPtrSet<SCCT *, 4> ConnectedSet; + + // Compute the SCCs which (transitively) reach the source. + ComputeSourceConnectedSet(ConnectedSet); + + // Partition the SCCs in this part of the port-order sequence so only SCCs + // connecting to the source remain between it and the target. This is + // a benign partition as it preserves postorder. + auto SourceI = std::stable_partition( + SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx + 1, + [&ConnectedSet](SCCT *C) { return !ConnectedSet.count(C); }); + for (int i = SourceIdx, e = TargetIdx + 1; i < e; ++i) + SCCIndices.find(SCCs[i])->second = i; + + // If the target doesn't connect to the source, then we've corrected the + // post-order and there are no cycles formed. + if (!ConnectedSet.count(&TargetSCC)) { + assert(SourceI > (SCCs.begin() + SourceIdx) && + "Must have moved the source to fix the post-order."); + assert(*std::prev(SourceI) == &TargetSCC && + "Last SCC to move should have bene the target."); + + // Return an empty range at the target SCC indicating there is nothing to + // merge. + return make_range(std::prev(SourceI), std::prev(SourceI)); + } + + assert(SCCs[TargetIdx] == &TargetSCC && + "Should not have moved target if connected!"); + SourceIdx = SourceI - SCCs.begin(); + assert(SCCs[SourceIdx] == &SourceSCC && + "Bad updated index computation for the source SCC!"); + + + // See whether there are any remaining intervening SCCs between the source + // and target. If so we need to make sure they all are reachable form the + // target. + if (SourceIdx + 1 < TargetIdx) { + ConnectedSet.clear(); + ComputeTargetConnectedSet(ConnectedSet); + + // Partition SCCs so that only SCCs reached from the target remain between + // the source and the target. This preserves postorder. + auto TargetI = std::stable_partition( + SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1, + [&ConnectedSet](SCCT *C) { return ConnectedSet.count(C); }); + for (int i = SourceIdx + 1, e = TargetIdx + 1; i < e; ++i) + SCCIndices.find(SCCs[i])->second = i; + TargetIdx = std::prev(TargetI) - SCCs.begin(); + assert(SCCs[TargetIdx] == &TargetSCC && + "Should always end with the target!"); + } + + // At this point, we know that connecting source to target forms a cycle + // because target connects back to source, and we know that all of the SCCs + // between the source and target in the postorder sequence participate in that + // cycle. + return make_range(SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx); +} + +bool +LazyCallGraph::RefSCC::switchInternalEdgeToCall( + Node &SourceN, Node &TargetN, + function_ref<void(ArrayRef<SCC *> MergeSCCs)> MergeCB) { + assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); + SmallVector<SCC *, 1> DeletedSCCs; + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and when this + // routine finishes. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + + SCC &SourceSCC = *G->lookupSCC(SourceN); + SCC &TargetSCC = *G->lookupSCC(TargetN); + + // If the two nodes are already part of the same SCC, we're also done as + // we've just added more connectivity. + if (&SourceSCC == &TargetSCC) { + SourceN->setEdgeKind(TargetN, Edge::Call); + return false; // No new cycle. + } + + // At this point we leverage the postorder list of SCCs to detect when the + // insertion of an edge changes the SCC structure in any way. + // + // First and foremost, we can eliminate the need for any changes when the + // edge is toward the beginning of the postorder sequence because all edges + // flow in that direction already. Thus adding a new one cannot form a cycle. + int SourceIdx = SCCIndices[&SourceSCC]; + int TargetIdx = SCCIndices[&TargetSCC]; + if (TargetIdx < SourceIdx) { + SourceN->setEdgeKind(TargetN, Edge::Call); + return false; // No new cycle. + } + + // Compute the SCCs which (transitively) reach the source. + auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) { +#ifndef NDEBUG + // Check that the RefSCC is still valid before computing this as the + // results will be nonsensical of we've broken its invariants. + verify(); +#endif + ConnectedSet.insert(&SourceSCC); + auto IsConnected = [&](SCC &C) { + for (Node &N : C) + for (Edge &E : N->calls()) + if (ConnectedSet.count(G->lookupSCC(E.getNode()))) + return true; + + return false; + }; + + for (SCC *C : + make_range(SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1)) + if (IsConnected(*C)) + ConnectedSet.insert(C); + }; + + // Use a normal worklist to find which SCCs the target connects to. We still + // bound the search based on the range in the postorder list we care about, + // but because this is forward connectivity we just "recurse" through the + // edges. + auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) { +#ifndef NDEBUG + // Check that the RefSCC is still valid before computing this as the + // results will be nonsensical of we've broken its invariants. + verify(); +#endif + ConnectedSet.insert(&TargetSCC); + SmallVector<SCC *, 4> Worklist; + Worklist.push_back(&TargetSCC); + do { + SCC &C = *Worklist.pop_back_val(); + for (Node &N : C) + for (Edge &E : *N) { + if (!E.isCall()) + continue; + SCC &EdgeC = *G->lookupSCC(E.getNode()); + if (&EdgeC.getOuterRefSCC() != this) + // Not in this RefSCC... + continue; + if (SCCIndices.find(&EdgeC)->second <= SourceIdx) + // Not in the postorder sequence between source and target. + continue; + + if (ConnectedSet.insert(&EdgeC).second) + Worklist.push_back(&EdgeC); + } + } while (!Worklist.empty()); + }; + + // Use a generic helper to update the postorder sequence of SCCs and return + // a range of any SCCs connected into a cycle by inserting this edge. This + // routine will also take care of updating the indices into the postorder + // sequence. + auto MergeRange = updatePostorderSequenceForEdgeInsertion( + SourceSCC, TargetSCC, SCCs, SCCIndices, ComputeSourceConnectedSet, + ComputeTargetConnectedSet); + + // Run the user's callback on the merged SCCs before we actually merge them. + if (MergeCB) + MergeCB(makeArrayRef(MergeRange.begin(), MergeRange.end())); + + // If the merge range is empty, then adding the edge didn't actually form any + // new cycles. We're done. + if (MergeRange.empty()) { + // Now that the SCC structure is finalized, flip the kind to call. + SourceN->setEdgeKind(TargetN, Edge::Call); + return false; // No new cycle. + } + +#ifndef NDEBUG + // Before merging, check that the RefSCC remains valid after all the + // postorder updates. + verify(); +#endif + + // Otherwise we need to merge all of the SCCs in the cycle into a single + // result SCC. + // + // NB: We merge into the target because all of these functions were already + // reachable from the target, meaning any SCC-wide properties deduced about it + // other than the set of functions within it will not have changed. + for (SCC *C : MergeRange) { + assert(C != &TargetSCC && + "We merge *into* the target and shouldn't process it here!"); + SCCIndices.erase(C); + TargetSCC.Nodes.append(C->Nodes.begin(), C->Nodes.end()); + for (Node *N : C->Nodes) + G->SCCMap[N] = &TargetSCC; + C->clear(); + DeletedSCCs.push_back(C); + } + + // Erase the merged SCCs from the list and update the indices of the + // remaining SCCs. + int IndexOffset = MergeRange.end() - MergeRange.begin(); + auto EraseEnd = SCCs.erase(MergeRange.begin(), MergeRange.end()); + for (SCC *C : make_range(EraseEnd, SCCs.end())) + SCCIndices[C] -= IndexOffset; + + // Now that the SCC structure is finalized, flip the kind to call. + SourceN->setEdgeKind(TargetN, Edge::Call); + + // And we're done, but we did form a new cycle. + return true; +} + +void LazyCallGraph::RefSCC::switchTrivialInternalEdgeToRef(Node &SourceN, + Node &TargetN) { + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and when this + // routine finishes. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + + assert(G->lookupRefSCC(SourceN) == this && + "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) == this && + "Target must be in this RefSCC."); + assert(G->lookupSCC(SourceN) != G->lookupSCC(TargetN) && + "Source and Target must be in separate SCCs for this to be trivial!"); + + // Set the edge kind. + SourceN->setEdgeKind(TargetN, Edge::Ref); +} + +iterator_range<LazyCallGraph::RefSCC::iterator> +LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and when this + // routine finishes. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + + assert(G->lookupRefSCC(SourceN) == this && + "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) == this && + "Target must be in this RefSCC."); + + SCC &TargetSCC = *G->lookupSCC(TargetN); + assert(G->lookupSCC(SourceN) == &TargetSCC && "Source and Target must be in " + "the same SCC to require the " + "full CG update."); + + // Set the edge kind. + SourceN->setEdgeKind(TargetN, Edge::Ref); + + // Otherwise we are removing a call edge from a single SCC. This may break + // the cycle. In order to compute the new set of SCCs, we need to do a small + // DFS over the nodes within the SCC to form any sub-cycles that remain as + // distinct SCCs and compute a postorder over the resulting SCCs. + // + // However, we specially handle the target node. The target node is known to + // reach all other nodes in the original SCC by definition. This means that + // we want the old SCC to be replaced with an SCC containing that node as it + // will be the root of whatever SCC DAG results from the DFS. Assumptions + // about an SCC such as the set of functions called will continue to hold, + // etc. + + SCC &OldSCC = TargetSCC; + SmallVector<std::pair<Node *, EdgeSequence::call_iterator>, 16> DFSStack; + SmallVector<Node *, 16> PendingSCCStack; + SmallVector<SCC *, 4> NewSCCs; + + // Prepare the nodes for a fresh DFS. + SmallVector<Node *, 16> Worklist; + Worklist.swap(OldSCC.Nodes); + for (Node *N : Worklist) { + N->DFSNumber = N->LowLink = 0; + G->SCCMap.erase(N); + } + + // Force the target node to be in the old SCC. This also enables us to take + // a very significant short-cut in the standard Tarjan walk to re-form SCCs + // below: whenever we build an edge that reaches the target node, we know + // that the target node eventually connects back to all other nodes in our + // walk. As a consequence, we can detect and handle participants in that + // cycle without walking all the edges that form this connection, and instead + // by relying on the fundamental guarantee coming into this operation (all + // nodes are reachable from the target due to previously forming an SCC). + TargetN.DFSNumber = TargetN.LowLink = -1; + OldSCC.Nodes.push_back(&TargetN); + G->SCCMap[&TargetN] = &OldSCC; + + // Scan down the stack and DFS across the call edges. + for (Node *RootN : Worklist) { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; + } + + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; + + DFSStack.push_back({RootN, (*RootN)->call_begin()}); + do { + Node *N; + EdgeSequence::call_iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = (*N)->call_end(); + while (I != E) { + Node &ChildN = I->getNode(); + if (ChildN.DFSNumber == 0) { + // We haven't yet visited this child, so descend, pushing the current + // node onto the stack. + DFSStack.push_back({N, I}); + + assert(!G->SCCMap.count(&ChildN) && + "Found a node with 0 DFS number but already in an SCC!"); + ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; + N = &ChildN; + I = (*N)->call_begin(); + E = (*N)->call_end(); + continue; + } + + // Check for the child already being part of some component. + if (ChildN.DFSNumber == -1) { + if (G->lookupSCC(ChildN) == &OldSCC) { + // If the child is part of the old SCC, we know that it can reach + // every other node, so we have formed a cycle. Pull the entire DFS + // and pending stacks into it. See the comment above about setting + // up the old SCC for why we do this. + int OldSize = OldSCC.size(); + OldSCC.Nodes.push_back(N); + OldSCC.Nodes.append(PendingSCCStack.begin(), PendingSCCStack.end()); + PendingSCCStack.clear(); + while (!DFSStack.empty()) + OldSCC.Nodes.push_back(DFSStack.pop_back_val().first); + for (Node &N : make_range(OldSCC.begin() + OldSize, OldSCC.end())) { + N.DFSNumber = N.LowLink = -1; + G->SCCMap[&N] = &OldSCC; + } + N = nullptr; + break; + } + + // If the child has already been added to some child component, it + // couldn't impact the low-link of this parent because it isn't + // connected, and thus its low-link isn't relevant so skip it. + ++I; + continue; + } + + // Track the lowest linked child as the lowest link for this node. + assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); + if (ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + + // Move to the next edge. + ++I; + } + if (!N) + // Cleared the DFS early, start another round. + break; + + // We've finished processing N and its descendants, put it on our pending + // SCC stack to eventually get merged into an SCC of nodes. + PendingSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) + continue; + + // Otherwise, we've completed an SCC. Append it to our post order list of + // SCCs. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto SCCNodes = make_range( + PendingSCCStack.rbegin(), + find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) { + return N->DFSNumber < RootDFSNumber; + })); + + // Form a new SCC out of these nodes and then clear them off our pending + // stack. + NewSCCs.push_back(G->createSCC(*this, SCCNodes)); + for (Node &N : *NewSCCs.back()) { + N.DFSNumber = N.LowLink = -1; + G->SCCMap[&N] = NewSCCs.back(); + } + PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); + } while (!DFSStack.empty()); + } + + // Insert the remaining SCCs before the old one. The old SCC can reach all + // other SCCs we form because it contains the target node of the removed edge + // of the old SCC. This means that we will have edges into all of the new + // SCCs, which means the old one must come last for postorder. + int OldIdx = SCCIndices[&OldSCC]; + SCCs.insert(SCCs.begin() + OldIdx, NewSCCs.begin(), NewSCCs.end()); + + // Update the mapping from SCC* to index to use the new SCC*s, and remove the + // old SCC from the mapping. + for (int Idx = OldIdx, Size = SCCs.size(); Idx < Size; ++Idx) + SCCIndices[SCCs[Idx]] = Idx; + + return make_range(SCCs.begin() + OldIdx, + SCCs.begin() + OldIdx + NewSCCs.size()); +} + +void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN, + Node &TargetN) { + assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); + + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) != this && + "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS + assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && + "Target must be a descendant of the Source."); +#endif + + // Edges between RefSCCs are the same regardless of call or ref, so we can + // just flip the edge here. + SourceN->setEdgeKind(TargetN, Edge::Call); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::switchOutgoingEdgeToRef(Node &SourceN, + Node &TargetN) { + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) != this && + "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS + assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && + "Target must be a descendant of the Source."); +#endif + + // Edges between RefSCCs are the same regardless of call or ref, so we can + // just flip the edge here. + SourceN->setEdgeKind(TargetN, Edge::Ref); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN, + Node &TargetN) { + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); + + SourceN->insertEdgeInternal(TargetN, Edge::Ref); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertOutgoingEdge(Node &SourceN, Node &TargetN, + Edge::Kind EK) { + // First insert it into the caller. + SourceN->insertEdgeInternal(TargetN, EK); + + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + + assert(G->lookupRefSCC(TargetN) != this && + "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS + assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && + "Target must be a descendant of the Source."); +#endif + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { + assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); + RefSCC &SourceC = *G->lookupRefSCC(SourceN); + assert(&SourceC != this && "Source must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS + assert(SourceC.isDescendantOf(*this) && + "Source must be a descendant of the Target."); +#endif + + SmallVector<RefSCC *, 1> DeletedRefSCCs; + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and when this + // routine finishes. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + + int SourceIdx = G->RefSCCIndices[&SourceC]; + int TargetIdx = G->RefSCCIndices[this]; + assert(SourceIdx < TargetIdx && + "Postorder list doesn't see edge as incoming!"); + + // Compute the RefSCCs which (transitively) reach the source. We do this by + // working backwards from the source using the parent set in each RefSCC, + // skipping any RefSCCs that don't fall in the postorder range. This has the + // advantage of walking the sparser parent edge (in high fan-out graphs) but + // more importantly this removes examining all forward edges in all RefSCCs + // within the postorder range which aren't in fact connected. Only connected + // RefSCCs (and their edges) are visited here. + auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) { + Set.insert(&SourceC); + auto IsConnected = [&](RefSCC &RC) { + for (SCC &C : RC) + for (Node &N : C) + for (Edge &E : *N) + if (Set.count(G->lookupRefSCC(E.getNode()))) + return true; + + return false; + }; + + for (RefSCC *C : make_range(G->PostOrderRefSCCs.begin() + SourceIdx + 1, + G->PostOrderRefSCCs.begin() + TargetIdx + 1)) + if (IsConnected(*C)) + Set.insert(C); + }; + + // Use a normal worklist to find which SCCs the target connects to. We still + // bound the search based on the range in the postorder list we care about, + // but because this is forward connectivity we just "recurse" through the + // edges. + auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) { + Set.insert(this); + SmallVector<RefSCC *, 4> Worklist; + Worklist.push_back(this); + do { + RefSCC &RC = *Worklist.pop_back_val(); + for (SCC &C : RC) + for (Node &N : C) + for (Edge &E : *N) { + RefSCC &EdgeRC = *G->lookupRefSCC(E.getNode()); + if (G->getRefSCCIndex(EdgeRC) <= SourceIdx) + // Not in the postorder sequence between source and target. + continue; + + if (Set.insert(&EdgeRC).second) + Worklist.push_back(&EdgeRC); + } + } while (!Worklist.empty()); + }; + + // Use a generic helper to update the postorder sequence of RefSCCs and return + // a range of any RefSCCs connected into a cycle by inserting this edge. This + // routine will also take care of updating the indices into the postorder + // sequence. + iterator_range<SmallVectorImpl<RefSCC *>::iterator> MergeRange = + updatePostorderSequenceForEdgeInsertion( + SourceC, *this, G->PostOrderRefSCCs, G->RefSCCIndices, + ComputeSourceConnectedSet, ComputeTargetConnectedSet); + + // Build a set so we can do fast tests for whether a RefSCC will end up as + // part of the merged RefSCC. + SmallPtrSet<RefSCC *, 16> MergeSet(MergeRange.begin(), MergeRange.end()); + + // This RefSCC will always be part of that set, so just insert it here. + MergeSet.insert(this); + + // Now that we have identified all of the SCCs which need to be merged into + // a connected set with the inserted edge, merge all of them into this SCC. + SmallVector<SCC *, 16> MergedSCCs; + int SCCIndex = 0; + for (RefSCC *RC : MergeRange) { + assert(RC != this && "We're merging into the target RefSCC, so it " + "shouldn't be in the range."); + + // Walk the inner SCCs to update their up-pointer and walk all the edges to + // update any parent sets. + // FIXME: We should try to find a way to avoid this (rather expensive) edge + // walk by updating the parent sets in some other manner. + for (SCC &InnerC : *RC) { + InnerC.OuterRefSCC = this; + SCCIndices[&InnerC] = SCCIndex++; + for (Node &N : InnerC) + G->SCCMap[&N] = &InnerC; + } + + // Now merge in the SCCs. We can actually move here so try to reuse storage + // the first time through. + if (MergedSCCs.empty()) + MergedSCCs = std::move(RC->SCCs); + else + MergedSCCs.append(RC->SCCs.begin(), RC->SCCs.end()); + RC->SCCs.clear(); + DeletedRefSCCs.push_back(RC); + } + + // Append our original SCCs to the merged list and move it into place. + for (SCC &InnerC : *this) + SCCIndices[&InnerC] = SCCIndex++; + MergedSCCs.append(SCCs.begin(), SCCs.end()); + SCCs = std::move(MergedSCCs); + + // Remove the merged away RefSCCs from the post order sequence. + for (RefSCC *RC : MergeRange) + G->RefSCCIndices.erase(RC); + int IndexOffset = MergeRange.end() - MergeRange.begin(); + auto EraseEnd = + G->PostOrderRefSCCs.erase(MergeRange.begin(), MergeRange.end()); + for (RefSCC *RC : make_range(EraseEnd, G->PostOrderRefSCCs.end())) + G->RefSCCIndices[RC] -= IndexOffset; + + // At this point we have a merged RefSCC with a post-order SCCs list, just + // connect the nodes to form the new edge. + SourceN->insertEdgeInternal(TargetN, Edge::Ref); + + // We return the list of SCCs which were merged so that callers can + // invalidate any data they have associated with those SCCs. Note that these + // SCCs are no longer in an interesting state (they are totally empty) but + // the pointers will remain stable for the life of the graph itself. + return DeletedRefSCCs; +} + +void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) { + assert(G->lookupRefSCC(SourceN) == this && + "The source must be a member of this RefSCC."); + assert(G->lookupRefSCC(TargetN) != this && + "The target must not be a member of this RefSCC"); + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and when this + // routine finishes. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + + // First remove it from the node. + bool Removed = SourceN->removeEdgeInternal(TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); +} + +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, + ArrayRef<Node *> TargetNs) { + // We return a list of the resulting *new* RefSCCs in post-order. + SmallVector<RefSCC *, 1> Result; + +#ifndef NDEBUG + // In a debug build, verify the RefSCC is valid to start with and that either + // we return an empty list of result RefSCCs and this RefSCC remains valid, + // or we return new RefSCCs and this RefSCC is dead. + verify(); + auto VerifyOnExit = make_scope_exit([&]() { + // If we didn't replace our RefSCC with new ones, check that this one + // remains valid. + if (G) + verify(); + }); +#endif + + // First remove the actual edges. + for (Node *TargetN : TargetNs) { + assert(!(*SourceN)[*TargetN].isCall() && + "Cannot remove a call edge, it must first be made a ref edge"); + + bool Removed = SourceN->removeEdgeInternal(*TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); + } + + // Direct self references don't impact the ref graph at all. + if (llvm::all_of(TargetNs, + [&](Node *TargetN) { return &SourceN == TargetN; })) + return Result; + + // If all targets are in the same SCC as the source, because no call edges + // were removed there is no RefSCC structure change. + SCC &SourceC = *G->lookupSCC(SourceN); + if (llvm::all_of(TargetNs, [&](Node *TargetN) { + return G->lookupSCC(*TargetN) == &SourceC; + })) + return Result; + + // We build somewhat synthetic new RefSCCs by providing a postorder mapping + // for each inner SCC. We store these inside the low-link field of the nodes + // rather than associated with SCCs because this saves a round-trip through + // the node->SCC map and in the common case, SCCs are small. We will verify + // that we always give the same number to every node in the SCC such that + // these are equivalent. + int PostOrderNumber = 0; + + // Reset all the other nodes to prepare for a DFS over them, and add them to + // our worklist. + SmallVector<Node *, 8> Worklist; + for (SCC *C : SCCs) { + for (Node &N : *C) + N.DFSNumber = N.LowLink = 0; + + Worklist.append(C->Nodes.begin(), C->Nodes.end()); + } + + // Track the number of nodes in this RefSCC so that we can quickly recognize + // an important special case of the edge removal not breaking the cycle of + // this RefSCC. + const int NumRefSCCNodes = Worklist.size(); + + SmallVector<std::pair<Node *, EdgeSequence::iterator>, 4> DFSStack; + SmallVector<Node *, 4> PendingRefSCCStack; + do { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingRefSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + Node *RootN = Worklist.pop_back_val(); + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; + } + + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; + + DFSStack.push_back({RootN, (*RootN)->begin()}); + do { + Node *N; + EdgeSequence::iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = (*N)->end(); + + assert(N->DFSNumber != 0 && "We should always assign a DFS number " + "before processing a node."); + + while (I != E) { + Node &ChildN = I->getNode(); + if (ChildN.DFSNumber == 0) { + // Mark that we should start at this child when next this node is the + // top of the stack. We don't start at the next child to ensure this + // child's lowlink is reflected. + DFSStack.push_back({N, I}); + + // Continue, resetting to the child node. + ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; + N = &ChildN; + I = ChildN->begin(); + E = ChildN->end(); + continue; + } + if (ChildN.DFSNumber == -1) { + // If this child isn't currently in this RefSCC, no need to process + // it. + ++I; + continue; + } + + // Track the lowest link of the children, if any are still in the stack. + // Any child not on the stack will have a LowLink of -1. + assert(ChildN.LowLink != 0 && + "Low-link must not be zero with a non-zero DFS number."); + if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + ++I; + } + + // We've finished processing N and its descendants, put it on our pending + // stack to eventually get merged into a RefSCC. + PendingRefSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) { + assert(!DFSStack.empty() && + "We never found a viable root for a RefSCC to pop off!"); + continue; + } + + // Otherwise, form a new RefSCC from the top of the pending node stack. + int RefSCCNumber = PostOrderNumber++; + int RootDFSNumber = N->DFSNumber; + + // Find the range of the node stack by walking down until we pass the + // root DFS number. Update the DFS numbers and low link numbers in the + // process to avoid re-walking this list where possible. + auto StackRI = find_if(reverse(PendingRefSCCStack), [&](Node *N) { + if (N->DFSNumber < RootDFSNumber) + // We've found the bottom. + return true; + + // Update this node and keep scanning. + N->DFSNumber = -1; + // Save the post-order number in the lowlink field so that we can use + // it to map SCCs into new RefSCCs after we finish the DFS. + N->LowLink = RefSCCNumber; + return false; + }); + auto RefSCCNodes = make_range(StackRI.base(), PendingRefSCCStack.end()); + + // If we find a cycle containing all nodes originally in this RefSCC then + // the removal hasn't changed the structure at all. This is an important + // special case and we can directly exit the entire routine more + // efficiently as soon as we discover it. + if (llvm::size(RefSCCNodes) == NumRefSCCNodes) { + // Clear out the low link field as we won't need it. + for (Node *N : RefSCCNodes) + N->LowLink = -1; + // Return the empty result immediately. + return Result; + } + + // We've already marked the nodes internally with the RefSCC number so + // just clear them off the stack and continue. + PendingRefSCCStack.erase(RefSCCNodes.begin(), PendingRefSCCStack.end()); + } while (!DFSStack.empty()); + + assert(DFSStack.empty() && "Didn't flush the entire DFS stack!"); + assert(PendingRefSCCStack.empty() && "Didn't flush all pending nodes!"); + } while (!Worklist.empty()); + + assert(PostOrderNumber > 1 && + "Should never finish the DFS when the existing RefSCC remains valid!"); + + // Otherwise we create a collection of new RefSCC nodes and build + // a radix-sort style map from postorder number to these new RefSCCs. We then + // append SCCs to each of these RefSCCs in the order they occurred in the + // original SCCs container. + for (int i = 0; i < PostOrderNumber; ++i) + Result.push_back(G->createRefSCC(*G)); + + // Insert the resulting postorder sequence into the global graph postorder + // sequence before the current RefSCC in that sequence, and then remove the + // current one. + // + // FIXME: It'd be nice to change the APIs so that we returned an iterator + // range over the global postorder sequence and generally use that sequence + // rather than building a separate result vector here. + int Idx = G->getRefSCCIndex(*this); + G->PostOrderRefSCCs.erase(G->PostOrderRefSCCs.begin() + Idx); + G->PostOrderRefSCCs.insert(G->PostOrderRefSCCs.begin() + Idx, Result.begin(), + Result.end()); + for (int i : seq<int>(Idx, G->PostOrderRefSCCs.size())) + G->RefSCCIndices[G->PostOrderRefSCCs[i]] = i; + + for (SCC *C : SCCs) { + // We store the SCC number in the node's low-link field above. + int SCCNumber = C->begin()->LowLink; + // Clear out all of the SCC's node's low-link fields now that we're done + // using them as side-storage. + for (Node &N : *C) { + assert(N.LowLink == SCCNumber && + "Cannot have different numbers for nodes in the same SCC!"); + N.LowLink = -1; + } + + RefSCC &RC = *Result[SCCNumber]; + int SCCIndex = RC.SCCs.size(); + RC.SCCs.push_back(C); + RC.SCCIndices[C] = SCCIndex; + C->OuterRefSCC = &RC; + } + + // Now that we've moved things into the new RefSCCs, clear out our current + // one. + G = nullptr; + SCCs.clear(); + SCCIndices.clear(); + +#ifndef NDEBUG + // Verify the new RefSCCs we've built. + for (RefSCC *RC : Result) + RC->verify(); +#endif + + // Return the new list of SCCs. + return Result; +} + +void LazyCallGraph::RefSCC::handleTrivialEdgeInsertion(Node &SourceN, + Node &TargetN) { + // The only trivial case that requires any graph updates is when we add new + // ref edge and may connect different RefSCCs along that path. This is only + // because of the parents set. Every other part of the graph remains constant + // after this edge insertion. + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + RefSCC &TargetRC = *G->lookupRefSCC(TargetN); + if (&TargetRC == this) + return; + +#ifdef EXPENSIVE_CHECKS + assert(TargetRC.isDescendantOf(*this) && + "Target must be a descendant of the Source."); +#endif +} + +void LazyCallGraph::RefSCC::insertTrivialCallEdge(Node &SourceN, + Node &TargetN) { +#ifndef NDEBUG + // Check that the RefSCC is still valid when we finish. + auto ExitVerifier = make_scope_exit([this] { verify(); }); + +#ifdef EXPENSIVE_CHECKS + // Check that we aren't breaking some invariants of the SCC graph. Note that + // this is quadratic in the number of edges in the call graph! + SCC &SourceC = *G->lookupSCC(SourceN); + SCC &TargetC = *G->lookupSCC(TargetN); + if (&SourceC != &TargetC) + assert(SourceC.isAncestorOf(TargetC) && + "Call edge is not trivial in the SCC graph!"); +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + + // First insert it into the source or find the existing edge. + auto InsertResult = + SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); + if (!InsertResult.second) { + // Already an edge, just update it. + Edge &E = SourceN->Edges[InsertResult.first->second]; + if (E.isCall()) + return; // Nothing to do! + E.setKind(Edge::Call); + } else { + // Create the new edge. + SourceN->Edges.emplace_back(TargetN, Edge::Call); + } + + // Now that we have the edge, handle the graph fallout. + handleTrivialEdgeInsertion(SourceN, TargetN); +} + +void LazyCallGraph::RefSCC::insertTrivialRefEdge(Node &SourceN, Node &TargetN) { +#ifndef NDEBUG + // Check that the RefSCC is still valid when we finish. + auto ExitVerifier = make_scope_exit([this] { verify(); }); + +#ifdef EXPENSIVE_CHECKS + // Check that we aren't breaking some invariants of the RefSCC graph. + RefSCC &SourceRC = *G->lookupRefSCC(SourceN); + RefSCC &TargetRC = *G->lookupRefSCC(TargetN); + if (&SourceRC != &TargetRC) + assert(SourceRC.isAncestorOf(TargetRC) && + "Ref edge is not trivial in the RefSCC graph!"); +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + + // First insert it into the source or find the existing edge. + auto InsertResult = + SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); + if (!InsertResult.second) + // Already an edge, we're done. + return; + + // Create the new edge. + SourceN->Edges.emplace_back(TargetN, Edge::Ref); + + // Now that we have the edge, handle the graph fallout. + handleTrivialEdgeInsertion(SourceN, TargetN); +} + +void LazyCallGraph::RefSCC::replaceNodeFunction(Node &N, Function &NewF) { + Function &OldF = N.getFunction(); + +#ifndef NDEBUG + // Check that the RefSCC is still valid when we finish. + auto ExitVerifier = make_scope_exit([this] { verify(); }); + + assert(G->lookupRefSCC(N) == this && + "Cannot replace the function of a node outside this RefSCC."); + + assert(G->NodeMap.find(&NewF) == G->NodeMap.end() && + "Must not have already walked the new function!'"); + + // It is important that this replacement not introduce graph changes so we + // insist that the caller has already removed every use of the original + // function and that all uses of the new function correspond to existing + // edges in the graph. The common and expected way to use this is when + // replacing the function itself in the IR without changing the call graph + // shape and just updating the analysis based on that. + assert(&OldF != &NewF && "Cannot replace a function with itself!"); + assert(OldF.use_empty() && + "Must have moved all uses from the old function to the new!"); +#endif + + N.replaceFunction(NewF); + + // Update various call graph maps. + G->NodeMap.erase(&OldF); + G->NodeMap[&NewF] = &N; +} + +void LazyCallGraph::insertEdge(Node &SourceN, Node &TargetN, Edge::Kind EK) { + assert(SCCMap.empty() && + "This method cannot be called after SCCs have been formed!"); + + return SourceN->insertEdgeInternal(TargetN, EK); +} + +void LazyCallGraph::removeEdge(Node &SourceN, Node &TargetN) { + assert(SCCMap.empty() && + "This method cannot be called after SCCs have been formed!"); + + bool Removed = SourceN->removeEdgeInternal(TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); +} + +void LazyCallGraph::removeDeadFunction(Function &F) { + // FIXME: This is unnecessarily restrictive. We should be able to remove + // functions which recursively call themselves. + assert(F.use_empty() && + "This routine should only be called on trivially dead functions!"); + + // We shouldn't remove library functions as they are never really dead while + // the call graph is in use -- every function definition refers to them. + assert(!isLibFunction(F) && + "Must not remove lib functions from the call graph!"); + + auto NI = NodeMap.find(&F); + if (NI == NodeMap.end()) + // Not in the graph at all! + return; + + Node &N = *NI->second; + NodeMap.erase(NI); + + // Remove this from the entry edges if present. + EntryEdges.removeEdgeInternal(N); + + if (SCCMap.empty()) { + // No SCCs have been formed, so removing this is fine and there is nothing + // else necessary at this point but clearing out the node. + N.clear(); + return; + } + + // Cannot remove a function which has yet to be visited in the DFS walk, so + // if we have a node at all then we must have an SCC and RefSCC. + auto CI = SCCMap.find(&N); + assert(CI != SCCMap.end() && + "Tried to remove a node without an SCC after DFS walk started!"); + SCC &C = *CI->second; + SCCMap.erase(CI); + RefSCC &RC = C.getOuterRefSCC(); + + // This node must be the only member of its SCC as it has no callers, and + // that SCC must be the only member of a RefSCC as it has no references. + // Validate these properties first. + assert(C.size() == 1 && "Dead functions must be in a singular SCC"); + assert(RC.size() == 1 && "Dead functions must be in a singular RefSCC"); + + auto RCIndexI = RefSCCIndices.find(&RC); + int RCIndex = RCIndexI->second; + PostOrderRefSCCs.erase(PostOrderRefSCCs.begin() + RCIndex); + RefSCCIndices.erase(RCIndexI); + for (int i = RCIndex, Size = PostOrderRefSCCs.size(); i < Size; ++i) + RefSCCIndices[PostOrderRefSCCs[i]] = i; + + // Finally clear out all the data structures from the node down through the + // components. + N.clear(); + N.G = nullptr; + N.F = nullptr; + C.clear(); + RC.clear(); + RC.G = nullptr; + + // Nothing to delete as all the objects are allocated in stable bump pointer + // allocators. +} + +LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) { + return *new (MappedN = BPA.Allocate()) Node(*this, F); +} + +void LazyCallGraph::updateGraphPtrs() { + // Walk the node map to update their graph pointers. While this iterates in + // an unstable order, the order has no effect so it remains correct. + for (auto &FunctionNodePair : NodeMap) + FunctionNodePair.second->G = this; + + for (auto *RC : PostOrderRefSCCs) + RC->G = this; +} + +template <typename RootsT, typename GetBeginT, typename GetEndT, + typename GetNodeT, typename FormSCCCallbackT> +void LazyCallGraph::buildGenericSCCs(RootsT &&Roots, GetBeginT &&GetBegin, + GetEndT &&GetEnd, GetNodeT &&GetNode, + FormSCCCallbackT &&FormSCC) { + using EdgeItT = decltype(GetBegin(std::declval<Node &>())); + + SmallVector<std::pair<Node *, EdgeItT>, 16> DFSStack; + SmallVector<Node *, 16> PendingSCCStack; + + // Scan down the stack and DFS across the call edges. + for (Node *RootN : Roots) { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; + } + + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; + + DFSStack.push_back({RootN, GetBegin(*RootN)}); + do { + Node *N; + EdgeItT I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = GetEnd(*N); + while (I != E) { + Node &ChildN = GetNode(I); + if (ChildN.DFSNumber == 0) { + // We haven't yet visited this child, so descend, pushing the current + // node onto the stack. + DFSStack.push_back({N, I}); + + ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; + N = &ChildN; + I = GetBegin(*N); + E = GetEnd(*N); + continue; + } + + // If the child has already been added to some child component, it + // couldn't impact the low-link of this parent because it isn't + // connected, and thus its low-link isn't relevant so skip it. + if (ChildN.DFSNumber == -1) { + ++I; + continue; + } + + // Track the lowest linked child as the lowest link for this node. + assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); + if (ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + + // Move to the next edge. + ++I; + } + + // We've finished processing N and its descendants, put it on our pending + // SCC stack to eventually get merged into an SCC of nodes. + PendingSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) + continue; + + // Otherwise, we've completed an SCC. Append it to our post order list of + // SCCs. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto SCCNodes = make_range( + PendingSCCStack.rbegin(), + find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) { + return N->DFSNumber < RootDFSNumber; + })); + // Form a new SCC out of these nodes and then clear them off our pending + // stack. + FormSCC(SCCNodes); + PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); + } while (!DFSStack.empty()); + } +} + +/// Build the internal SCCs for a RefSCC from a sequence of nodes. +/// +/// Appends the SCCs to the provided vector and updates the map with their +/// indices. Both the vector and map must be empty when passed into this +/// routine. +void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { + assert(RC.SCCs.empty() && "Already built SCCs!"); + assert(RC.SCCIndices.empty() && "Already mapped SCC indices!"); + + for (Node *N : Nodes) { + assert(N->LowLink >= (*Nodes.begin())->LowLink && + "We cannot have a low link in an SCC lower than its root on the " + "stack!"); + + // This node will go into the next RefSCC, clear out its DFS and low link + // as we scan. + N->DFSNumber = N->LowLink = 0; + } + + // Each RefSCC contains a DAG of the call SCCs. To build these, we do + // a direct walk of the call edges using Tarjan's algorithm. We reuse the + // internal storage as we won't need it for the outer graph's DFS any longer. + buildGenericSCCs( + Nodes, [](Node &N) { return N->call_begin(); }, + [](Node &N) { return N->call_end(); }, + [](EdgeSequence::call_iterator I) -> Node & { return I->getNode(); }, + [this, &RC](node_stack_range Nodes) { + RC.SCCs.push_back(createSCC(RC, Nodes)); + for (Node &N : *RC.SCCs.back()) { + N.DFSNumber = N.LowLink = -1; + SCCMap[&N] = RC.SCCs.back(); + } + }); + + // Wire up the SCC indices. + for (int i = 0, Size = RC.SCCs.size(); i < Size; ++i) + RC.SCCIndices[RC.SCCs[i]] = i; +} + +void LazyCallGraph::buildRefSCCs() { + if (EntryEdges.empty() || !PostOrderRefSCCs.empty()) + // RefSCCs are either non-existent or already built! + return; + + assert(RefSCCIndices.empty() && "Already mapped RefSCC indices!"); + + SmallVector<Node *, 16> Roots; + for (Edge &E : *this) + Roots.push_back(&E.getNode()); + + // The roots will be popped of a stack, so use reverse to get a less + // surprising order. This doesn't change any of the semantics anywhere. + std::reverse(Roots.begin(), Roots.end()); + + buildGenericSCCs( + Roots, + [](Node &N) { + // We need to populate each node as we begin to walk its edges. + N.populate(); + return N->begin(); + }, + [](Node &N) { return N->end(); }, + [](EdgeSequence::iterator I) -> Node & { return I->getNode(); }, + [this](node_stack_range Nodes) { + RefSCC *NewRC = createRefSCC(*this); + buildSCCs(*NewRC, Nodes); + + // Push the new node into the postorder list and remember its position + // in the index map. + bool Inserted = + RefSCCIndices.insert({NewRC, PostOrderRefSCCs.size()}).second; + (void)Inserted; + assert(Inserted && "Cannot already have this RefSCC in the index map!"); + PostOrderRefSCCs.push_back(NewRC); +#ifndef NDEBUG + NewRC->verify(); +#endif + }); +} + +AnalysisKey LazyCallGraphAnalysis::Key; + +LazyCallGraphPrinterPass::LazyCallGraphPrinterPass(raw_ostream &OS) : OS(OS) {} + +static void printNode(raw_ostream &OS, LazyCallGraph::Node &N) { + OS << " Edges in function: " << N.getFunction().getName() << "\n"; + for (LazyCallGraph::Edge &E : N.populate()) + OS << " " << (E.isCall() ? "call" : "ref ") << " -> " + << E.getFunction().getName() << "\n"; + + OS << "\n"; +} + +static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &C) { + OS << " SCC with " << C.size() << " functions:\n"; + + for (LazyCallGraph::Node &N : C) + OS << " " << N.getFunction().getName() << "\n"; +} + +static void printRefSCC(raw_ostream &OS, LazyCallGraph::RefSCC &C) { + OS << " RefSCC with " << C.size() << " call SCCs:\n"; + + for (LazyCallGraph::SCC &InnerC : C) + printSCC(OS, InnerC); + + OS << "\n"; +} + +PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); + + OS << "Printing the call graph for module: " << M.getModuleIdentifier() + << "\n\n"; + + for (Function &F : M) + printNode(OS, G.get(F)); + + G.buildRefSCCs(); + for (LazyCallGraph::RefSCC &C : G.postorder_ref_sccs()) + printRefSCC(OS, C); + + return PreservedAnalyses::all(); +} + +LazyCallGraphDOTPrinterPass::LazyCallGraphDOTPrinterPass(raw_ostream &OS) + : OS(OS) {} + +static void printNodeDOT(raw_ostream &OS, LazyCallGraph::Node &N) { + std::string Name = "\"" + DOT::EscapeString(N.getFunction().getName()) + "\""; + + for (LazyCallGraph::Edge &E : N.populate()) { + OS << " " << Name << " -> \"" + << DOT::EscapeString(E.getFunction().getName()) << "\""; + if (!E.isCall()) // It is a ref edge. + OS << " [style=dashed,label=\"ref\"]"; + OS << ";\n"; + } + + OS << "\n"; +} + +PreservedAnalyses LazyCallGraphDOTPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); + + OS << "digraph \"" << DOT::EscapeString(M.getModuleIdentifier()) << "\" {\n"; + + for (Function &F : M) + printNodeDOT(OS, G.get(F)); + + OS << "}\n"; + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp new file mode 100644 index 000000000000..96722f32e355 --- /dev/null +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -0,0 +1,2042 @@ +//===- LazyValueInfo.cpp - Value constraint analysis ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface for lazy computation of value constraint +// information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +#include <map> +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "lazy-value-info" + +// This is the number of worklist items we will process to try to discover an +// answer for a given value. +static const unsigned MaxProcessedPerValue = 500; + +char LazyValueInfoWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoWrapperPass, "lazy-value-info", + "Lazy Value Information Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LazyValueInfoWrapperPass, "lazy-value-info", + "Lazy Value Information Analysis", false, true) + +namespace llvm { + FunctionPass *createLazyValueInfoPass() { return new LazyValueInfoWrapperPass(); } +} + +AnalysisKey LazyValueAnalysis::Key; + +/// Returns true if this lattice value represents at most one possible value. +/// This is as precise as any lattice value can get while still representing +/// reachable code. +static bool hasSingleValue(const ValueLatticeElement &Val) { + if (Val.isConstantRange() && + Val.getConstantRange().isSingleElement()) + // Integer constants are single element ranges + return true; + if (Val.isConstant()) + // Non integer constants + return true; + return false; +} + +/// Combine two sets of facts about the same value into a single set of +/// facts. Note that this method is not suitable for merging facts along +/// different paths in a CFG; that's what the mergeIn function is for. This +/// is for merging facts gathered about the same value at the same location +/// through two independent means. +/// Notes: +/// * This method does not promise to return the most precise possible lattice +/// value implied by A and B. It is allowed to return any lattice element +/// which is at least as strong as *either* A or B (unless our facts +/// conflict, see below). +/// * Due to unreachable code, the intersection of two lattice values could be +/// contradictory. If this happens, we return some valid lattice value so as +/// not confuse the rest of LVI. Ideally, we'd always return Undefined, but +/// we do not make this guarantee. TODO: This would be a useful enhancement. +static ValueLatticeElement intersect(const ValueLatticeElement &A, + const ValueLatticeElement &B) { + // Undefined is the strongest state. It means the value is known to be along + // an unreachable path. + if (A.isUndefined()) + return A; + if (B.isUndefined()) + return B; + + // If we gave up for one, but got a useable fact from the other, use it. + if (A.isOverdefined()) + return B; + if (B.isOverdefined()) + return A; + + // Can't get any more precise than constants. + if (hasSingleValue(A)) + return A; + if (hasSingleValue(B)) + return B; + + // Could be either constant range or not constant here. + if (!A.isConstantRange() || !B.isConstantRange()) { + // TODO: Arbitrary choice, could be improved + return A; + } + + // Intersect two constant ranges + ConstantRange Range = + A.getConstantRange().intersectWith(B.getConstantRange()); + // Note: An empty range is implicitly converted to overdefined internally. + // TODO: We could instead use Undefined here since we've proven a conflict + // and thus know this path must be unreachable. + return ValueLatticeElement::getRange(std::move(Range)); +} + +//===----------------------------------------------------------------------===// +// LazyValueInfoCache Decl +//===----------------------------------------------------------------------===// + +namespace { + /// A callback value handle updates the cache when values are erased. + class LazyValueInfoCache; + struct LVIValueHandle final : public CallbackVH { + // Needs to access getValPtr(), which is protected. + friend struct DenseMapInfo<LVIValueHandle>; + + LazyValueInfoCache *Parent; + + LVIValueHandle(Value *V, LazyValueInfoCache *P) + : CallbackVH(V), Parent(P) { } + + void deleted() override; + void allUsesReplacedWith(Value *V) override { + deleted(); + } + }; +} // end anonymous namespace + +namespace { + /// This is the cache kept by LazyValueInfo which + /// maintains information about queries across the clients' queries. + class LazyValueInfoCache { + /// This is all of the cached block information for exactly one Value*. + /// The entries are sorted by the BasicBlock* of the + /// entries, allowing us to do a lookup with a binary search. + /// Over-defined lattice values are recorded in OverDefinedCache to reduce + /// memory overhead. + struct ValueCacheEntryTy { + ValueCacheEntryTy(Value *V, LazyValueInfoCache *P) : Handle(V, P) {} + LVIValueHandle Handle; + SmallDenseMap<PoisoningVH<BasicBlock>, ValueLatticeElement, 4> BlockVals; + }; + + /// This tracks, on a per-block basis, the set of values that are + /// over-defined at the end of that block. + typedef DenseMap<PoisoningVH<BasicBlock>, SmallPtrSet<Value *, 4>> + OverDefinedCacheTy; + /// Keep track of all blocks that we have ever seen, so we + /// don't spend time removing unused blocks from our caches. + DenseSet<PoisoningVH<BasicBlock> > SeenBlocks; + + /// This is all of the cached information for all values, + /// mapped from Value* to key information. + DenseMap<Value *, std::unique_ptr<ValueCacheEntryTy>> ValueCache; + OverDefinedCacheTy OverDefinedCache; + + + public: + void insertResult(Value *Val, BasicBlock *BB, + const ValueLatticeElement &Result) { + SeenBlocks.insert(BB); + + // Insert over-defined values into their own cache to reduce memory + // overhead. + if (Result.isOverdefined()) + OverDefinedCache[BB].insert(Val); + else { + auto It = ValueCache.find_as(Val); + if (It == ValueCache.end()) { + ValueCache[Val] = std::make_unique<ValueCacheEntryTy>(Val, this); + It = ValueCache.find_as(Val); + assert(It != ValueCache.end() && "Val was just added to the map!"); + } + It->second->BlockVals[BB] = Result; + } + } + + bool isOverdefined(Value *V, BasicBlock *BB) const { + auto ODI = OverDefinedCache.find(BB); + + if (ODI == OverDefinedCache.end()) + return false; + + return ODI->second.count(V); + } + + bool hasCachedValueInfo(Value *V, BasicBlock *BB) const { + if (isOverdefined(V, BB)) + return true; + + auto I = ValueCache.find_as(V); + if (I == ValueCache.end()) + return false; + + return I->second->BlockVals.count(BB); + } + + ValueLatticeElement getCachedValueInfo(Value *V, BasicBlock *BB) const { + if (isOverdefined(V, BB)) + return ValueLatticeElement::getOverdefined(); + + auto I = ValueCache.find_as(V); + if (I == ValueCache.end()) + return ValueLatticeElement(); + auto BBI = I->second->BlockVals.find(BB); + if (BBI == I->second->BlockVals.end()) + return ValueLatticeElement(); + return BBI->second; + } + + /// clear - Empty the cache. + void clear() { + SeenBlocks.clear(); + ValueCache.clear(); + OverDefinedCache.clear(); + } + + /// Inform the cache that a given value has been deleted. + void eraseValue(Value *V); + + /// This is part of the update interface to inform the cache + /// that a block has been deleted. + void eraseBlock(BasicBlock *BB); + + /// Updates the cache to remove any influence an overdefined value in + /// OldSucc might have (unless also overdefined in NewSucc). This just + /// flushes elements from the cache and does not add any. + void threadEdgeImpl(BasicBlock *OldSucc,BasicBlock *NewSucc); + + friend struct LVIValueHandle; + }; +} + +void LazyValueInfoCache::eraseValue(Value *V) { + for (auto I = OverDefinedCache.begin(), E = OverDefinedCache.end(); I != E;) { + // Copy and increment the iterator immediately so we can erase behind + // ourselves. + auto Iter = I++; + SmallPtrSetImpl<Value *> &ValueSet = Iter->second; + ValueSet.erase(V); + if (ValueSet.empty()) + OverDefinedCache.erase(Iter); + } + + ValueCache.erase(V); +} + +void LVIValueHandle::deleted() { + // This erasure deallocates *this, so it MUST happen after we're done + // using any and all members of *this. + Parent->eraseValue(*this); +} + +void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { + // Shortcut if we have never seen this block. + DenseSet<PoisoningVH<BasicBlock> >::iterator I = SeenBlocks.find(BB); + if (I == SeenBlocks.end()) + return; + SeenBlocks.erase(I); + + auto ODI = OverDefinedCache.find(BB); + if (ODI != OverDefinedCache.end()) + OverDefinedCache.erase(ODI); + + for (auto &I : ValueCache) + I.second->BlockVals.erase(BB); +} + +void LazyValueInfoCache::threadEdgeImpl(BasicBlock *OldSucc, + BasicBlock *NewSucc) { + // When an edge in the graph has been threaded, values that we could not + // determine a value for before (i.e. were marked overdefined) may be + // possible to solve now. We do NOT try to proactively update these values. + // Instead, we clear their entries from the cache, and allow lazy updating to + // recompute them when needed. + + // The updating process is fairly simple: we need to drop cached info + // for all values that were marked overdefined in OldSucc, and for those same + // values in any successor of OldSucc (except NewSucc) in which they were + // also marked overdefined. + std::vector<BasicBlock*> worklist; + worklist.push_back(OldSucc); + + auto I = OverDefinedCache.find(OldSucc); + if (I == OverDefinedCache.end()) + return; // Nothing to process here. + SmallVector<Value *, 4> ValsToClear(I->second.begin(), I->second.end()); + + // Use a worklist to perform a depth-first search of OldSucc's successors. + // NOTE: We do not need a visited list since any blocks we have already + // visited will have had their overdefined markers cleared already, and we + // thus won't loop to their successors. + while (!worklist.empty()) { + BasicBlock *ToUpdate = worklist.back(); + worklist.pop_back(); + + // Skip blocks only accessible through NewSucc. + if (ToUpdate == NewSucc) continue; + + // If a value was marked overdefined in OldSucc, and is here too... + auto OI = OverDefinedCache.find(ToUpdate); + if (OI == OverDefinedCache.end()) + continue; + SmallPtrSetImpl<Value *> &ValueSet = OI->second; + + bool changed = false; + for (Value *V : ValsToClear) { + if (!ValueSet.erase(V)) + continue; + + // If we removed anything, then we potentially need to update + // blocks successors too. + changed = true; + + if (ValueSet.empty()) { + OverDefinedCache.erase(OI); + break; + } + } + + if (!changed) continue; + + worklist.insert(worklist.end(), succ_begin(ToUpdate), succ_end(ToUpdate)); + } +} + + +namespace { +/// An assembly annotator class to print LazyValueCache information in +/// comments. +class LazyValueInfoImpl; +class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter { + LazyValueInfoImpl *LVIImpl; + // While analyzing which blocks we can solve values for, we need the dominator + // information. Since this is an optional parameter in LVI, we require this + // DomTreeAnalysis pass in the printer pass, and pass the dominator + // tree to the LazyValueInfoAnnotatedWriter. + DominatorTree &DT; + +public: + LazyValueInfoAnnotatedWriter(LazyValueInfoImpl *L, DominatorTree &DTree) + : LVIImpl(L), DT(DTree) {} + + virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS); + + virtual void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS); +}; +} +namespace { + // The actual implementation of the lazy analysis and update. Note that the + // inheritance from LazyValueInfoCache is intended to be temporary while + // splitting the code and then transitioning to a has-a relationship. + class LazyValueInfoImpl { + + /// Cached results from previous queries + LazyValueInfoCache TheCache; + + /// This stack holds the state of the value solver during a query. + /// It basically emulates the callstack of the naive + /// recursive value lookup process. + SmallVector<std::pair<BasicBlock*, Value*>, 8> BlockValueStack; + + /// Keeps track of which block-value pairs are in BlockValueStack. + DenseSet<std::pair<BasicBlock*, Value*> > BlockValueSet; + + /// Push BV onto BlockValueStack unless it's already in there. + /// Returns true on success. + bool pushBlockValue(const std::pair<BasicBlock *, Value *> &BV) { + if (!BlockValueSet.insert(BV).second) + return false; // It's already in the stack. + + LLVM_DEBUG(dbgs() << "PUSH: " << *BV.second << " in " + << BV.first->getName() << "\n"); + BlockValueStack.push_back(BV); + return true; + } + + AssumptionCache *AC; ///< A pointer to the cache of @llvm.assume calls. + const DataLayout &DL; ///< A mandatory DataLayout + DominatorTree *DT; ///< An optional DT pointer. + DominatorTree *DisabledDT; ///< Stores DT if it's disabled. + + ValueLatticeElement getBlockValue(Value *Val, BasicBlock *BB); + bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, + ValueLatticeElement &Result, Instruction *CxtI = nullptr); + bool hasBlockValue(Value *Val, BasicBlock *BB); + + // These methods process one work item and may add more. A false value + // returned means that the work item was not completely processed and must + // be revisited after going through the new items. + bool solveBlockValue(Value *Val, BasicBlock *BB); + bool solveBlockValueImpl(ValueLatticeElement &Res, Value *Val, + BasicBlock *BB); + bool solveBlockValueNonLocal(ValueLatticeElement &BBLV, Value *Val, + BasicBlock *BB); + bool solveBlockValuePHINode(ValueLatticeElement &BBLV, PHINode *PN, + BasicBlock *BB); + bool solveBlockValueSelect(ValueLatticeElement &BBLV, SelectInst *S, + BasicBlock *BB); + Optional<ConstantRange> getRangeForOperand(unsigned Op, Instruction *I, + BasicBlock *BB); + bool solveBlockValueBinaryOpImpl( + ValueLatticeElement &BBLV, Instruction *I, BasicBlock *BB, + std::function<ConstantRange(const ConstantRange &, + const ConstantRange &)> OpFn); + bool solveBlockValueBinaryOp(ValueLatticeElement &BBLV, BinaryOperator *BBI, + BasicBlock *BB); + bool solveBlockValueCast(ValueLatticeElement &BBLV, CastInst *CI, + BasicBlock *BB); + bool solveBlockValueOverflowIntrinsic( + ValueLatticeElement &BBLV, WithOverflowInst *WO, BasicBlock *BB); + bool solveBlockValueIntrinsic(ValueLatticeElement &BBLV, IntrinsicInst *II, + BasicBlock *BB); + bool solveBlockValueExtractValue(ValueLatticeElement &BBLV, + ExtractValueInst *EVI, BasicBlock *BB); + void intersectAssumeOrGuardBlockValueConstantRange(Value *Val, + ValueLatticeElement &BBLV, + Instruction *BBI); + + void solve(); + + public: + /// This is the query interface to determine the lattice + /// value for the specified Value* at the end of the specified block. + ValueLatticeElement getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI = nullptr); + + /// This is the query interface to determine the lattice + /// value for the specified Value* at the specified instruction (generally + /// from an assume intrinsic). + ValueLatticeElement getValueAt(Value *V, Instruction *CxtI); + + /// This is the query interface to determine the lattice + /// value for the specified Value* that is true on the specified edge. + ValueLatticeElement getValueOnEdge(Value *V, BasicBlock *FromBB, + BasicBlock *ToBB, + Instruction *CxtI = nullptr); + + /// Complete flush all previously computed values + void clear() { + TheCache.clear(); + } + + /// Printing the LazyValueInfo Analysis. + void printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) { + LazyValueInfoAnnotatedWriter Writer(this, DTree); + F.print(OS, &Writer); + } + + /// This is part of the update interface to inform the cache + /// that a block has been deleted. + void eraseBlock(BasicBlock *BB) { + TheCache.eraseBlock(BB); + } + + /// Disables use of the DominatorTree within LVI. + void disableDT() { + if (DT) { + assert(!DisabledDT && "Both DT and DisabledDT are not nullptr!"); + std::swap(DT, DisabledDT); + } + } + + /// Enables use of the DominatorTree within LVI. Does nothing if the class + /// instance was initialized without a DT pointer. + void enableDT() { + if (DisabledDT) { + assert(!DT && "Both DT and DisabledDT are not nullptr!"); + std::swap(DT, DisabledDT); + } + } + + /// This is the update interface to inform the cache that an edge from + /// PredBB to OldSucc has been threaded to be from PredBB to NewSucc. + void threadEdge(BasicBlock *PredBB,BasicBlock *OldSucc,BasicBlock *NewSucc); + + LazyValueInfoImpl(AssumptionCache *AC, const DataLayout &DL, + DominatorTree *DT = nullptr) + : AC(AC), DL(DL), DT(DT), DisabledDT(nullptr) {} + }; +} // end anonymous namespace + + +void LazyValueInfoImpl::solve() { + SmallVector<std::pair<BasicBlock *, Value *>, 8> StartingStack( + BlockValueStack.begin(), BlockValueStack.end()); + + unsigned processedCount = 0; + while (!BlockValueStack.empty()) { + processedCount++; + // Abort if we have to process too many values to get a result for this one. + // Because of the design of the overdefined cache currently being per-block + // to avoid naming-related issues (IE it wants to try to give different + // results for the same name in different blocks), overdefined results don't + // get cached globally, which in turn means we will often try to rediscover + // the same overdefined result again and again. Once something like + // PredicateInfo is used in LVI or CVP, we should be able to make the + // overdefined cache global, and remove this throttle. + if (processedCount > MaxProcessedPerValue) { + LLVM_DEBUG( + dbgs() << "Giving up on stack because we are getting too deep\n"); + // Fill in the original values + while (!StartingStack.empty()) { + std::pair<BasicBlock *, Value *> &e = StartingStack.back(); + TheCache.insertResult(e.second, e.first, + ValueLatticeElement::getOverdefined()); + StartingStack.pop_back(); + } + BlockValueSet.clear(); + BlockValueStack.clear(); + return; + } + std::pair<BasicBlock *, Value *> e = BlockValueStack.back(); + assert(BlockValueSet.count(e) && "Stack value should be in BlockValueSet!"); + + if (solveBlockValue(e.second, e.first)) { + // The work item was completely processed. + assert(BlockValueStack.back() == e && "Nothing should have been pushed!"); + assert(TheCache.hasCachedValueInfo(e.second, e.first) && + "Result should be in cache!"); + + LLVM_DEBUG( + dbgs() << "POP " << *e.second << " in " << e.first->getName() << " = " + << TheCache.getCachedValueInfo(e.second, e.first) << "\n"); + + BlockValueStack.pop_back(); + BlockValueSet.erase(e); + } else { + // More work needs to be done before revisiting. + assert(BlockValueStack.back() != e && "Stack should have been pushed!"); + } + } +} + +bool LazyValueInfoImpl::hasBlockValue(Value *Val, BasicBlock *BB) { + // If already a constant, there is nothing to compute. + if (isa<Constant>(Val)) + return true; + + return TheCache.hasCachedValueInfo(Val, BB); +} + +ValueLatticeElement LazyValueInfoImpl::getBlockValue(Value *Val, + BasicBlock *BB) { + // If already a constant, there is nothing to compute. + if (Constant *VC = dyn_cast<Constant>(Val)) + return ValueLatticeElement::get(VC); + + return TheCache.getCachedValueInfo(Val, BB); +} + +static ValueLatticeElement getFromRangeMetadata(Instruction *BBI) { + switch (BBI->getOpcode()) { + default: break; + case Instruction::Load: + case Instruction::Call: + case Instruction::Invoke: + if (MDNode *Ranges = BBI->getMetadata(LLVMContext::MD_range)) + if (isa<IntegerType>(BBI->getType())) { + return ValueLatticeElement::getRange( + getConstantRangeFromMetadata(*Ranges)); + } + break; + }; + // Nothing known - will be intersected with other facts + return ValueLatticeElement::getOverdefined(); +} + +bool LazyValueInfoImpl::solveBlockValue(Value *Val, BasicBlock *BB) { + if (isa<Constant>(Val)) + return true; + + if (TheCache.hasCachedValueInfo(Val, BB)) { + // If we have a cached value, use that. + LLVM_DEBUG(dbgs() << " reuse BB '" << BB->getName() << "' val=" + << TheCache.getCachedValueInfo(Val, BB) << '\n'); + + // Since we're reusing a cached value, we don't need to update the + // OverDefinedCache. The cache will have been properly updated whenever the + // cached value was inserted. + return true; + } + + // Hold off inserting this value into the Cache in case we have to return + // false and come back later. + ValueLatticeElement Res; + if (!solveBlockValueImpl(Res, Val, BB)) + // Work pushed, will revisit + return false; + + TheCache.insertResult(Val, BB, Res); + return true; +} + +bool LazyValueInfoImpl::solveBlockValueImpl(ValueLatticeElement &Res, + Value *Val, BasicBlock *BB) { + + Instruction *BBI = dyn_cast<Instruction>(Val); + if (!BBI || BBI->getParent() != BB) + return solveBlockValueNonLocal(Res, Val, BB); + + if (PHINode *PN = dyn_cast<PHINode>(BBI)) + return solveBlockValuePHINode(Res, PN, BB); + + if (auto *SI = dyn_cast<SelectInst>(BBI)) + return solveBlockValueSelect(Res, SI, BB); + + // If this value is a nonnull pointer, record it's range and bailout. Note + // that for all other pointer typed values, we terminate the search at the + // definition. We could easily extend this to look through geps, bitcasts, + // and the like to prove non-nullness, but it's not clear that's worth it + // compile time wise. The context-insensitive value walk done inside + // isKnownNonZero gets most of the profitable cases at much less expense. + // This does mean that we have a sensitivity to where the defining + // instruction is placed, even if it could legally be hoisted much higher. + // That is unfortunate. + PointerType *PT = dyn_cast<PointerType>(BBI->getType()); + if (PT && isKnownNonZero(BBI, DL)) { + Res = ValueLatticeElement::getNot(ConstantPointerNull::get(PT)); + return true; + } + if (BBI->getType()->isIntegerTy()) { + if (auto *CI = dyn_cast<CastInst>(BBI)) + return solveBlockValueCast(Res, CI, BB); + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI)) + return solveBlockValueBinaryOp(Res, BO, BB); + + if (auto *EVI = dyn_cast<ExtractValueInst>(BBI)) + return solveBlockValueExtractValue(Res, EVI, BB); + + if (auto *II = dyn_cast<IntrinsicInst>(BBI)) + return solveBlockValueIntrinsic(Res, II, BB); + } + + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - unknown inst def found.\n"); + Res = getFromRangeMetadata(BBI); + return true; +} + +static bool InstructionDereferencesPointer(Instruction *I, Value *Ptr) { + if (LoadInst *L = dyn_cast<LoadInst>(I)) { + return L->getPointerAddressSpace() == 0 && + GetUnderlyingObject(L->getPointerOperand(), + L->getModule()->getDataLayout()) == Ptr; + } + if (StoreInst *S = dyn_cast<StoreInst>(I)) { + return S->getPointerAddressSpace() == 0 && + GetUnderlyingObject(S->getPointerOperand(), + S->getModule()->getDataLayout()) == Ptr; + } + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { + if (MI->isVolatile()) return false; + + // FIXME: check whether it has a valuerange that excludes zero? + ConstantInt *Len = dyn_cast<ConstantInt>(MI->getLength()); + if (!Len || Len->isZero()) return false; + + if (MI->getDestAddressSpace() == 0) + if (GetUnderlyingObject(MI->getRawDest(), + MI->getModule()->getDataLayout()) == Ptr) + return true; + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) + if (MTI->getSourceAddressSpace() == 0) + if (GetUnderlyingObject(MTI->getRawSource(), + MTI->getModule()->getDataLayout()) == Ptr) + return true; + } + return false; +} + +/// Return true if the allocation associated with Val is ever dereferenced +/// within the given basic block. This establishes the fact Val is not null, +/// but does not imply that the memory at Val is dereferenceable. (Val may +/// point off the end of the dereferenceable part of the object.) +static bool isObjectDereferencedInBlock(Value *Val, BasicBlock *BB) { + assert(Val->getType()->isPointerTy()); + + const DataLayout &DL = BB->getModule()->getDataLayout(); + Value *UnderlyingVal = GetUnderlyingObject(Val, DL); + // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge + // inside InstructionDereferencesPointer either. + if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, DL, 1)) + for (Instruction &I : *BB) + if (InstructionDereferencesPointer(&I, UnderlyingVal)) + return true; + return false; +} + +bool LazyValueInfoImpl::solveBlockValueNonLocal(ValueLatticeElement &BBLV, + Value *Val, BasicBlock *BB) { + ValueLatticeElement Result; // Start Undefined. + + // If this is the entry block, we must be asking about an argument. The + // value is overdefined. + if (BB == &BB->getParent()->getEntryBlock()) { + assert(isa<Argument>(Val) && "Unknown live-in to the entry block"); + // Before giving up, see if we can prove the pointer non-null local to + // this particular block. + PointerType *PTy = dyn_cast<PointerType>(Val->getType()); + if (PTy && + (isKnownNonZero(Val, DL) || + (isObjectDereferencedInBlock(Val, BB) && + !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())))) { + Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy)); + } else { + Result = ValueLatticeElement::getOverdefined(); + } + BBLV = Result; + return true; + } + + // Loop over all of our predecessors, merging what we know from them into + // result. If we encounter an unexplored predecessor, we eagerly explore it + // in a depth first manner. In practice, this has the effect of discovering + // paths we can't analyze eagerly without spending compile times analyzing + // other paths. This heuristic benefits from the fact that predecessors are + // frequently arranged such that dominating ones come first and we quickly + // find a path to function entry. TODO: We should consider explicitly + // canonicalizing to make this true rather than relying on this happy + // accident. + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + ValueLatticeElement EdgeResult; + if (!getEdgeValue(Val, *PI, BB, EdgeResult)) + // Explore that input, then return here + return false; + + Result.mergeIn(EdgeResult, DL); + + // If we hit overdefined, exit early. The BlockVals entry is already set + // to overdefined. + if (Result.isOverdefined()) { + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined because of pred (non local).\n"); + // Before giving up, see if we can prove the pointer non-null local to + // this particular block. + PointerType *PTy = dyn_cast<PointerType>(Val->getType()); + if (PTy && isObjectDereferencedInBlock(Val, BB) && + !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())) { + Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy)); + } + + BBLV = Result; + return true; + } + } + + // Return the merged value, which is more precise than 'overdefined'. + assert(!Result.isOverdefined()); + BBLV = Result; + return true; +} + +bool LazyValueInfoImpl::solveBlockValuePHINode(ValueLatticeElement &BBLV, + PHINode *PN, BasicBlock *BB) { + ValueLatticeElement Result; // Start Undefined. + + // Loop over all of our predecessors, merging what we know from them into + // result. See the comment about the chosen traversal order in + // solveBlockValueNonLocal; the same reasoning applies here. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *PhiBB = PN->getIncomingBlock(i); + Value *PhiVal = PN->getIncomingValue(i); + ValueLatticeElement EdgeResult; + // Note that we can provide PN as the context value to getEdgeValue, even + // though the results will be cached, because PN is the value being used as + // the cache key in the caller. + if (!getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN)) + // Explore that input, then return here + return false; + + Result.mergeIn(EdgeResult, DL); + + // If we hit overdefined, exit early. The BlockVals entry is already set + // to overdefined. + if (Result.isOverdefined()) { + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined because of pred (local).\n"); + + BBLV = Result; + return true; + } + } + + // Return the merged value, which is more precise than 'overdefined'. + assert(!Result.isOverdefined() && "Possible PHI in entry block?"); + BBLV = Result; + return true; +} + +static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, + bool isTrueDest = true); + +// If we can determine a constraint on the value given conditions assumed by +// the program, intersect those constraints with BBLV +void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange( + Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) { + BBI = BBI ? BBI : dyn_cast<Instruction>(Val); + if (!BBI) + return; + + for (auto &AssumeVH : AC->assumptionsFor(Val)) { + if (!AssumeVH) + continue; + auto *I = cast<CallInst>(AssumeVH); + if (!isValidAssumeForContext(I, BBI, DT)) + continue; + + BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0))); + } + + // If guards are not used in the module, don't spend time looking for them + auto *GuardDecl = BBI->getModule()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return; + + if (BBI->getIterator() == BBI->getParent()->begin()) + return; + for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()), + BBI->getParent()->rend())) { + Value *Cond = nullptr; + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond)))) + BBLV = intersect(BBLV, getValueFromCondition(Val, Cond)); + } +} + +bool LazyValueInfoImpl::solveBlockValueSelect(ValueLatticeElement &BBLV, + SelectInst *SI, BasicBlock *BB) { + + // Recurse on our inputs if needed + if (!hasBlockValue(SI->getTrueValue(), BB)) { + if (pushBlockValue(std::make_pair(BB, SI->getTrueValue()))) + return false; + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + ValueLatticeElement TrueVal = getBlockValue(SI->getTrueValue(), BB); + // If we hit overdefined, don't ask more queries. We want to avoid poisoning + // extra slots in the table if we can. + if (TrueVal.isOverdefined()) { + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + + if (!hasBlockValue(SI->getFalseValue(), BB)) { + if (pushBlockValue(std::make_pair(BB, SI->getFalseValue()))) + return false; + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + ValueLatticeElement FalseVal = getBlockValue(SI->getFalseValue(), BB); + // If we hit overdefined, don't ask more queries. We want to avoid poisoning + // extra slots in the table if we can. + if (FalseVal.isOverdefined()) { + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + + if (TrueVal.isConstantRange() && FalseVal.isConstantRange()) { + const ConstantRange &TrueCR = TrueVal.getConstantRange(); + const ConstantRange &FalseCR = FalseVal.getConstantRange(); + Value *LHS = nullptr; + Value *RHS = nullptr; + SelectPatternResult SPR = matchSelectPattern(SI, LHS, RHS); + // Is this a min specifically of our two inputs? (Avoid the risk of + // ValueTracking getting smarter looking back past our immediate inputs.) + if (SelectPatternResult::isMinOrMax(SPR.Flavor) && + LHS == SI->getTrueValue() && RHS == SI->getFalseValue()) { + ConstantRange ResultCR = [&]() { + switch (SPR.Flavor) { + default: + llvm_unreachable("unexpected minmax type!"); + case SPF_SMIN: /// Signed minimum + return TrueCR.smin(FalseCR); + case SPF_UMIN: /// Unsigned minimum + return TrueCR.umin(FalseCR); + case SPF_SMAX: /// Signed maximum + return TrueCR.smax(FalseCR); + case SPF_UMAX: /// Unsigned maximum + return TrueCR.umax(FalseCR); + }; + }(); + BBLV = ValueLatticeElement::getRange(ResultCR); + return true; + } + + if (SPR.Flavor == SPF_ABS) { + if (LHS == SI->getTrueValue()) { + BBLV = ValueLatticeElement::getRange(TrueCR.abs()); + return true; + } + if (LHS == SI->getFalseValue()) { + BBLV = ValueLatticeElement::getRange(FalseCR.abs()); + return true; + } + } + + if (SPR.Flavor == SPF_NABS) { + ConstantRange Zero(APInt::getNullValue(TrueCR.getBitWidth())); + if (LHS == SI->getTrueValue()) { + BBLV = ValueLatticeElement::getRange(Zero.sub(TrueCR.abs())); + return true; + } + if (LHS == SI->getFalseValue()) { + BBLV = ValueLatticeElement::getRange(Zero.sub(FalseCR.abs())); + return true; + } + } + } + + // Can we constrain the facts about the true and false values by using the + // condition itself? This shows up with idioms like e.g. select(a > 5, a, 5). + // TODO: We could potentially refine an overdefined true value above. + Value *Cond = SI->getCondition(); + TrueVal = intersect(TrueVal, + getValueFromCondition(SI->getTrueValue(), Cond, true)); + FalseVal = intersect(FalseVal, + getValueFromCondition(SI->getFalseValue(), Cond, false)); + + // Handle clamp idioms such as: + // %24 = constantrange<0, 17> + // %39 = icmp eq i32 %24, 0 + // %40 = add i32 %24, -1 + // %siv.next = select i1 %39, i32 16, i32 %40 + // %siv.next = constantrange<0, 17> not <-1, 17> + // In general, this can handle any clamp idiom which tests the edge + // condition via an equality or inequality. + if (auto *ICI = dyn_cast<ICmpInst>(Cond)) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *A = ICI->getOperand(0); + if (ConstantInt *CIBase = dyn_cast<ConstantInt>(ICI->getOperand(1))) { + auto addConstants = [](ConstantInt *A, ConstantInt *B) { + assert(A->getType() == B->getType()); + return ConstantInt::get(A->getType(), A->getValue() + B->getValue()); + }; + // See if either input is A + C2, subject to the constraint from the + // condition that A != C when that input is used. We can assume that + // that input doesn't include C + C2. + ConstantInt *CIAdded; + switch (Pred) { + default: break; + case ICmpInst::ICMP_EQ: + if (match(SI->getFalseValue(), m_Add(m_Specific(A), + m_ConstantInt(CIAdded)))) { + auto ResNot = addConstants(CIBase, CIAdded); + FalseVal = intersect(FalseVal, + ValueLatticeElement::getNot(ResNot)); + } + break; + case ICmpInst::ICMP_NE: + if (match(SI->getTrueValue(), m_Add(m_Specific(A), + m_ConstantInt(CIAdded)))) { + auto ResNot = addConstants(CIBase, CIAdded); + TrueVal = intersect(TrueVal, + ValueLatticeElement::getNot(ResNot)); + } + break; + }; + } + } + + ValueLatticeElement Result; // Start Undefined. + Result.mergeIn(TrueVal, DL); + Result.mergeIn(FalseVal, DL); + BBLV = Result; + return true; +} + +Optional<ConstantRange> LazyValueInfoImpl::getRangeForOperand(unsigned Op, + Instruction *I, + BasicBlock *BB) { + if (!hasBlockValue(I->getOperand(Op), BB)) + if (pushBlockValue(std::make_pair(BB, I->getOperand(Op)))) + return None; + + const unsigned OperandBitWidth = + DL.getTypeSizeInBits(I->getOperand(Op)->getType()); + ConstantRange Range = ConstantRange::getFull(OperandBitWidth); + if (hasBlockValue(I->getOperand(Op), BB)) { + ValueLatticeElement Val = getBlockValue(I->getOperand(Op), BB); + intersectAssumeOrGuardBlockValueConstantRange(I->getOperand(Op), Val, I); + if (Val.isConstantRange()) + Range = Val.getConstantRange(); + } + return Range; +} + +bool LazyValueInfoImpl::solveBlockValueCast(ValueLatticeElement &BBLV, + CastInst *CI, + BasicBlock *BB) { + if (!CI->getOperand(0)->getType()->isSized()) { + // Without knowing how wide the input is, we can't analyze it in any useful + // way. + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + + // Filter out casts we don't know how to reason about before attempting to + // recurse on our operand. This can cut a long search short if we know we're + // not going to be able to get any useful information anways. + switch (CI->getOpcode()) { + case Instruction::Trunc: + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::BitCast: + break; + default: + // Unhandled instructions are overdefined. + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown cast).\n"); + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + + // Figure out the range of the LHS. If that fails, we still apply the + // transfer rule on the full set since we may be able to locally infer + // interesting facts. + Optional<ConstantRange> LHSRes = getRangeForOperand(0, CI, BB); + if (!LHSRes.hasValue()) + // More work to do before applying this transfer rule. + return false; + ConstantRange LHSRange = LHSRes.getValue(); + + const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth(); + + // NOTE: We're currently limited by the set of operations that ConstantRange + // can evaluate symbolically. Enhancing that set will allows us to analyze + // more definitions. + BBLV = ValueLatticeElement::getRange(LHSRange.castOp(CI->getOpcode(), + ResultBitWidth)); + return true; +} + +bool LazyValueInfoImpl::solveBlockValueBinaryOpImpl( + ValueLatticeElement &BBLV, Instruction *I, BasicBlock *BB, + std::function<ConstantRange(const ConstantRange &, + const ConstantRange &)> OpFn) { + // Figure out the ranges of the operands. If that fails, use a + // conservative range, but apply the transfer rule anyways. This + // lets us pick up facts from expressions like "and i32 (call i32 + // @foo()), 32" + Optional<ConstantRange> LHSRes = getRangeForOperand(0, I, BB); + Optional<ConstantRange> RHSRes = getRangeForOperand(1, I, BB); + if (!LHSRes.hasValue() || !RHSRes.hasValue()) + // More work to do before applying this transfer rule. + return false; + + ConstantRange LHSRange = LHSRes.getValue(); + ConstantRange RHSRange = RHSRes.getValue(); + BBLV = ValueLatticeElement::getRange(OpFn(LHSRange, RHSRange)); + return true; +} + +bool LazyValueInfoImpl::solveBlockValueBinaryOp(ValueLatticeElement &BBLV, + BinaryOperator *BO, + BasicBlock *BB) { + + assert(BO->getOperand(0)->getType()->isSized() && + "all operands to binary operators are sized"); + if (BO->getOpcode() == Instruction::Xor) { + // Xor is the only operation not supported by ConstantRange::binaryOp(). + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown binary operator).\n"); + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + + return solveBlockValueBinaryOpImpl(BBLV, BO, BB, + [BO](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.binaryOp(BO->getOpcode(), CR2); + }); +} + +bool LazyValueInfoImpl::solveBlockValueOverflowIntrinsic( + ValueLatticeElement &BBLV, WithOverflowInst *WO, BasicBlock *BB) { + return solveBlockValueBinaryOpImpl(BBLV, WO, BB, + [WO](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.binaryOp(WO->getBinaryOp(), CR2); + }); +} + +bool LazyValueInfoImpl::solveBlockValueIntrinsic( + ValueLatticeElement &BBLV, IntrinsicInst *II, BasicBlock *BB) { + switch (II->getIntrinsicID()) { + case Intrinsic::uadd_sat: + return solveBlockValueBinaryOpImpl(BBLV, II, BB, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.uadd_sat(CR2); + }); + case Intrinsic::usub_sat: + return solveBlockValueBinaryOpImpl(BBLV, II, BB, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.usub_sat(CR2); + }); + case Intrinsic::sadd_sat: + return solveBlockValueBinaryOpImpl(BBLV, II, BB, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.sadd_sat(CR2); + }); + case Intrinsic::ssub_sat: + return solveBlockValueBinaryOpImpl(BBLV, II, BB, + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.ssub_sat(CR2); + }); + default: + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown intrinsic).\n"); + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } +} + +bool LazyValueInfoImpl::solveBlockValueExtractValue( + ValueLatticeElement &BBLV, ExtractValueInst *EVI, BasicBlock *BB) { + if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand())) + if (EVI->getNumIndices() == 1 && *EVI->idx_begin() == 0) + return solveBlockValueOverflowIntrinsic(BBLV, WO, BB); + + // Handle extractvalue of insertvalue to allow further simplification + // based on replaced with.overflow intrinsics. + if (Value *V = SimplifyExtractValueInst( + EVI->getAggregateOperand(), EVI->getIndices(), + EVI->getModule()->getDataLayout())) { + if (!hasBlockValue(V, BB)) { + if (pushBlockValue({ BB, V })) + return false; + BBLV = ValueLatticeElement::getOverdefined(); + return true; + } + BBLV = getBlockValue(V, BB); + return true; + } + + LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown extractvalue).\n"); + BBLV = ValueLatticeElement::getOverdefined(); + return true; +} + +static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, + bool isTrueDest) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + CmpInst::Predicate Predicate = ICI->getPredicate(); + + if (isa<Constant>(RHS)) { + if (ICI->isEquality() && LHS == Val) { + // We know that V has the RHS constant if this is a true SETEQ or + // false SETNE. + if (isTrueDest == (Predicate == ICmpInst::ICMP_EQ)) + return ValueLatticeElement::get(cast<Constant>(RHS)); + else + return ValueLatticeElement::getNot(cast<Constant>(RHS)); + } + } + + if (!Val->getType()->isIntegerTy()) + return ValueLatticeElement::getOverdefined(); + + // Use ConstantRange::makeAllowedICmpRegion in order to determine the possible + // range of Val guaranteed by the condition. Recognize comparisons in the from + // of: + // icmp <pred> Val, ... + // icmp <pred> (add Val, Offset), ... + // The latter is the range checking idiom that InstCombine produces. Subtract + // the offset from the allowed range for RHS in this case. + + // Val or (add Val, Offset) can be on either hand of the comparison + if (LHS != Val && !match(LHS, m_Add(m_Specific(Val), m_ConstantInt()))) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + + ConstantInt *Offset = nullptr; + if (LHS != Val) + match(LHS, m_Add(m_Specific(Val), m_ConstantInt(Offset))); + + if (LHS == Val || Offset) { + // Calculate the range of values that are allowed by the comparison + ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(), + /*isFullSet=*/true); + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) + RHSRange = ConstantRange(CI->getValue()); + else if (Instruction *I = dyn_cast<Instruction>(RHS)) + if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) + RHSRange = getConstantRangeFromMetadata(*Ranges); + + // If we're interested in the false dest, invert the condition + CmpInst::Predicate Pred = + isTrueDest ? Predicate : CmpInst::getInversePredicate(Predicate); + ConstantRange TrueValues = + ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + + if (Offset) // Apply the offset from above. + TrueValues = TrueValues.subtract(Offset->getValue()); + + return ValueLatticeElement::getRange(std::move(TrueValues)); + } + + return ValueLatticeElement::getOverdefined(); +} + +// Handle conditions of the form +// extractvalue(op.with.overflow(%x, C), 1). +static ValueLatticeElement getValueFromOverflowCondition( + Value *Val, WithOverflowInst *WO, bool IsTrueDest) { + // TODO: This only works with a constant RHS for now. We could also compute + // the range of the RHS, but this doesn't fit into the current structure of + // the edge value calculation. + const APInt *C; + if (WO->getLHS() != Val || !match(WO->getRHS(), m_APInt(C))) + return ValueLatticeElement::getOverdefined(); + + // Calculate the possible values of %x for which no overflow occurs. + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); + + // If overflow is false, %x is constrained to NWR. If overflow is true, %x is + // constrained to it's inverse (all values that might cause overflow). + if (IsTrueDest) + NWR = NWR.inverse(); + return ValueLatticeElement::getRange(NWR); +} + +static ValueLatticeElement +getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest, + DenseMap<Value*, ValueLatticeElement> &Visited); + +static ValueLatticeElement +getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, + DenseMap<Value*, ValueLatticeElement> &Visited) { + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond)) + return getValueFromICmpCondition(Val, ICI, isTrueDest); + + if (auto *EVI = dyn_cast<ExtractValueInst>(Cond)) + if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand())) + if (EVI->getNumIndices() == 1 && *EVI->idx_begin() == 1) + return getValueFromOverflowCondition(Val, WO, isTrueDest); + + // Handle conditions in the form of (cond1 && cond2), we know that on the + // true dest path both of the conditions hold. Similarly for conditions of + // the form (cond1 || cond2), we know that on the false dest path neither + // condition holds. + BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond); + if (!BO || (isTrueDest && BO->getOpcode() != BinaryOperator::And) || + (!isTrueDest && BO->getOpcode() != BinaryOperator::Or)) + return ValueLatticeElement::getOverdefined(); + + // Prevent infinite recursion if Cond references itself as in this example: + // Cond: "%tmp4 = and i1 %tmp4, undef" + // BL: "%tmp4 = and i1 %tmp4, undef" + // BR: "i1 undef" + Value *BL = BO->getOperand(0); + Value *BR = BO->getOperand(1); + if (BL == Cond || BR == Cond) + return ValueLatticeElement::getOverdefined(); + + return intersect(getValueFromCondition(Val, BL, isTrueDest, Visited), + getValueFromCondition(Val, BR, isTrueDest, Visited)); +} + +static ValueLatticeElement +getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest, + DenseMap<Value*, ValueLatticeElement> &Visited) { + auto I = Visited.find(Cond); + if (I != Visited.end()) + return I->second; + + auto Result = getValueFromConditionImpl(Val, Cond, isTrueDest, Visited); + Visited[Cond] = Result; + return Result; +} + +ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, + bool isTrueDest) { + assert(Cond && "precondition"); + DenseMap<Value*, ValueLatticeElement> Visited; + return getValueFromCondition(Val, Cond, isTrueDest, Visited); +} + +// Return true if Usr has Op as an operand, otherwise false. +static bool usesOperand(User *Usr, Value *Op) { + return find(Usr->operands(), Op) != Usr->op_end(); +} + +// Return true if the instruction type of Val is supported by +// constantFoldUser(). Currently CastInst and BinaryOperator only. Call this +// before calling constantFoldUser() to find out if it's even worth attempting +// to call it. +static bool isOperationFoldable(User *Usr) { + return isa<CastInst>(Usr) || isa<BinaryOperator>(Usr); +} + +// Check if Usr can be simplified to an integer constant when the value of one +// of its operands Op is an integer constant OpConstVal. If so, return it as an +// lattice value range with a single element or otherwise return an overdefined +// lattice value. +static ValueLatticeElement constantFoldUser(User *Usr, Value *Op, + const APInt &OpConstVal, + const DataLayout &DL) { + assert(isOperationFoldable(Usr) && "Precondition"); + Constant* OpConst = Constant::getIntegerValue(Op->getType(), OpConstVal); + // Check if Usr can be simplified to a constant. + if (auto *CI = dyn_cast<CastInst>(Usr)) { + assert(CI->getOperand(0) == Op && "Operand 0 isn't Op"); + if (auto *C = dyn_cast_or_null<ConstantInt>( + SimplifyCastInst(CI->getOpcode(), OpConst, + CI->getDestTy(), DL))) { + return ValueLatticeElement::getRange(ConstantRange(C->getValue())); + } + } else if (auto *BO = dyn_cast<BinaryOperator>(Usr)) { + bool Op0Match = BO->getOperand(0) == Op; + bool Op1Match = BO->getOperand(1) == Op; + assert((Op0Match || Op1Match) && + "Operand 0 nor Operand 1 isn't a match"); + Value *LHS = Op0Match ? OpConst : BO->getOperand(0); + Value *RHS = Op1Match ? OpConst : BO->getOperand(1); + if (auto *C = dyn_cast_or_null<ConstantInt>( + SimplifyBinOp(BO->getOpcode(), LHS, RHS, DL))) { + return ValueLatticeElement::getRange(ConstantRange(C->getValue())); + } + } + return ValueLatticeElement::getOverdefined(); +} + +/// Compute the value of Val on the edge BBFrom -> BBTo. Returns false if +/// Val is not constrained on the edge. Result is unspecified if return value +/// is false. +static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, + BasicBlock *BBTo, ValueLatticeElement &Result) { + // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we + // know that v != 0. + if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) { + // If this is a conditional branch and only one successor goes to BBTo, then + // we may be able to infer something from the condition. + if (BI->isConditional() && + BI->getSuccessor(0) != BI->getSuccessor(1)) { + bool isTrueDest = BI->getSuccessor(0) == BBTo; + assert(BI->getSuccessor(!isTrueDest) == BBTo && + "BBTo isn't a successor of BBFrom"); + Value *Condition = BI->getCondition(); + + // If V is the condition of the branch itself, then we know exactly what + // it is. + if (Condition == Val) { + Result = ValueLatticeElement::get(ConstantInt::get( + Type::getInt1Ty(Val->getContext()), isTrueDest)); + return true; + } + + // If the condition of the branch is an equality comparison, we may be + // able to infer the value. + Result = getValueFromCondition(Val, Condition, isTrueDest); + if (!Result.isOverdefined()) + return true; + + if (User *Usr = dyn_cast<User>(Val)) { + assert(Result.isOverdefined() && "Result isn't overdefined"); + // Check with isOperationFoldable() first to avoid linearly iterating + // over the operands unnecessarily which can be expensive for + // instructions with many operands. + if (isa<IntegerType>(Usr->getType()) && isOperationFoldable(Usr)) { + const DataLayout &DL = BBTo->getModule()->getDataLayout(); + if (usesOperand(Usr, Condition)) { + // If Val has Condition as an operand and Val can be folded into a + // constant with either Condition == true or Condition == false, + // propagate the constant. + // eg. + // ; %Val is true on the edge to %then. + // %Val = and i1 %Condition, true. + // br %Condition, label %then, label %else + APInt ConditionVal(1, isTrueDest ? 1 : 0); + Result = constantFoldUser(Usr, Condition, ConditionVal, DL); + } else { + // If one of Val's operand has an inferred value, we may be able to + // infer the value of Val. + // eg. + // ; %Val is 94 on the edge to %then. + // %Val = add i8 %Op, 1 + // %Condition = icmp eq i8 %Op, 93 + // br i1 %Condition, label %then, label %else + for (unsigned i = 0; i < Usr->getNumOperands(); ++i) { + Value *Op = Usr->getOperand(i); + ValueLatticeElement OpLatticeVal = + getValueFromCondition(Op, Condition, isTrueDest); + if (Optional<APInt> OpConst = OpLatticeVal.asConstantInteger()) { + Result = constantFoldUser(Usr, Op, OpConst.getValue(), DL); + break; + } + } + } + } + } + if (!Result.isOverdefined()) + return true; + } + } + + // If the edge was formed by a switch on the value, then we may know exactly + // what it is. + if (SwitchInst *SI = dyn_cast<SwitchInst>(BBFrom->getTerminator())) { + Value *Condition = SI->getCondition(); + if (!isa<IntegerType>(Val->getType())) + return false; + bool ValUsesConditionAndMayBeFoldable = false; + if (Condition != Val) { + // Check if Val has Condition as an operand. + if (User *Usr = dyn_cast<User>(Val)) + ValUsesConditionAndMayBeFoldable = isOperationFoldable(Usr) && + usesOperand(Usr, Condition); + if (!ValUsesConditionAndMayBeFoldable) + return false; + } + assert((Condition == Val || ValUsesConditionAndMayBeFoldable) && + "Condition != Val nor Val doesn't use Condition"); + + bool DefaultCase = SI->getDefaultDest() == BBTo; + unsigned BitWidth = Val->getType()->getIntegerBitWidth(); + ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/); + + for (auto Case : SI->cases()) { + APInt CaseValue = Case.getCaseValue()->getValue(); + ConstantRange EdgeVal(CaseValue); + if (ValUsesConditionAndMayBeFoldable) { + User *Usr = cast<User>(Val); + const DataLayout &DL = BBTo->getModule()->getDataLayout(); + ValueLatticeElement EdgeLatticeVal = + constantFoldUser(Usr, Condition, CaseValue, DL); + if (EdgeLatticeVal.isOverdefined()) + return false; + EdgeVal = EdgeLatticeVal.getConstantRange(); + } + if (DefaultCase) { + // It is possible that the default destination is the destination of + // some cases. We cannot perform difference for those cases. + // We know Condition != CaseValue in BBTo. In some cases we can use + // this to infer Val == f(Condition) is != f(CaseValue). For now, we + // only do this when f is identity (i.e. Val == Condition), but we + // should be able to do this for any injective f. + if (Case.getCaseSuccessor() != BBTo && Condition == Val) + EdgesVals = EdgesVals.difference(EdgeVal); + } else if (Case.getCaseSuccessor() == BBTo) + EdgesVals = EdgesVals.unionWith(EdgeVal); + } + Result = ValueLatticeElement::getRange(std::move(EdgesVals)); + return true; + } + return false; +} + +/// Compute the value of Val on the edge BBFrom -> BBTo or the value at +/// the basic block if the edge does not constrain Val. +bool LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom, + BasicBlock *BBTo, + ValueLatticeElement &Result, + Instruction *CxtI) { + // If already a constant, there is nothing to compute. + if (Constant *VC = dyn_cast<Constant>(Val)) { + Result = ValueLatticeElement::get(VC); + return true; + } + + ValueLatticeElement LocalResult; + if (!getEdgeValueLocal(Val, BBFrom, BBTo, LocalResult)) + // If we couldn't constrain the value on the edge, LocalResult doesn't + // provide any information. + LocalResult = ValueLatticeElement::getOverdefined(); + + if (hasSingleValue(LocalResult)) { + // Can't get any more precise here + Result = LocalResult; + return true; + } + + if (!hasBlockValue(Val, BBFrom)) { + if (pushBlockValue(std::make_pair(BBFrom, Val))) + return false; + // No new information. + Result = LocalResult; + return true; + } + + // Try to intersect ranges of the BB and the constraint on the edge. + ValueLatticeElement InBlock = getBlockValue(Val, BBFrom); + intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, + BBFrom->getTerminator()); + // We can use the context instruction (generically the ultimate instruction + // the calling pass is trying to simplify) here, even though the result of + // this function is generally cached when called from the solve* functions + // (and that cached result might be used with queries using a different + // context instruction), because when this function is called from the solve* + // functions, the context instruction is not provided. When called from + // LazyValueInfoImpl::getValueOnEdge, the context instruction is provided, + // but then the result is not cached. + intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI); + + Result = intersect(LocalResult, InBlock); + return true; +} + +ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI) { + LLVM_DEBUG(dbgs() << "LVI Getting block end value " << *V << " at '" + << BB->getName() << "'\n"); + + assert(BlockValueStack.empty() && BlockValueSet.empty()); + if (!hasBlockValue(V, BB)) { + pushBlockValue(std::make_pair(BB, V)); + solve(); + } + ValueLatticeElement Result = getBlockValue(V, BB); + intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI); + + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + return Result; +} + +ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { + LLVM_DEBUG(dbgs() << "LVI Getting value " << *V << " at '" << CxtI->getName() + << "'\n"); + + if (auto *C = dyn_cast<Constant>(V)) + return ValueLatticeElement::get(C); + + ValueLatticeElement Result = ValueLatticeElement::getOverdefined(); + if (auto *I = dyn_cast<Instruction>(V)) + Result = getFromRangeMetadata(I); + intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI); + + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + return Result; +} + +ValueLatticeElement LazyValueInfoImpl:: +getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { + LLVM_DEBUG(dbgs() << "LVI Getting edge value " << *V << " from '" + << FromBB->getName() << "' to '" << ToBB->getName() + << "'\n"); + + ValueLatticeElement Result; + if (!getEdgeValue(V, FromBB, ToBB, Result, CxtI)) { + solve(); + bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result, CxtI); + (void)WasFastQuery; + assert(WasFastQuery && "More work to do after problem solved?"); + } + + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + return Result; +} + +void LazyValueInfoImpl::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, + BasicBlock *NewSucc) { + TheCache.threadEdgeImpl(OldSucc, NewSucc); +} + +//===----------------------------------------------------------------------===// +// LazyValueInfo Impl +//===----------------------------------------------------------------------===// + +/// This lazily constructs the LazyValueInfoImpl. +static LazyValueInfoImpl &getImpl(void *&PImpl, AssumptionCache *AC, + const DataLayout *DL, + DominatorTree *DT = nullptr) { + if (!PImpl) { + assert(DL && "getCache() called with a null DataLayout"); + PImpl = new LazyValueInfoImpl(AC, *DL, DT); + } + return *static_cast<LazyValueInfoImpl*>(PImpl); +} + +bool LazyValueInfoWrapperPass::runOnFunction(Function &F) { + Info.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + const DataLayout &DL = F.getParent()->getDataLayout(); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + Info.DT = DTWP ? &DTWP->getDomTree() : nullptr; + Info.TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + + if (Info.PImpl) + getImpl(Info.PImpl, Info.AC, &DL, Info.DT).clear(); + + // Fully lazy. + return false; +} + +void LazyValueInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +LazyValueInfo &LazyValueInfoWrapperPass::getLVI() { return Info; } + +LazyValueInfo::~LazyValueInfo() { releaseMemory(); } + +void LazyValueInfo::releaseMemory() { + // If the cache was allocated, free it. + if (PImpl) { + delete &getImpl(PImpl, AC, nullptr); + PImpl = nullptr; + } +} + +bool LazyValueInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // We need to invalidate if we have either failed to preserve this analyses + // result directly or if any of its dependencies have been invalidated. + auto PAC = PA.getChecker<LazyValueAnalysis>(); + if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + (DT && Inv.invalidate<DominatorTreeAnalysis>(F, PA))) + return true; + + return false; +} + +void LazyValueInfoWrapperPass::releaseMemory() { Info.releaseMemory(); } + +LazyValueInfo LazyValueAnalysis::run(Function &F, + FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult<AssumptionAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + + return LazyValueInfo(&AC, &F.getParent()->getDataLayout(), &TLI, DT); +} + +/// Returns true if we can statically tell that this value will never be a +/// "useful" constant. In practice, this means we've got something like an +/// alloca or a malloc call for which a comparison against a constant can +/// only be guarding dead code. Note that we are potentially giving up some +/// precision in dead code (a constant result) in favour of avoiding a +/// expensive search for a easily answered common query. +static bool isKnownNonConstant(Value *V) { + V = V->stripPointerCasts(); + // The return val of alloc cannot be a Constant. + if (isa<AllocaInst>(V)) + return true; + return false; +} + +Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, + Instruction *CxtI) { + // Bail out early if V is known not to be a Constant. + if (isKnownNonConstant(V)) + return nullptr; + + const DataLayout &DL = BB->getModule()->getDataLayout(); + ValueLatticeElement Result = + getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI); + + if (Result.isConstant()) + return Result.getConstant(); + if (Result.isConstantRange()) { + const ConstantRange &CR = Result.getConstantRange(); + if (const APInt *SingleVal = CR.getSingleElement()) + return ConstantInt::get(V->getContext(), *SingleVal); + } + return nullptr; +} + +ConstantRange LazyValueInfo::getConstantRange(Value *V, BasicBlock *BB, + Instruction *CxtI) { + assert(V->getType()->isIntegerTy()); + unsigned Width = V->getType()->getIntegerBitWidth(); + const DataLayout &DL = BB->getModule()->getDataLayout(); + ValueLatticeElement Result = + getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI); + if (Result.isUndefined()) + return ConstantRange::getEmpty(Width); + if (Result.isConstantRange()) + return Result.getConstantRange(); + // We represent ConstantInt constants as constant ranges but other kinds + // of integer constants, i.e. ConstantExpr will be tagged as constants + assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) && + "ConstantInt value must be represented as constantrange"); + return ConstantRange::getFull(Width); +} + +/// Determine whether the specified value is known to be a +/// constant on the specified edge. Return null if not. +Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, + BasicBlock *ToBB, + Instruction *CxtI) { + const DataLayout &DL = FromBB->getModule()->getDataLayout(); + ValueLatticeElement Result = + getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + + if (Result.isConstant()) + return Result.getConstant(); + if (Result.isConstantRange()) { + const ConstantRange &CR = Result.getConstantRange(); + if (const APInt *SingleVal = CR.getSingleElement()) + return ConstantInt::get(V->getContext(), *SingleVal); + } + return nullptr; +} + +ConstantRange LazyValueInfo::getConstantRangeOnEdge(Value *V, + BasicBlock *FromBB, + BasicBlock *ToBB, + Instruction *CxtI) { + unsigned Width = V->getType()->getIntegerBitWidth(); + const DataLayout &DL = FromBB->getModule()->getDataLayout(); + ValueLatticeElement Result = + getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + + if (Result.isUndefined()) + return ConstantRange::getEmpty(Width); + if (Result.isConstantRange()) + return Result.getConstantRange(); + // We represent ConstantInt constants as constant ranges but other kinds + // of integer constants, i.e. ConstantExpr will be tagged as constants + assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) && + "ConstantInt value must be represented as constantrange"); + return ConstantRange::getFull(Width); +} + +static LazyValueInfo::Tristate +getPredicateResult(unsigned Pred, Constant *C, const ValueLatticeElement &Val, + const DataLayout &DL, TargetLibraryInfo *TLI) { + // If we know the value is a constant, evaluate the conditional. + Constant *Res = nullptr; + if (Val.isConstant()) { + Res = ConstantFoldCompareInstOperands(Pred, Val.getConstant(), C, DL, TLI); + if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res)) + return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True; + return LazyValueInfo::Unknown; + } + + if (Val.isConstantRange()) { + ConstantInt *CI = dyn_cast<ConstantInt>(C); + if (!CI) return LazyValueInfo::Unknown; + + const ConstantRange &CR = Val.getConstantRange(); + if (Pred == ICmpInst::ICMP_EQ) { + if (!CR.contains(CI->getValue())) + return LazyValueInfo::False; + + if (CR.isSingleElement()) + return LazyValueInfo::True; + } else if (Pred == ICmpInst::ICMP_NE) { + if (!CR.contains(CI->getValue())) + return LazyValueInfo::True; + + if (CR.isSingleElement()) + return LazyValueInfo::False; + } else { + // Handle more complex predicates. + ConstantRange TrueValues = ConstantRange::makeExactICmpRegion( + (ICmpInst::Predicate)Pred, CI->getValue()); + if (TrueValues.contains(CR)) + return LazyValueInfo::True; + if (TrueValues.inverse().contains(CR)) + return LazyValueInfo::False; + } + return LazyValueInfo::Unknown; + } + + if (Val.isNotConstant()) { + // If this is an equality comparison, we can try to fold it knowing that + // "V != C1". + if (Pred == ICmpInst::ICMP_EQ) { + // !C1 == C -> false iff C1 == C. + Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, + Val.getNotConstant(), C, DL, + TLI); + if (Res->isNullValue()) + return LazyValueInfo::False; + } else if (Pred == ICmpInst::ICMP_NE) { + // !C1 != C -> true iff C1 == C. + Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, + Val.getNotConstant(), C, DL, + TLI); + if (Res->isNullValue()) + return LazyValueInfo::True; + } + return LazyValueInfo::Unknown; + } + + return LazyValueInfo::Unknown; +} + +/// Determine whether the specified value comparison with a constant is known to +/// be true or false on the specified CFG edge. Pred is a CmpInst predicate. +LazyValueInfo::Tristate +LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, + BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { + const DataLayout &DL = FromBB->getModule()->getDataLayout(); + ValueLatticeElement Result = + getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + + return getPredicateResult(Pred, C, Result, DL, TLI); +} + +LazyValueInfo::Tristate +LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, + Instruction *CxtI) { + // Is or is not NonNull are common predicates being queried. If + // isKnownNonZero can tell us the result of the predicate, we can + // return it quickly. But this is only a fastpath, and falling + // through would still be correct. + const DataLayout &DL = CxtI->getModule()->getDataLayout(); + if (V->getType()->isPointerTy() && C->isNullValue() && + isKnownNonZero(V->stripPointerCastsSameRepresentation(), DL)) { + if (Pred == ICmpInst::ICMP_EQ) + return LazyValueInfo::False; + else if (Pred == ICmpInst::ICMP_NE) + return LazyValueInfo::True; + } + ValueLatticeElement Result = getImpl(PImpl, AC, &DL, DT).getValueAt(V, CxtI); + Tristate Ret = getPredicateResult(Pred, C, Result, DL, TLI); + if (Ret != Unknown) + return Ret; + + // Note: The following bit of code is somewhat distinct from the rest of LVI; + // LVI as a whole tries to compute a lattice value which is conservatively + // correct at a given location. In this case, we have a predicate which we + // weren't able to prove about the merged result, and we're pushing that + // predicate back along each incoming edge to see if we can prove it + // separately for each input. As a motivating example, consider: + // bb1: + // %v1 = ... ; constantrange<1, 5> + // br label %merge + // bb2: + // %v2 = ... ; constantrange<10, 20> + // br label %merge + // merge: + // %phi = phi [%v1, %v2] ; constantrange<1,20> + // %pred = icmp eq i32 %phi, 8 + // We can't tell from the lattice value for '%phi' that '%pred' is false + // along each path, but by checking the predicate over each input separately, + // we can. + // We limit the search to one step backwards from the current BB and value. + // We could consider extending this to search further backwards through the + // CFG and/or value graph, but there are non-obvious compile time vs quality + // tradeoffs. + if (CxtI) { + BasicBlock *BB = CxtI->getParent(); + + // Function entry or an unreachable block. Bail to avoid confusing + // analysis below. + pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + if (PI == PE) + return Unknown; + + // If V is a PHI node in the same block as the context, we need to ask + // questions about the predicate as applied to the incoming value along + // each edge. This is useful for eliminating cases where the predicate is + // known along all incoming edges. + if (auto *PHI = dyn_cast<PHINode>(V)) + if (PHI->getParent() == BB) { + Tristate Baseline = Unknown; + for (unsigned i = 0, e = PHI->getNumIncomingValues(); i < e; i++) { + Value *Incoming = PHI->getIncomingValue(i); + BasicBlock *PredBB = PHI->getIncomingBlock(i); + // Note that PredBB may be BB itself. + Tristate Result = getPredicateOnEdge(Pred, Incoming, C, PredBB, BB, + CxtI); + + // Keep going as long as we've seen a consistent known result for + // all inputs. + Baseline = (i == 0) ? Result /* First iteration */ + : (Baseline == Result ? Baseline : Unknown); /* All others */ + if (Baseline == Unknown) + break; + } + if (Baseline != Unknown) + return Baseline; + } + + // For a comparison where the V is outside this block, it's possible + // that we've branched on it before. Look to see if the value is known + // on all incoming edges. + if (!isa<Instruction>(V) || + cast<Instruction>(V)->getParent() != BB) { + // For predecessor edge, determine if the comparison is true or false + // on that edge. If they're all true or all false, we can conclude + // the value of the comparison in this block. + Tristate Baseline = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI); + if (Baseline != Unknown) { + // Check that all remaining incoming values match the first one. + while (++PI != PE) { + Tristate Ret = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI); + if (Ret != Baseline) break; + } + // If we terminated early, then one of the values didn't match. + if (PI == PE) { + return Baseline; + } + } + } + } + return Unknown; +} + +void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, + BasicBlock *NewSucc) { + if (PImpl) { + const DataLayout &DL = PredBB->getModule()->getDataLayout(); + getImpl(PImpl, AC, &DL, DT).threadEdge(PredBB, OldSucc, NewSucc); + } +} + +void LazyValueInfo::eraseBlock(BasicBlock *BB) { + if (PImpl) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + getImpl(PImpl, AC, &DL, DT).eraseBlock(BB); + } +} + + +void LazyValueInfo::printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) { + if (PImpl) { + getImpl(PImpl, AC, DL, DT).printLVI(F, DTree, OS); + } +} + +void LazyValueInfo::disableDT() { + if (PImpl) + getImpl(PImpl, AC, DL, DT).disableDT(); +} + +void LazyValueInfo::enableDT() { + if (PImpl) + getImpl(PImpl, AC, DL, DT).enableDT(); +} + +// Print the LVI for the function arguments at the start of each basic block. +void LazyValueInfoAnnotatedWriter::emitBasicBlockStartAnnot( + const BasicBlock *BB, formatted_raw_ostream &OS) { + // Find if there are latticevalues defined for arguments of the function. + auto *F = BB->getParent(); + for (auto &Arg : F->args()) { + ValueLatticeElement Result = LVIImpl->getValueInBlock( + const_cast<Argument *>(&Arg), const_cast<BasicBlock *>(BB)); + if (Result.isUndefined()) + continue; + OS << "; LatticeVal for: '" << Arg << "' is: " << Result << "\n"; + } +} + +// This function prints the LVI analysis for the instruction I at the beginning +// of various basic blocks. It relies on calculated values that are stored in +// the LazyValueInfoCache, and in the absence of cached values, recalculate the +// LazyValueInfo for `I`, and print that info. +void LazyValueInfoAnnotatedWriter::emitInstructionAnnot( + const Instruction *I, formatted_raw_ostream &OS) { + + auto *ParentBB = I->getParent(); + SmallPtrSet<const BasicBlock*, 16> BlocksContainingLVI; + // We can generate (solve) LVI values only for blocks that are dominated by + // the I's parent. However, to avoid generating LVI for all dominating blocks, + // that contain redundant/uninteresting information, we print LVI for + // blocks that may use this LVI information (such as immediate successor + // blocks, and blocks that contain uses of `I`). + auto printResult = [&](const BasicBlock *BB) { + if (!BlocksContainingLVI.insert(BB).second) + return; + ValueLatticeElement Result = LVIImpl->getValueInBlock( + const_cast<Instruction *>(I), const_cast<BasicBlock *>(BB)); + OS << "; LatticeVal for: '" << *I << "' in BB: '"; + BB->printAsOperand(OS, false); + OS << "' is: " << Result << "\n"; + }; + + printResult(ParentBB); + // Print the LVI analysis results for the immediate successor blocks, that + // are dominated by `ParentBB`. + for (auto *BBSucc : successors(ParentBB)) + if (DT.dominates(ParentBB, BBSucc)) + printResult(BBSucc); + + // Print LVI in blocks where `I` is used. + for (auto *U : I->users()) + if (auto *UseI = dyn_cast<Instruction>(U)) + if (!isa<PHINode>(UseI) || DT.dominates(ParentBB, UseI->getParent())) + printResult(UseI->getParent()); + +} + +namespace { +// Printer class for LazyValueInfo results. +class LazyValueInfoPrinter : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + LazyValueInfoPrinter() : FunctionPass(ID) { + initializeLazyValueInfoPrinterPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<LazyValueInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + } + + // Get the mandatory dominator tree analysis and pass this in to the + // LVIPrinter. We cannot rely on the LVI's DT, since it's optional. + bool runOnFunction(Function &F) override { + dbgs() << "LVI for function '" << F.getName() << "':\n"; + auto &LVI = getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto &DTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LVI.printLVI(F, DTree, dbgs()); + return false; + } +}; +} + +char LazyValueInfoPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoPrinter, "print-lazy-value-info", + "Lazy Value Info Printer Pass", false, false) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LazyValueInfoPrinter, "print-lazy-value-info", + "Lazy Value Info Printer Pass", false, false) diff --git a/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp new file mode 100644 index 000000000000..7de9d2cbfddb --- /dev/null +++ b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp @@ -0,0 +1,404 @@ +//===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis +//Implementation -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements divergence analysis which determines whether a branch +// in a GPU program is divergent.It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same execution group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they converge at the immediately post-dominating BB of the paths. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can be unfortunately harmful when performed on +// divergent branches. Therefore, an analysis that computes which branches in a +// GPU program are divergent can help the compiler to selectively run these +// optimizations. +// +// This file defines divergence analysis which computes a conservative but +// non-trivial approximation of all divergent branches in a GPU program. It +// partially implements the approach described in +// +// Divergence Analysis +// Sampaio, Souza, Collange, Pereira +// TOPLAS '13 +// +// The divergence analysis identifies the sources of divergence (e.g., special +// variables that hold the thread ID), and recursively marks variables that are +// data or sync dependent on a source of divergence as divergent. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// The current implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +// non-kernel-entry function and the return value of a function call as +// divergent. +// 2. memory as black box. It conservatively considers values loaded from +// generic or local address as divergent. This can be improved by leveraging +// pointer analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LegacyDivergenceAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <vector> +using namespace llvm; + +#define DEBUG_TYPE "divergence" + +// transparently use the GPUDivergenceAnalysis +static cl::opt<bool> UseGPUDA("use-gpu-divergence-analysis", cl::init(false), + cl::Hidden, + cl::desc("turn the LegacyDivergenceAnalysis into " + "a wrapper for GPUDivergenceAnalysis")); + +namespace { + +class DivergencePropagator { +public: + DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, + PostDominatorTree &PDT, DenseSet<const Value *> &DV, + DenseSet<const Use *> &DU) + : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV), DU(DU) {} + void populateWithSourcesOfDivergence(); + void propagate(); + +private: + // A helper function that explores data dependents of V. + void exploreDataDependency(Value *V); + // A helper function that explores sync dependents of TI. + void exploreSyncDependency(Instruction *TI); + // Computes the influence region from Start to End. This region includes all + // basic blocks on any simple path from Start to End. + void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion); + // Finds all users of I that are outside the influence region, and add these + // users to Worklist. + void findUsersOutsideInfluenceRegion( + Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion); + + Function &F; + TargetTransformInfo &TTI; + DominatorTree &DT; + PostDominatorTree &PDT; + std::vector<Value *> Worklist; // Stack for DFS. + DenseSet<const Value *> &DV; // Stores all divergent values. + DenseSet<const Use *> &DU; // Stores divergent uses of possibly uniform + // values. +}; + +void DivergencePropagator::populateWithSourcesOfDivergence() { + Worklist.clear(); + DV.clear(); + DU.clear(); + for (auto &I : instructions(F)) { + if (TTI.isSourceOfDivergence(&I)) { + Worklist.push_back(&I); + DV.insert(&I); + } + } + for (auto &Arg : F.args()) { + if (TTI.isSourceOfDivergence(&Arg)) { + Worklist.push_back(&Arg); + DV.insert(&Arg); + } + } +} + +void DivergencePropagator::exploreSyncDependency(Instruction *TI) { + // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's + // immediate post dominator are divergent. This rule handles if-then-else + // patterns. For example, + // + // if (tid < 5) + // a1 = 1; + // else + // a2 = 2; + // a = phi(a1, a2); // sync dependent on (tid < 5) + BasicBlock *ThisBB = TI->getParent(); + + // Unreachable blocks may not be in the dominator tree. + if (!DT.isReachableFromEntry(ThisBB)) + return; + + // If the function has no exit blocks or doesn't reach any exit blocks, the + // post dominator may be null. + DomTreeNode *ThisNode = PDT.getNode(ThisBB); + if (!ThisNode) + return; + + BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); + if (IPostDom == nullptr) + return; + + for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { + // A PHINode is uniform if it returns the same value no matter which path is + // taken. + if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) + Worklist.push_back(&*I); + } + + // Propagation rule 2: if a value defined in a loop is used outside, the user + // is sync dependent on the condition of the loop exits that dominate the + // user. For example, + // + // int i = 0; + // do { + // i++; + // if (foo(i)) ... // uniform + // } while (i < tid); + // if (bar(i)) ... // divergent + // + // A program may contain unstructured loops. Therefore, we cannot leverage + // LoopInfo, which only recognizes natural loops. + // + // The algorithm used here handles both natural and unstructured loops. Given + // a branch TI, we first compute its influence region, the union of all simple + // paths from TI to its immediate post dominator (IPostDom). Then, we search + // for all the values defined in the influence region but used outside. All + // these users are sync dependent on TI. + DenseSet<BasicBlock *> InfluenceRegion; + computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion); + // An insight that can speed up the search process is that all the in-region + // values that are used outside must dominate TI. Therefore, instead of + // searching every basic blocks in the influence region, we search all the + // dominators of TI until it is outside the influence region. + BasicBlock *InfluencedBB = ThisBB; + while (InfluenceRegion.count(InfluencedBB)) { + for (auto &I : *InfluencedBB) { + if (!DV.count(&I)) + findUsersOutsideInfluenceRegion(I, InfluenceRegion); + } + DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); + if (IDomNode == nullptr) + break; + InfluencedBB = IDomNode->getBlock(); + } +} + +void DivergencePropagator::findUsersOutsideInfluenceRegion( + Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { + for (Use &Use : I.uses()) { + Instruction *UserInst = cast<Instruction>(Use.getUser()); + if (!InfluenceRegion.count(UserInst->getParent())) { + DU.insert(&Use); + if (DV.insert(UserInst).second) + Worklist.push_back(UserInst); + } + } +} + +// A helper function for computeInfluenceRegion that adds successors of "ThisBB" +// to the influence region. +static void +addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion, + std::vector<BasicBlock *> &InfluenceStack) { + for (BasicBlock *Succ : successors(ThisBB)) { + if (Succ != End && InfluenceRegion.insert(Succ).second) + InfluenceStack.push_back(Succ); + } +} + +void DivergencePropagator::computeInfluenceRegion( + BasicBlock *Start, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion) { + assert(PDT.properlyDominates(End, Start) && + "End does not properly dominate Start"); + + // The influence region starts from the end of "Start" to the beginning of + // "End". Therefore, "Start" should not be in the region unless "Start" is in + // a loop that doesn't contain "End". + std::vector<BasicBlock *> InfluenceStack; + addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack); + while (!InfluenceStack.empty()) { + BasicBlock *BB = InfluenceStack.back(); + InfluenceStack.pop_back(); + addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack); + } +} + +void DivergencePropagator::exploreDataDependency(Value *V) { + // Follow def-use chains of V. + for (User *U : V->users()) { + if (!TTI.isAlwaysUniform(U) && DV.insert(U).second) + Worklist.push_back(U); + } +} + +void DivergencePropagator::propagate() { + // Traverse the dependency graph using DFS. + while (!Worklist.empty()) { + Value *V = Worklist.back(); + Worklist.pop_back(); + if (Instruction *I = dyn_cast<Instruction>(V)) { + // Terminators with less than two successors won't introduce sync + // dependency. Ignore them. + if (I->isTerminator() && I->getNumSuccessors() > 1) + exploreSyncDependency(I); + } + exploreDataDependency(V); + } +} + +} // namespace + +// Register this pass. +char LegacyDivergenceAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence", + "Legacy Divergence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence", + "Legacy Divergence Analysis", false, true) + +FunctionPass *llvm::createLegacyDivergenceAnalysisPass() { + return new LegacyDivergenceAnalysis(); +} + +void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + if (UseGPUDA) + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesAll(); +} + +bool LegacyDivergenceAnalysis::shouldUseGPUDivergenceAnalysis( + const Function &F) const { + if (!UseGPUDA) + return false; + + // GPUDivergenceAnalysis requires a reducible CFG. + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + using RPOTraversal = ReversePostOrderTraversal<const Function *>; + RPOTraversal FuncRPOT(&F); + return !containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, + const LoopInfo>(FuncRPOT, LI); +} + +bool LegacyDivergenceAnalysis::runOnFunction(Function &F) { + auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + if (TTIWP == nullptr) + return false; + + TargetTransformInfo &TTI = TTIWP->getTTI(F); + // Fast path: if the target does not have branch divergence, we do not mark + // any branch as divergent. + if (!TTI.hasBranchDivergence()) + return false; + + DivergentValues.clear(); + DivergentUses.clear(); + gpuDA = nullptr; + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + + if (shouldUseGPUDivergenceAnalysis(F)) { + // run the new GPU divergence analysis + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + gpuDA = std::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI); + + } else { + // run LLVM's existing DivergenceAnalysis + DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues, DivergentUses); + DP.populateWithSourcesOfDivergence(); + DP.propagate(); + } + + LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName() + << ":\n"; + print(dbgs(), F.getParent())); + + return false; +} + +bool LegacyDivergenceAnalysis::isDivergent(const Value *V) const { + if (gpuDA) { + return gpuDA->isDivergent(*V); + } + return DivergentValues.count(V); +} + +bool LegacyDivergenceAnalysis::isDivergentUse(const Use *U) const { + if (gpuDA) { + return gpuDA->isDivergentUse(*U); + } + return DivergentValues.count(U->get()) || DivergentUses.count(U); +} + +void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const { + if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty()) + return; + + const Function *F = nullptr; + if (!DivergentValues.empty()) { + const Value *FirstDivergentValue = *DivergentValues.begin(); + if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) { + F = Arg->getParent(); + } else if (const Instruction *I = + dyn_cast<Instruction>(FirstDivergentValue)) { + F = I->getParent()->getParent(); + } else { + llvm_unreachable("Only arguments and instructions can be divergent"); + } + } else if (gpuDA) { + F = &gpuDA->getFunction(); + } + if (!F) + return; + + // Dumps all divergent values in F, arguments and then instructions. + for (auto &Arg : F->args()) { + OS << (isDivergent(&Arg) ? "DIVERGENT: " : " "); + OS << Arg << "\n"; + } + // Iterate instructions using instructions() to ensure a deterministic order. + for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) { + auto &BB = *BI; + OS << "\n " << BB.getName() << ":\n"; + for (auto &I : BB.instructionsWithoutDebug()) { + OS << (isDivergent(&I) ? "DIVERGENT: " : " "); + OS << I << "\n"; + } + } + OS << "\n"; +} diff --git a/llvm/lib/Analysis/Lint.cpp b/llvm/lib/Analysis/Lint.cpp new file mode 100644 index 000000000000..db18716c64cf --- /dev/null +++ b/llvm/lib/Analysis/Lint.cpp @@ -0,0 +1,756 @@ +//===-- Lint.cpp - Check for common errors in LLVM IR ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass statically checks for common and easily-identified constructs +// which produce undefined or likely unintended behavior in LLVM IR. +// +// It is not a guarantee of correctness, in two ways. First, it isn't +// comprehensive. There are checks which could be done statically which are +// not yet implemented. Some of these are indicated by TODO comments, but +// those aren't comprehensive either. Second, many conditions cannot be +// checked statically. This pass does no dynamic instrumentation, so it +// can't check for all possible problems. +// +// Another limitation is that it assumes all code will be executed. A store +// through a null pointer in a basic block which is never reached is harmless, +// but this pass will warn about it anyway. This is the main reason why most +// of these checks live here instead of in the Verifier pass. +// +// Optimization passes may make conditions that this pass checks for more or +// less obvious. If an optimization pass appears to be introducing a warning, +// it may be that the optimization pass is merely exposing an existing +// condition in the code. +// +// This code may be run before instcombine. In many cases, instcombine checks +// for the same kinds of things and turns instructions with undefined behavior +// into unreachable (or equivalent). Because of this, this pass makes some +// effort to look through bitcasts and so on. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Lint.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <string> + +using namespace llvm; + +namespace { + namespace MemRef { + static const unsigned Read = 1; + static const unsigned Write = 2; + static const unsigned Callee = 4; + static const unsigned Branchee = 8; + } // end namespace MemRef + + class Lint : public FunctionPass, public InstVisitor<Lint> { + friend class InstVisitor<Lint>; + + void visitFunction(Function &F); + + void visitCallSite(CallSite CS); + void visitMemoryReference(Instruction &I, Value *Ptr, + uint64_t Size, unsigned Align, + Type *Ty, unsigned Flags); + void visitEHBeginCatch(IntrinsicInst *II); + void visitEHEndCatch(IntrinsicInst *II); + + void visitCallInst(CallInst &I); + void visitInvokeInst(InvokeInst &I); + void visitReturnInst(ReturnInst &I); + void visitLoadInst(LoadInst &I); + void visitStoreInst(StoreInst &I); + void visitXor(BinaryOperator &I); + void visitSub(BinaryOperator &I); + void visitLShr(BinaryOperator &I); + void visitAShr(BinaryOperator &I); + void visitShl(BinaryOperator &I); + void visitSDiv(BinaryOperator &I); + void visitUDiv(BinaryOperator &I); + void visitSRem(BinaryOperator &I); + void visitURem(BinaryOperator &I); + void visitAllocaInst(AllocaInst &I); + void visitVAArgInst(VAArgInst &I); + void visitIndirectBrInst(IndirectBrInst &I); + void visitExtractElementInst(ExtractElementInst &I); + void visitInsertElementInst(InsertElementInst &I); + void visitUnreachableInst(UnreachableInst &I); + + Value *findValue(Value *V, bool OffsetOk) const; + Value *findValueImpl(Value *V, bool OffsetOk, + SmallPtrSetImpl<Value *> &Visited) const; + + public: + Module *Mod; + const DataLayout *DL; + AliasAnalysis *AA; + AssumptionCache *AC; + DominatorTree *DT; + TargetLibraryInfo *TLI; + + std::string Messages; + raw_string_ostream MessagesStr; + + static char ID; // Pass identification, replacement for typeid + Lint() : FunctionPass(ID), MessagesStr(Messages) { + initializeLintPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + } + void print(raw_ostream &O, const Module *M) const override {} + + void WriteValues(ArrayRef<const Value *> Vs) { + for (const Value *V : Vs) { + if (!V) + continue; + if (isa<Instruction>(V)) { + MessagesStr << *V << '\n'; + } else { + V->printAsOperand(MessagesStr, true, Mod); + MessagesStr << '\n'; + } + } + } + + /// A check failed, so printout out the condition and the message. + /// + /// This provides a nice place to put a breakpoint if you want to see why + /// something is not correct. + void CheckFailed(const Twine &Message) { MessagesStr << Message << '\n'; } + + /// A check failed (with values to print). + /// + /// This calls the Message-only version so that the above is easier to set + /// a breakpoint on. + template <typename T1, typename... Ts> + void CheckFailed(const Twine &Message, const T1 &V1, const Ts &...Vs) { + CheckFailed(Message); + WriteValues({V1, Vs...}); + } + }; +} // end anonymous namespace + +char Lint::ID = 0; +INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR", + false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR", + false, true) + +// Assert - We know that cond should be true, if not print an error message. +#define Assert(C, ...) \ + do { if (!(C)) { CheckFailed(__VA_ARGS__); return; } } while (false) + +// Lint::run - This is the main Analysis entry point for a +// function. +// +bool Lint::runOnFunction(Function &F) { + Mod = F.getParent(); + DL = &F.getParent()->getDataLayout(); + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + visit(F); + dbgs() << MessagesStr.str(); + Messages.clear(); + return false; +} + +void Lint::visitFunction(Function &F) { + // This isn't undefined behavior, it's just a little unusual, and it's a + // fairly common mistake to neglect to name a function. + Assert(F.hasName() || F.hasLocalLinkage(), + "Unusual: Unnamed function with non-local linkage", &F); + + // TODO: Check for irreducible control flow. +} + +void Lint::visitCallSite(CallSite CS) { + Instruction &I = *CS.getInstruction(); + Value *Callee = CS.getCalledValue(); + + visitMemoryReference(I, Callee, MemoryLocation::UnknownSize, 0, nullptr, + MemRef::Callee); + + if (Function *F = dyn_cast<Function>(findValue(Callee, + /*OffsetOk=*/false))) { + Assert(CS.getCallingConv() == F->getCallingConv(), + "Undefined behavior: Caller and callee calling convention differ", + &I); + + FunctionType *FT = F->getFunctionType(); + unsigned NumActualArgs = CS.arg_size(); + + Assert(FT->isVarArg() ? FT->getNumParams() <= NumActualArgs + : FT->getNumParams() == NumActualArgs, + "Undefined behavior: Call argument count mismatches callee " + "argument count", + &I); + + Assert(FT->getReturnType() == I.getType(), + "Undefined behavior: Call return type mismatches " + "callee return type", + &I); + + // Check argument types (in case the callee was casted) and attributes. + // TODO: Verify that caller and callee attributes are compatible. + Function::arg_iterator PI = F->arg_begin(), PE = F->arg_end(); + CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end(); + for (; AI != AE; ++AI) { + Value *Actual = *AI; + if (PI != PE) { + Argument *Formal = &*PI++; + Assert(Formal->getType() == Actual->getType(), + "Undefined behavior: Call argument type mismatches " + "callee parameter type", + &I); + + // Check that noalias arguments don't alias other arguments. This is + // not fully precise because we don't know the sizes of the dereferenced + // memory regions. + if (Formal->hasNoAliasAttr() && Actual->getType()->isPointerTy()) { + AttributeList PAL = CS.getAttributes(); + unsigned ArgNo = 0; + for (CallSite::arg_iterator BI = CS.arg_begin(); BI != AE; + ++BI, ++ArgNo) { + // Skip ByVal arguments since they will be memcpy'd to the callee's + // stack so we're not really passing the pointer anyway. + if (PAL.hasParamAttribute(ArgNo, Attribute::ByVal)) + continue; + // If both arguments are readonly, they have no dependence. + if (Formal->onlyReadsMemory() && CS.onlyReadsMemory(ArgNo)) + continue; + if (AI != BI && (*BI)->getType()->isPointerTy()) { + AliasResult Result = AA->alias(*AI, *BI); + Assert(Result != MustAlias && Result != PartialAlias, + "Unusual: noalias argument aliases another argument", &I); + } + } + } + + // Check that an sret argument points to valid memory. + if (Formal->hasStructRetAttr() && Actual->getType()->isPointerTy()) { + Type *Ty = + cast<PointerType>(Formal->getType())->getElementType(); + visitMemoryReference(I, Actual, DL->getTypeStoreSize(Ty), + DL->getABITypeAlignment(Ty), Ty, + MemRef::Read | MemRef::Write); + } + } + } + } + + if (CS.isCall()) { + const CallInst *CI = cast<CallInst>(CS.getInstruction()); + if (CI->isTailCall()) { + const AttributeList &PAL = CI->getAttributes(); + unsigned ArgNo = 0; + for (Value *Arg : CS.args()) { + // Skip ByVal arguments since they will be memcpy'd to the callee's + // stack anyway. + if (PAL.hasParamAttribute(ArgNo++, Attribute::ByVal)) + continue; + Value *Obj = findValue(Arg, /*OffsetOk=*/true); + Assert(!isa<AllocaInst>(Obj), + "Undefined behavior: Call with \"tail\" keyword references " + "alloca", + &I); + } + } + } + + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I)) + switch (II->getIntrinsicID()) { + default: break; + + // TODO: Check more intrinsics + + case Intrinsic::memcpy: { + MemCpyInst *MCI = cast<MemCpyInst>(&I); + // TODO: If the size is known, use it. + visitMemoryReference(I, MCI->getDest(), MemoryLocation::UnknownSize, + MCI->getDestAlignment(), nullptr, MemRef::Write); + visitMemoryReference(I, MCI->getSource(), MemoryLocation::UnknownSize, + MCI->getSourceAlignment(), nullptr, MemRef::Read); + + // Check that the memcpy arguments don't overlap. The AliasAnalysis API + // isn't expressive enough for what we really want to do. Known partial + // overlap is not distinguished from the case where nothing is known. + auto Size = LocationSize::unknown(); + if (const ConstantInt *Len = + dyn_cast<ConstantInt>(findValue(MCI->getLength(), + /*OffsetOk=*/false))) + if (Len->getValue().isIntN(32)) + Size = LocationSize::precise(Len->getValue().getZExtValue()); + Assert(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) != + MustAlias, + "Undefined behavior: memcpy source and destination overlap", &I); + break; + } + case Intrinsic::memmove: { + MemMoveInst *MMI = cast<MemMoveInst>(&I); + // TODO: If the size is known, use it. + visitMemoryReference(I, MMI->getDest(), MemoryLocation::UnknownSize, + MMI->getDestAlignment(), nullptr, MemRef::Write); + visitMemoryReference(I, MMI->getSource(), MemoryLocation::UnknownSize, + MMI->getSourceAlignment(), nullptr, MemRef::Read); + break; + } + case Intrinsic::memset: { + MemSetInst *MSI = cast<MemSetInst>(&I); + // TODO: If the size is known, use it. + visitMemoryReference(I, MSI->getDest(), MemoryLocation::UnknownSize, + MSI->getDestAlignment(), nullptr, MemRef::Write); + break; + } + + case Intrinsic::vastart: + Assert(I.getParent()->getParent()->isVarArg(), + "Undefined behavior: va_start called in a non-varargs function", + &I); + + visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Read | MemRef::Write); + break; + case Intrinsic::vacopy: + visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Write); + visitMemoryReference(I, CS.getArgument(1), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Read); + break; + case Intrinsic::vaend: + visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Read | MemRef::Write); + break; + + case Intrinsic::stackrestore: + // Stackrestore doesn't read or write memory, but it sets the + // stack pointer, which the compiler may read from or write to + // at any time, so check it for both readability and writeability. + visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Read | MemRef::Write); + break; + } +} + +void Lint::visitCallInst(CallInst &I) { + return visitCallSite(&I); +} + +void Lint::visitInvokeInst(InvokeInst &I) { + return visitCallSite(&I); +} + +void Lint::visitReturnInst(ReturnInst &I) { + Function *F = I.getParent()->getParent(); + Assert(!F->doesNotReturn(), + "Unusual: Return statement in function with noreturn attribute", &I); + + if (Value *V = I.getReturnValue()) { + Value *Obj = findValue(V, /*OffsetOk=*/true); + Assert(!isa<AllocaInst>(Obj), "Unusual: Returning alloca value", &I); + } +} + +// TODO: Check that the reference is in bounds. +// TODO: Check readnone/readonly function attributes. +void Lint::visitMemoryReference(Instruction &I, + Value *Ptr, uint64_t Size, unsigned Align, + Type *Ty, unsigned Flags) { + // If no memory is being referenced, it doesn't matter if the pointer + // is valid. + if (Size == 0) + return; + + Value *UnderlyingObject = findValue(Ptr, /*OffsetOk=*/true); + Assert(!isa<ConstantPointerNull>(UnderlyingObject), + "Undefined behavior: Null pointer dereference", &I); + Assert(!isa<UndefValue>(UnderlyingObject), + "Undefined behavior: Undef pointer dereference", &I); + Assert(!isa<ConstantInt>(UnderlyingObject) || + !cast<ConstantInt>(UnderlyingObject)->isMinusOne(), + "Unusual: All-ones pointer dereference", &I); + Assert(!isa<ConstantInt>(UnderlyingObject) || + !cast<ConstantInt>(UnderlyingObject)->isOne(), + "Unusual: Address one pointer dereference", &I); + + if (Flags & MemRef::Write) { + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(UnderlyingObject)) + Assert(!GV->isConstant(), "Undefined behavior: Write to read-only memory", + &I); + Assert(!isa<Function>(UnderlyingObject) && + !isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Write to text section", &I); + } + if (Flags & MemRef::Read) { + Assert(!isa<Function>(UnderlyingObject), "Unusual: Load from function body", + &I); + Assert(!isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Load from block address", &I); + } + if (Flags & MemRef::Callee) { + Assert(!isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Call to block address", &I); + } + if (Flags & MemRef::Branchee) { + Assert(!isa<Constant>(UnderlyingObject) || + isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Branch to non-blockaddress", &I); + } + + // Check for buffer overflows and misalignment. + // Only handles memory references that read/write something simple like an + // alloca instruction or a global variable. + int64_t Offset = 0; + if (Value *Base = GetPointerBaseWithConstantOffset(Ptr, Offset, *DL)) { + // OK, so the access is to a constant offset from Ptr. Check that Ptr is + // something we can handle and if so extract the size of this base object + // along with its alignment. + uint64_t BaseSize = MemoryLocation::UnknownSize; + unsigned BaseAlign = 0; + + if (AllocaInst *AI = dyn_cast<AllocaInst>(Base)) { + Type *ATy = AI->getAllocatedType(); + if (!AI->isArrayAllocation() && ATy->isSized()) + BaseSize = DL->getTypeAllocSize(ATy); + BaseAlign = AI->getAlignment(); + if (BaseAlign == 0 && ATy->isSized()) + BaseAlign = DL->getABITypeAlignment(ATy); + } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { + // If the global may be defined differently in another compilation unit + // then don't warn about funky memory accesses. + if (GV->hasDefinitiveInitializer()) { + Type *GTy = GV->getValueType(); + if (GTy->isSized()) + BaseSize = DL->getTypeAllocSize(GTy); + BaseAlign = GV->getAlignment(); + if (BaseAlign == 0 && GTy->isSized()) + BaseAlign = DL->getABITypeAlignment(GTy); + } + } + + // Accesses from before the start or after the end of the object are not + // defined. + Assert(Size == MemoryLocation::UnknownSize || + BaseSize == MemoryLocation::UnknownSize || + (Offset >= 0 && Offset + Size <= BaseSize), + "Undefined behavior: Buffer overflow", &I); + + // Accesses that say that the memory is more aligned than it is are not + // defined. + if (Align == 0 && Ty && Ty->isSized()) + Align = DL->getABITypeAlignment(Ty); + Assert(!BaseAlign || Align <= MinAlign(BaseAlign, Offset), + "Undefined behavior: Memory reference address is misaligned", &I); + } +} + +void Lint::visitLoadInst(LoadInst &I) { + visitMemoryReference(I, I.getPointerOperand(), + DL->getTypeStoreSize(I.getType()), I.getAlignment(), + I.getType(), MemRef::Read); +} + +void Lint::visitStoreInst(StoreInst &I) { + visitMemoryReference(I, I.getPointerOperand(), + DL->getTypeStoreSize(I.getOperand(0)->getType()), + I.getAlignment(), + I.getOperand(0)->getType(), MemRef::Write); +} + +void Lint::visitXor(BinaryOperator &I) { + Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), + "Undefined result: xor(undef, undef)", &I); +} + +void Lint::visitSub(BinaryOperator &I) { + Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), + "Undefined result: sub(undef, undef)", &I); +} + +void Lint::visitLShr(BinaryOperator &I) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1), + /*OffsetOk=*/false))) + Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); +} + +void Lint::visitAShr(BinaryOperator &I) { + if (ConstantInt *CI = + dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) + Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); +} + +void Lint::visitShl(BinaryOperator &I) { + if (ConstantInt *CI = + dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) + Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); +} + +static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, + AssumptionCache *AC) { + // Assume undef could be zero. + if (isa<UndefValue>(V)) + return true; + + VectorType *VecTy = dyn_cast<VectorType>(V->getType()); + if (!VecTy) { + KnownBits Known = computeKnownBits(V, DL, 0, AC, dyn_cast<Instruction>(V), DT); + return Known.isZero(); + } + + // Per-component check doesn't work with zeroinitializer + Constant *C = dyn_cast<Constant>(V); + if (!C) + return false; + + if (C->isZeroValue()) + return true; + + // For a vector, KnownZero will only be true if all values are zero, so check + // this per component + for (unsigned I = 0, N = VecTy->getNumElements(); I != N; ++I) { + Constant *Elem = C->getAggregateElement(I); + if (isa<UndefValue>(Elem)) + return true; + + KnownBits Known = computeKnownBits(Elem, DL); + if (Known.isZero()) + return true; + } + + return false; +} + +void Lint::visitSDiv(BinaryOperator &I) { + Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); +} + +void Lint::visitUDiv(BinaryOperator &I) { + Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); +} + +void Lint::visitSRem(BinaryOperator &I) { + Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); +} + +void Lint::visitURem(BinaryOperator &I) { + Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); +} + +void Lint::visitAllocaInst(AllocaInst &I) { + if (isa<ConstantInt>(I.getArraySize())) + // This isn't undefined behavior, it's just an obvious pessimization. + Assert(&I.getParent()->getParent()->getEntryBlock() == I.getParent(), + "Pessimization: Static alloca outside of entry block", &I); + + // TODO: Check for an unusual size (MSB set?) +} + +void Lint::visitVAArgInst(VAArgInst &I) { + visitMemoryReference(I, I.getOperand(0), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Read | MemRef::Write); +} + +void Lint::visitIndirectBrInst(IndirectBrInst &I) { + visitMemoryReference(I, I.getAddress(), MemoryLocation::UnknownSize, 0, + nullptr, MemRef::Branchee); + + Assert(I.getNumDestinations() != 0, + "Undefined behavior: indirectbr with no destinations", &I); +} + +void Lint::visitExtractElementInst(ExtractElementInst &I) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getIndexOperand(), + /*OffsetOk=*/false))) + Assert(CI->getValue().ult(I.getVectorOperandType()->getNumElements()), + "Undefined result: extractelement index out of range", &I); +} + +void Lint::visitInsertElementInst(InsertElementInst &I) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(2), + /*OffsetOk=*/false))) + Assert(CI->getValue().ult(I.getType()->getNumElements()), + "Undefined result: insertelement index out of range", &I); +} + +void Lint::visitUnreachableInst(UnreachableInst &I) { + // This isn't undefined behavior, it's merely suspicious. + Assert(&I == &I.getParent()->front() || + std::prev(I.getIterator())->mayHaveSideEffects(), + "Unusual: unreachable immediately preceded by instruction without " + "side effects", + &I); +} + +/// findValue - Look through bitcasts and simple memory reference patterns +/// to identify an equivalent, but more informative, value. If OffsetOk +/// is true, look through getelementptrs with non-zero offsets too. +/// +/// Most analysis passes don't require this logic, because instcombine +/// will simplify most of these kinds of things away. But it's a goal of +/// this Lint pass to be useful even on non-optimized IR. +Value *Lint::findValue(Value *V, bool OffsetOk) const { + SmallPtrSet<Value *, 4> Visited; + return findValueImpl(V, OffsetOk, Visited); +} + +/// findValueImpl - Implementation helper for findValue. +Value *Lint::findValueImpl(Value *V, bool OffsetOk, + SmallPtrSetImpl<Value *> &Visited) const { + // Detect self-referential values. + if (!Visited.insert(V).second) + return UndefValue::get(V->getType()); + + // TODO: Look through sext or zext cast, when the result is known to + // be interpreted as signed or unsigned, respectively. + // TODO: Look through eliminable cast pairs. + // TODO: Look through calls with unique return values. + // TODO: Look through vector insert/extract/shuffle. + V = OffsetOk ? GetUnderlyingObject(V, *DL) : V->stripPointerCasts(); + if (LoadInst *L = dyn_cast<LoadInst>(V)) { + BasicBlock::iterator BBI = L->getIterator(); + BasicBlock *BB = L->getParent(); + SmallPtrSet<BasicBlock *, 4> VisitedBlocks; + for (;;) { + if (!VisitedBlocks.insert(BB).second) + break; + if (Value *U = + FindAvailableLoadedValue(L, BB, BBI, DefMaxInstsToScan, AA)) + return findValueImpl(U, OffsetOk, Visited); + if (BBI != BB->begin()) break; + BB = BB->getUniquePredecessor(); + if (!BB) break; + BBI = BB->end(); + } + } else if (PHINode *PN = dyn_cast<PHINode>(V)) { + if (Value *W = PN->hasConstantValue()) + if (W != V) + return findValueImpl(W, OffsetOk, Visited); + } else if (CastInst *CI = dyn_cast<CastInst>(V)) { + if (CI->isNoopCast(*DL)) + return findValueImpl(CI->getOperand(0), OffsetOk, Visited); + } else if (ExtractValueInst *Ex = dyn_cast<ExtractValueInst>(V)) { + if (Value *W = FindInsertedValue(Ex->getAggregateOperand(), + Ex->getIndices())) + if (W != V) + return findValueImpl(W, OffsetOk, Visited); + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + // Same as above, but for ConstantExpr instead of Instruction. + if (Instruction::isCast(CE->getOpcode())) { + if (CastInst::isNoopCast(Instruction::CastOps(CE->getOpcode()), + CE->getOperand(0)->getType(), CE->getType(), + *DL)) + return findValueImpl(CE->getOperand(0), OffsetOk, Visited); + } else if (CE->getOpcode() == Instruction::ExtractValue) { + ArrayRef<unsigned> Indices = CE->getIndices(); + if (Value *W = FindInsertedValue(CE->getOperand(0), Indices)) + if (W != V) + return findValueImpl(W, OffsetOk, Visited); + } + } + + // As a last resort, try SimplifyInstruction or constant folding. + if (Instruction *Inst = dyn_cast<Instruction>(V)) { + if (Value *W = SimplifyInstruction(Inst, {*DL, TLI, DT, AC})) + return findValueImpl(W, OffsetOk, Visited); + } else if (auto *C = dyn_cast<Constant>(V)) { + if (Value *W = ConstantFoldConstant(C, *DL, TLI)) + if (W && W != V) + return findValueImpl(W, OffsetOk, Visited); + } + + return V; +} + +//===----------------------------------------------------------------------===// +// Implement the public interfaces to this file... +//===----------------------------------------------------------------------===// + +FunctionPass *llvm::createLintPass() { + return new Lint(); +} + +/// lintFunction - Check a function for errors, printing messages on stderr. +/// +void llvm::lintFunction(const Function &f) { + Function &F = const_cast<Function&>(f); + assert(!F.isDeclaration() && "Cannot lint external functions"); + + legacy::FunctionPassManager FPM(F.getParent()); + Lint *V = new Lint(); + FPM.add(V); + FPM.run(F); +} + +/// lintModule - Check a module for errors, printing messages on stderr. +/// +void llvm::lintModule(const Module &M) { + legacy::PassManager PM; + Lint *V = new Lint(); + PM.add(V); + PM.run(const_cast<Module&>(M)); +} diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp new file mode 100644 index 000000000000..641e92eac781 --- /dev/null +++ b/llvm/lib/Analysis/Loads.cpp @@ -0,0 +1,481 @@ +//===- Loads.cpp - Local load analysis ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines simple local analyses for load instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Statepoint.h" + +using namespace llvm; + +static MaybeAlign getBaseAlign(const Value *Base, const DataLayout &DL) { + if (const MaybeAlign PA = Base->getPointerAlignment(DL)) + return *PA; + Type *const Ty = Base->getType()->getPointerElementType(); + if (!Ty->isSized()) + return None; + return Align(DL.getABITypeAlignment(Ty)); +} + +static bool isAligned(const Value *Base, const APInt &Offset, Align Alignment, + const DataLayout &DL) { + if (MaybeAlign BA = getBaseAlign(Base, DL)) { + const APInt APBaseAlign(Offset.getBitWidth(), BA->value()); + const APInt APAlign(Offset.getBitWidth(), Alignment.value()); + assert(APAlign.isPowerOf2() && "must be a power of 2!"); + return APBaseAlign.uge(APAlign) && !(Offset & (APAlign - 1)); + } + return false; +} + +/// Test if V is always a pointer to allocated and suitably aligned memory for +/// a simple load or store. +static bool isDereferenceableAndAlignedPointer( + const Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, + const Instruction *CtxI, const DominatorTree *DT, + SmallPtrSetImpl<const Value *> &Visited) { + // Already visited? Bail out, we've likely hit unreachable code. + if (!Visited.insert(V).second) + return false; + + // Note that it is not safe to speculate into a malloc'd region because + // malloc may return null. + + // bitcast instructions are no-ops as far as dereferenceability is concerned. + if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(V)) + return isDereferenceableAndAlignedPointer(BC->getOperand(0), Alignment, + Size, DL, CtxI, DT, Visited); + + bool CheckForNonNull = false; + APInt KnownDerefBytes(Size.getBitWidth(), + V->getPointerDereferenceableBytes(DL, CheckForNonNull)); + if (KnownDerefBytes.getBoolValue() && KnownDerefBytes.uge(Size)) + if (!CheckForNonNull || isKnownNonZero(V, DL, 0, nullptr, CtxI, DT)) { + // As we recursed through GEPs to get here, we've incrementally checked + // that each step advanced by a multiple of the alignment. If our base is + // properly aligned, then the original offset accessed must also be. + Type *Ty = V->getType(); + assert(Ty->isSized() && "must be sized"); + APInt Offset(DL.getTypeStoreSizeInBits(Ty), 0); + return isAligned(V, Offset, Alignment, DL); + } + + // For GEPs, determine if the indexing lands within the allocated object. + if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + const Value *Base = GEP->getPointerOperand(); + + APInt Offset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); + if (!GEP->accumulateConstantOffset(DL, Offset) || Offset.isNegative() || + !Offset.urem(APInt(Offset.getBitWidth(), Alignment.value())) + .isMinValue()) + return false; + + // If the base pointer is dereferenceable for Offset+Size bytes, then the + // GEP (== Base + Offset) is dereferenceable for Size bytes. If the base + // pointer is aligned to Align bytes, and the Offset is divisible by Align + // then the GEP (== Base + Offset == k_0 * Align + k_1 * Align) is also + // aligned to Align bytes. + + // Offset and Size may have different bit widths if we have visited an + // addrspacecast, so we can't do arithmetic directly on the APInt values. + return isDereferenceableAndAlignedPointer( + Base, Alignment, Offset + Size.sextOrTrunc(Offset.getBitWidth()), DL, + CtxI, DT, Visited); + } + + // For gc.relocate, look through relocations + if (const GCRelocateInst *RelocateInst = dyn_cast<GCRelocateInst>(V)) + return isDereferenceableAndAlignedPointer( + RelocateInst->getDerivedPtr(), Alignment, Size, DL, CtxI, DT, Visited); + + if (const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(V)) + return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Alignment, + Size, DL, CtxI, DT, Visited); + + if (const auto *Call = dyn_cast<CallBase>(V)) + if (auto *RP = getArgumentAliasingToReturnedPointer(Call, true)) + return isDereferenceableAndAlignedPointer(RP, Alignment, Size, DL, CtxI, + DT, Visited); + + // If we don't know, assume the worst. + return false; +} + +bool llvm::isDereferenceableAndAlignedPointer(const Value *V, Align Alignment, + const APInt &Size, + const DataLayout &DL, + const Instruction *CtxI, + const DominatorTree *DT) { + // Note: At the moment, Size can be zero. This ends up being interpreted as + // a query of whether [Base, V] is dereferenceable and V is aligned (since + // that's what the implementation happened to do). It's unclear if this is + // the desired semantic, but at least SelectionDAG does exercise this case. + + SmallPtrSet<const Value *, 32> Visited; + return ::isDereferenceableAndAlignedPointer(V, Alignment, Size, DL, CtxI, DT, + Visited); +} + +bool llvm::isDereferenceableAndAlignedPointer(const Value *V, Type *Ty, + MaybeAlign MA, + const DataLayout &DL, + const Instruction *CtxI, + const DominatorTree *DT) { + if (!Ty->isSized()) + return false; + + // When dereferenceability information is provided by a dereferenceable + // attribute, we know exactly how many bytes are dereferenceable. If we can + // determine the exact offset to the attributed variable, we can use that + // information here. + + // Require ABI alignment for loads without alignment specification + const Align Alignment = DL.getValueOrABITypeAlignment(MA, Ty); + APInt AccessSize(DL.getIndexTypeSizeInBits(V->getType()), + DL.getTypeStoreSize(Ty)); + return isDereferenceableAndAlignedPointer(V, Alignment, AccessSize, DL, CtxI, + DT); +} + +bool llvm::isDereferenceablePointer(const Value *V, Type *Ty, + const DataLayout &DL, + const Instruction *CtxI, + const DominatorTree *DT) { + return isDereferenceableAndAlignedPointer(V, Ty, Align::None(), DL, CtxI, DT); +} + +/// Test if A and B will obviously have the same value. +/// +/// This includes recognizing that %t0 and %t1 will have the same +/// value in code like this: +/// \code +/// %t0 = getelementptr \@a, 0, 3 +/// store i32 0, i32* %t0 +/// %t1 = getelementptr \@a, 0, 3 +/// %t2 = load i32* %t1 +/// \endcode +/// +static bool AreEquivalentAddressValues(const Value *A, const Value *B) { + // Test if the values are trivially equivalent. + if (A == B) + return true; + + // Test if the values come from identical arithmetic instructions. + // Use isIdenticalToWhenDefined instead of isIdenticalTo because + // this function is only used when one address use dominates the + // other, which means that they'll always either have the same + // value or one of them will have an undefined value. + if (isa<BinaryOperator>(A) || isa<CastInst>(A) || isa<PHINode>(A) || + isa<GetElementPtrInst>(A)) + if (const Instruction *BI = dyn_cast<Instruction>(B)) + if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI)) + return true; + + // Otherwise they may not be equivalent. + return false; +} + +bool llvm::isDereferenceableAndAlignedInLoop(LoadInst *LI, Loop *L, + ScalarEvolution &SE, + DominatorTree &DT) { + auto &DL = LI->getModule()->getDataLayout(); + Value *Ptr = LI->getPointerOperand(); + + APInt EltSize(DL.getIndexTypeSizeInBits(Ptr->getType()), + DL.getTypeStoreSize(LI->getType())); + const Align Alignment = DL.getValueOrABITypeAlignment( + MaybeAlign(LI->getAlignment()), LI->getType()); + + Instruction *HeaderFirstNonPHI = L->getHeader()->getFirstNonPHI(); + + // If given a uniform (i.e. non-varying) address, see if we can prove the + // access is safe within the loop w/o needing predication. + if (L->isLoopInvariant(Ptr)) + return isDereferenceableAndAlignedPointer(Ptr, Alignment, EltSize, DL, + HeaderFirstNonPHI, &DT); + + // Otherwise, check to see if we have a repeating access pattern where we can + // prove that all accesses are well aligned and dereferenceable. + auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Ptr)); + if (!AddRec || AddRec->getLoop() != L || !AddRec->isAffine()) + return false; + auto* Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE)); + if (!Step) + return false; + // TODO: generalize to access patterns which have gaps + if (Step->getAPInt() != EltSize) + return false; + + // TODO: If the symbolic trip count has a small bound (max count), we might + // be able to prove safety. + auto TC = SE.getSmallConstantTripCount(L); + if (!TC) + return false; + + const APInt AccessSize = TC * EltSize; + + auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart()); + if (!StartS) + return false; + assert(SE.isLoopInvariant(StartS, L) && "implied by addrec definition"); + Value *Base = StartS->getValue(); + + // For the moment, restrict ourselves to the case where the access size is a + // multiple of the requested alignment and the base is aligned. + // TODO: generalize if a case found which warrants + if (EltSize.urem(Alignment.value()) != 0) + return false; + return isDereferenceableAndAlignedPointer(Base, Alignment, AccessSize, DL, + HeaderFirstNonPHI, &DT); +} + +/// Check if executing a load of this pointer value cannot trap. +/// +/// If DT and ScanFrom are specified this method performs context-sensitive +/// analysis and returns true if it is safe to load immediately before ScanFrom. +/// +/// If it is not obviously safe to load from the specified pointer, we do +/// a quick local scan of the basic block containing \c ScanFrom, to determine +/// if the address is already accessed. +/// +/// This uses the pointee type to determine how many bytes need to be safe to +/// load from the pointer. +bool llvm::isSafeToLoadUnconditionally(Value *V, MaybeAlign MA, APInt &Size, + const DataLayout &DL, + Instruction *ScanFrom, + const DominatorTree *DT) { + // Zero alignment means that the load has the ABI alignment for the target + const Align Alignment = + DL.getValueOrABITypeAlignment(MA, V->getType()->getPointerElementType()); + + // If DT is not specified we can't make context-sensitive query + const Instruction* CtxI = DT ? ScanFrom : nullptr; + if (isDereferenceableAndAlignedPointer(V, Alignment, Size, DL, CtxI, DT)) + return true; + + if (!ScanFrom) + return false; + + if (Size.getBitWidth() > 64) + return false; + const uint64_t LoadSize = Size.getZExtValue(); + + // Otherwise, be a little bit aggressive by scanning the local block where we + // want to check to see if the pointer is already being loaded or stored + // from/to. If so, the previous load or store would have already trapped, + // so there is no harm doing an extra load (also, CSE will later eliminate + // the load entirely). + BasicBlock::iterator BBI = ScanFrom->getIterator(), + E = ScanFrom->getParent()->begin(); + + // We can at least always strip pointer casts even though we can't use the + // base here. + V = V->stripPointerCasts(); + + while (BBI != E) { + --BBI; + + // If we see a free or a call which may write to memory (i.e. which might do + // a free) the pointer could be marked invalid. + if (isa<CallInst>(BBI) && BBI->mayWriteToMemory() && + !isa<DbgInfoIntrinsic>(BBI)) + return false; + + Value *AccessedPtr; + MaybeAlign MaybeAccessedAlign; + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { + // Ignore volatile loads. The execution of a volatile load cannot + // be used to prove an address is backed by regular memory; it can, + // for example, point to an MMIO register. + if (LI->isVolatile()) + continue; + AccessedPtr = LI->getPointerOperand(); + MaybeAccessedAlign = MaybeAlign(LI->getAlignment()); + } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { + // Ignore volatile stores (see comment for loads). + if (SI->isVolatile()) + continue; + AccessedPtr = SI->getPointerOperand(); + MaybeAccessedAlign = MaybeAlign(SI->getAlignment()); + } else + continue; + + Type *AccessedTy = AccessedPtr->getType()->getPointerElementType(); + + const Align AccessedAlign = + DL.getValueOrABITypeAlignment(MaybeAccessedAlign, AccessedTy); + if (AccessedAlign < Alignment) + continue; + + // Handle trivial cases. + if (AccessedPtr == V && + LoadSize <= DL.getTypeStoreSize(AccessedTy)) + return true; + + if (AreEquivalentAddressValues(AccessedPtr->stripPointerCasts(), V) && + LoadSize <= DL.getTypeStoreSize(AccessedTy)) + return true; + } + return false; +} + +bool llvm::isSafeToLoadUnconditionally(Value *V, Type *Ty, MaybeAlign Alignment, + const DataLayout &DL, + Instruction *ScanFrom, + const DominatorTree *DT) { + APInt Size(DL.getIndexTypeSizeInBits(V->getType()), DL.getTypeStoreSize(Ty)); + return isSafeToLoadUnconditionally(V, Alignment, Size, DL, ScanFrom, DT); +} + + /// DefMaxInstsToScan - the default number of maximum instructions +/// to scan in the block, used by FindAvailableLoadedValue(). +/// FindAvailableLoadedValue() was introduced in r60148, to improve jump +/// threading in part by eliminating partially redundant loads. +/// At that point, the value of MaxInstsToScan was already set to '6' +/// without documented explanation. +cl::opt<unsigned> +llvm::DefMaxInstsToScan("available-load-scan-limit", cl::init(6), cl::Hidden, + cl::desc("Use this to specify the default maximum number of instructions " + "to scan backward from a given instruction, when searching for " + "available loaded value")); + +Value *llvm::FindAvailableLoadedValue(LoadInst *Load, + BasicBlock *ScanBB, + BasicBlock::iterator &ScanFrom, + unsigned MaxInstsToScan, + AliasAnalysis *AA, bool *IsLoad, + unsigned *NumScanedInst) { + // Don't CSE load that is volatile or anything stronger than unordered. + if (!Load->isUnordered()) + return nullptr; + + return FindAvailablePtrLoadStore( + Load->getPointerOperand(), Load->getType(), Load->isAtomic(), ScanBB, + ScanFrom, MaxInstsToScan, AA, IsLoad, NumScanedInst); +} + +Value *llvm::FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy, + bool AtLeastAtomic, BasicBlock *ScanBB, + BasicBlock::iterator &ScanFrom, + unsigned MaxInstsToScan, + AliasAnalysis *AA, bool *IsLoadCSE, + unsigned *NumScanedInst) { + if (MaxInstsToScan == 0) + MaxInstsToScan = ~0U; + + const DataLayout &DL = ScanBB->getModule()->getDataLayout(); + + // Try to get the store size for the type. + auto AccessSize = LocationSize::precise(DL.getTypeStoreSize(AccessTy)); + + Value *StrippedPtr = Ptr->stripPointerCasts(); + + while (ScanFrom != ScanBB->begin()) { + // We must ignore debug info directives when counting (otherwise they + // would affect codegen). + Instruction *Inst = &*--ScanFrom; + if (isa<DbgInfoIntrinsic>(Inst)) + continue; + + // Restore ScanFrom to expected value in case next test succeeds + ScanFrom++; + + if (NumScanedInst) + ++(*NumScanedInst); + + // Don't scan huge blocks. + if (MaxInstsToScan-- == 0) + return nullptr; + + --ScanFrom; + // If this is a load of Ptr, the loaded value is available. + // (This is true even if the load is volatile or atomic, although + // those cases are unlikely.) + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) + if (AreEquivalentAddressValues( + LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) { + + // We can value forward from an atomic to a non-atomic, but not the + // other way around. + if (LI->isAtomic() < AtLeastAtomic) + return nullptr; + + if (IsLoadCSE) + *IsLoadCSE = true; + return LI; + } + + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + Value *StorePtr = SI->getPointerOperand()->stripPointerCasts(); + // If this is a store through Ptr, the value is available! + // (This is true even if the store is volatile or atomic, although + // those cases are unlikely.) + if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(), + AccessTy, DL)) { + + // We can value forward from an atomic to a non-atomic, but not the + // other way around. + if (SI->isAtomic() < AtLeastAtomic) + return nullptr; + + if (IsLoadCSE) + *IsLoadCSE = false; + return SI->getOperand(0); + } + + // If both StrippedPtr and StorePtr reach all the way to an alloca or + // global and they are different, ignore the store. This is a trivial form + // of alias analysis that is important for reg2mem'd code. + if ((isa<AllocaInst>(StrippedPtr) || isa<GlobalVariable>(StrippedPtr)) && + (isa<AllocaInst>(StorePtr) || isa<GlobalVariable>(StorePtr)) && + StrippedPtr != StorePtr) + continue; + + // If we have alias analysis and it says the store won't modify the loaded + // value, ignore the store. + if (AA && !isModSet(AA->getModRefInfo(SI, StrippedPtr, AccessSize))) + continue; + + // Otherwise the store that may or may not alias the pointer, bail out. + ++ScanFrom; + return nullptr; + } + + // If this is some other instruction that may clobber Ptr, bail out. + if (Inst->mayWriteToMemory()) { + // If alias analysis claims that it really won't modify the load, + // ignore it. + if (AA && !isModSet(AA->getModRefInfo(Inst, StrippedPtr, AccessSize))) + continue; + + // May modify the pointer, bail out. + ++ScanFrom; + return nullptr; + } + } + + // Got to the start of the block, we didn't find it, but are done for this + // block. + return nullptr; +} diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp new file mode 100644 index 000000000000..3d8f77675f3a --- /dev/null +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -0,0 +1,2464 @@ +//===- LoopAccessAnalysis.cpp - Loop Access Analysis Implementation --------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The implementation for the loop memory dependence that was originally +// developed for the loop vectorizer. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "loop-accesses" + +static cl::opt<unsigned, true> +VectorizationFactor("force-vector-width", cl::Hidden, + cl::desc("Sets the SIMD width. Zero is autoselect."), + cl::location(VectorizerParams::VectorizationFactor)); +unsigned VectorizerParams::VectorizationFactor; + +static cl::opt<unsigned, true> +VectorizationInterleave("force-vector-interleave", cl::Hidden, + cl::desc("Sets the vectorization interleave count. " + "Zero is autoselect."), + cl::location( + VectorizerParams::VectorizationInterleave)); +unsigned VectorizerParams::VectorizationInterleave; + +static cl::opt<unsigned, true> RuntimeMemoryCheckThreshold( + "runtime-memory-check-threshold", cl::Hidden, + cl::desc("When performing memory disambiguation checks at runtime do not " + "generate more than this number of comparisons (default = 8)."), + cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8)); +unsigned VectorizerParams::RuntimeMemoryCheckThreshold; + +/// The maximum iterations used to merge memory checks +static cl::opt<unsigned> MemoryCheckMergeThreshold( + "memory-check-merge-threshold", cl::Hidden, + cl::desc("Maximum number of comparisons done when trying to merge " + "runtime memory checks. (default = 100)"), + cl::init(100)); + +/// Maximum SIMD width. +const unsigned VectorizerParams::MaxVectorWidth = 64; + +/// We collect dependences up to this threshold. +static cl::opt<unsigned> + MaxDependences("max-dependences", cl::Hidden, + cl::desc("Maximum number of dependences collected by " + "loop-access analysis (default = 100)"), + cl::init(100)); + +/// This enables versioning on the strides of symbolically striding memory +/// accesses in code like the following. +/// for (i = 0; i < N; ++i) +/// A[i * Stride1] += B[i * Stride2] ... +/// +/// Will be roughly translated to +/// if (Stride1 == 1 && Stride2 == 1) { +/// for (i = 0; i < N; i+=4) +/// A[i:i+3] += ... +/// } else +/// ... +static cl::opt<bool> EnableMemAccessVersioning( + "enable-mem-access-versioning", cl::init(true), cl::Hidden, + cl::desc("Enable symbolic stride memory access versioning")); + +/// Enable store-to-load forwarding conflict detection. This option can +/// be disabled for correctness testing. +static cl::opt<bool> EnableForwardingConflictDetection( + "store-to-load-forwarding-conflict-detection", cl::Hidden, + cl::desc("Enable conflict detection in loop-access analysis"), + cl::init(true)); + +bool VectorizerParams::isInterleaveForced() { + return ::VectorizationInterleave.getNumOccurrences() > 0; +} + +Value *llvm::stripIntegerCast(Value *V) { + if (auto *CI = dyn_cast<CastInst>(V)) + if (CI->getOperand(0)->getType()->isIntegerTy()) + return CI->getOperand(0); + return V; +} + +const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, + const ValueToValueMap &PtrToStride, + Value *Ptr, Value *OrigPtr) { + const SCEV *OrigSCEV = PSE.getSCEV(Ptr); + + // If there is an entry in the map return the SCEV of the pointer with the + // symbolic stride replaced by one. + ValueToValueMap::const_iterator SI = + PtrToStride.find(OrigPtr ? OrigPtr : Ptr); + if (SI != PtrToStride.end()) { + Value *StrideVal = SI->second; + + // Strip casts. + StrideVal = stripIntegerCast(StrideVal); + + ScalarEvolution *SE = PSE.getSE(); + const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal)); + const auto *CT = + static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType())); + + PSE.addPredicate(*SE->getEqualPredicate(U, CT)); + auto *Expr = PSE.getSCEV(Ptr); + + LLVM_DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV + << " by: " << *Expr << "\n"); + return Expr; + } + + // Otherwise, just return the SCEV of the original pointer. + return OrigSCEV; +} + +/// Calculate Start and End points of memory access. +/// Let's assume A is the first access and B is a memory access on N-th loop +/// iteration. Then B is calculated as: +/// B = A + Step*N . +/// Step value may be positive or negative. +/// N is a calculated back-edge taken count: +/// N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0 +/// Start and End points are calculated in the following way: +/// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt, +/// where SizeOfElt is the size of single memory access in bytes. +/// +/// There is no conflict when the intervals are disjoint: +/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End) +void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, + unsigned DepSetId, unsigned ASId, + const ValueToValueMap &Strides, + PredicatedScalarEvolution &PSE) { + // Get the stride replaced scev. + const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + ScalarEvolution *SE = PSE.getSE(); + + const SCEV *ScStart; + const SCEV *ScEnd; + + if (SE->isLoopInvariant(Sc, Lp)) + ScStart = ScEnd = Sc; + else { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); + assert(AR && "Invalid addrec expression"); + const SCEV *Ex = PSE.getBackedgeTakenCount(); + + ScStart = AR->getStart(); + ScEnd = AR->evaluateAtIteration(Ex, *SE); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // For expressions with negative step, the upper bound is ScStart and the + // lower bound is ScEnd. + if (const auto *CStep = dyn_cast<SCEVConstant>(Step)) { + if (CStep->getValue()->isNegative()) + std::swap(ScStart, ScEnd); + } else { + // Fallback case: the step is not constant, but we can still + // get the upper and lower bounds of the interval by using min/max + // expressions. + ScStart = SE->getUMinExpr(ScStart, ScEnd); + ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); + } + // Add the size of the pointed element to ScEnd. + unsigned EltSize = + Ptr->getType()->getPointerElementType()->getScalarSizeInBits() / 8; + const SCEV *EltSizeSCEV = SE->getConstant(ScEnd->getType(), EltSize); + ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); + } + + Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); +} + +SmallVector<RuntimePointerChecking::PointerCheck, 4> +RuntimePointerChecking::generateChecks() const { + SmallVector<PointerCheck, 4> Checks; + + for (unsigned I = 0; I < CheckingGroups.size(); ++I) { + for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { + const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; + const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + + if (needsChecking(CGI, CGJ)) + Checks.push_back(std::make_pair(&CGI, &CGJ)); + } + } + return Checks; +} + +void RuntimePointerChecking::generateChecks( + MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { + assert(Checks.empty() && "Checks is not empty"); + groupChecks(DepCands, UseDependencies); + Checks = generateChecks(); +} + +bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M, + const CheckingPtrGroup &N) const { + for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) + for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) + if (needsChecking(M.Members[I], N.Members[J])) + return true; + return false; +} + +/// Compare \p I and \p J and return the minimum. +/// Return nullptr in case we couldn't find an answer. +static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J, + ScalarEvolution *SE) { + const SCEV *Diff = SE->getMinusSCEV(J, I); + const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff); + + if (!C) + return nullptr; + if (C->getValue()->isNegative()) + return J; + return I; +} + +bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { + const SCEV *Start = RtCheck.Pointers[Index].Start; + const SCEV *End = RtCheck.Pointers[Index].End; + + // Compare the starts and ends with the known minimum and maximum + // of this set. We need to know how we compare against the min/max + // of the set in order to be able to emit memchecks. + const SCEV *Min0 = getMinFromExprs(Start, Low, RtCheck.SE); + if (!Min0) + return false; + + const SCEV *Min1 = getMinFromExprs(End, High, RtCheck.SE); + if (!Min1) + return false; + + // Update the low bound expression if we've found a new min value. + if (Min0 == Start) + Low = Start; + + // Update the high bound expression if we've found a new max value. + if (Min1 != End) + High = End; + + Members.push_back(Index); + return true; +} + +void RuntimePointerChecking::groupChecks( + MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { + // We build the groups from dependency candidates equivalence classes + // because: + // - We know that pointers in the same equivalence class share + // the same underlying object and therefore there is a chance + // that we can compare pointers + // - We wouldn't be able to merge two pointers for which we need + // to emit a memcheck. The classes in DepCands are already + // conveniently built such that no two pointers in the same + // class need checking against each other. + + // We use the following (greedy) algorithm to construct the groups + // For every pointer in the equivalence class: + // For each existing group: + // - if the difference between this pointer and the min/max bounds + // of the group is a constant, then make the pointer part of the + // group and update the min/max bounds of that group as required. + + CheckingGroups.clear(); + + // If we need to check two pointers to the same underlying object + // with a non-constant difference, we shouldn't perform any pointer + // grouping with those pointers. This is because we can easily get + // into cases where the resulting check would return false, even when + // the accesses are safe. + // + // The following example shows this: + // for (i = 0; i < 1000; ++i) + // a[5000 + i * m] = a[i] + a[i + 9000] + // + // Here grouping gives a check of (5000, 5000 + 1000 * m) against + // (0, 10000) which is always false. However, if m is 1, there is no + // dependence. Not grouping the checks for a[i] and a[i + 9000] allows + // us to perform an accurate check in this case. + // + // The above case requires that we have an UnknownDependence between + // accesses to the same underlying object. This cannot happen unless + // FoundNonConstantDistanceDependence is set, and therefore UseDependencies + // is also false. In this case we will use the fallback path and create + // separate checking groups for all pointers. + + // If we don't have the dependency partitions, construct a new + // checking pointer group for each pointer. This is also required + // for correctness, because in this case we can have checking between + // pointers to the same underlying object. + if (!UseDependencies) { + for (unsigned I = 0; I < Pointers.size(); ++I) + CheckingGroups.push_back(CheckingPtrGroup(I, *this)); + return; + } + + unsigned TotalComparisons = 0; + + DenseMap<Value *, unsigned> PositionMap; + for (unsigned Index = 0; Index < Pointers.size(); ++Index) + PositionMap[Pointers[Index].PointerValue] = Index; + + // We need to keep track of what pointers we've already seen so we + // don't process them twice. + SmallSet<unsigned, 2> Seen; + + // Go through all equivalence classes, get the "pointer check groups" + // and add them to the overall solution. We use the order in which accesses + // appear in 'Pointers' to enforce determinism. + for (unsigned I = 0; I < Pointers.size(); ++I) { + // We've seen this pointer before, and therefore already processed + // its equivalence class. + if (Seen.count(I)) + continue; + + MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, + Pointers[I].IsWritePtr); + + SmallVector<CheckingPtrGroup, 2> Groups; + auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); + + // Because DepCands is constructed by visiting accesses in the order in + // which they appear in alias sets (which is deterministic) and the + // iteration order within an equivalence class member is only dependent on + // the order in which unions and insertions are performed on the + // equivalence class, the iteration order is deterministic. + for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end(); + MI != ME; ++MI) { + unsigned Pointer = PositionMap[MI->getPointer()]; + bool Merged = false; + // Mark this pointer as seen. + Seen.insert(Pointer); + + // Go through all the existing sets and see if we can find one + // which can include this pointer. + for (CheckingPtrGroup &Group : Groups) { + // Don't perform more than a certain amount of comparisons. + // This should limit the cost of grouping the pointers to something + // reasonable. If we do end up hitting this threshold, the algorithm + // will create separate groups for all remaining pointers. + if (TotalComparisons > MemoryCheckMergeThreshold) + break; + + TotalComparisons++; + + if (Group.addPointer(Pointer)) { + Merged = true; + break; + } + } + + if (!Merged) + // We couldn't add this pointer to any existing set or the threshold + // for the number of comparisons has been reached. Create a new group + // to hold the current pointer. + Groups.push_back(CheckingPtrGroup(Pointer, *this)); + } + + // We've computed the grouped checks for this partition. + // Save the results and continue with the next one. + llvm::copy(Groups, std::back_inserter(CheckingGroups)); + } +} + +bool RuntimePointerChecking::arePointersInSamePartition( + const SmallVectorImpl<int> &PtrToPartition, unsigned PtrIdx1, + unsigned PtrIdx2) { + return (PtrToPartition[PtrIdx1] != -1 && + PtrToPartition[PtrIdx1] == PtrToPartition[PtrIdx2]); +} + +bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const { + const PointerInfo &PointerI = Pointers[I]; + const PointerInfo &PointerJ = Pointers[J]; + + // No need to check if two readonly pointers intersect. + if (!PointerI.IsWritePtr && !PointerJ.IsWritePtr) + return false; + + // Only need to check pointers between two different dependency sets. + if (PointerI.DependencySetId == PointerJ.DependencySetId) + return false; + + // Only need to check pointers in the same alias set. + if (PointerI.AliasSetId != PointerJ.AliasSetId) + return false; + + return true; +} + +void RuntimePointerChecking::printChecks( + raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks, + unsigned Depth) const { + unsigned N = 0; + for (const auto &Check : Checks) { + const auto &First = Check.first->Members, &Second = Check.second->Members; + + OS.indent(Depth) << "Check " << N++ << ":\n"; + + OS.indent(Depth + 2) << "Comparing group (" << Check.first << "):\n"; + for (unsigned K = 0; K < First.size(); ++K) + OS.indent(Depth + 2) << *Pointers[First[K]].PointerValue << "\n"; + + OS.indent(Depth + 2) << "Against group (" << Check.second << "):\n"; + for (unsigned K = 0; K < Second.size(); ++K) + OS.indent(Depth + 2) << *Pointers[Second[K]].PointerValue << "\n"; + } +} + +void RuntimePointerChecking::print(raw_ostream &OS, unsigned Depth) const { + + OS.indent(Depth) << "Run-time memory checks:\n"; + printChecks(OS, Checks, Depth); + + OS.indent(Depth) << "Grouped accesses:\n"; + for (unsigned I = 0; I < CheckingGroups.size(); ++I) { + const auto &CG = CheckingGroups[I]; + + OS.indent(Depth + 2) << "Group " << &CG << ":\n"; + OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High + << ")\n"; + for (unsigned J = 0; J < CG.Members.size(); ++J) { + OS.indent(Depth + 6) << "Member: " << *Pointers[CG.Members[J]].Expr + << "\n"; + } + } +} + +namespace { + +/// Analyses memory accesses in a loop. +/// +/// Checks whether run time pointer checks are needed and builds sets for data +/// dependence checking. +class AccessAnalysis { +public: + /// Read or write access location. + typedef PointerIntPair<Value *, 1, bool> MemAccessInfo; + typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList; + + AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AliasAnalysis *AA, + LoopInfo *LI, MemoryDepChecker::DepCandidates &DA, + PredicatedScalarEvolution &PSE) + : DL(Dl), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), + IsRTCheckAnalysisNeeded(false), PSE(PSE) {} + + /// Register a load and whether it is only read from. + void addLoad(MemoryLocation &Loc, bool IsReadOnly) { + Value *Ptr = const_cast<Value*>(Loc.Ptr); + AST.add(Ptr, LocationSize::unknown(), Loc.AATags); + Accesses.insert(MemAccessInfo(Ptr, false)); + if (IsReadOnly) + ReadOnlyPtr.insert(Ptr); + } + + /// Register a store. + void addStore(MemoryLocation &Loc) { + Value *Ptr = const_cast<Value*>(Loc.Ptr); + AST.add(Ptr, LocationSize::unknown(), Loc.AATags); + Accesses.insert(MemAccessInfo(Ptr, true)); + } + + /// Check if we can emit a run-time no-alias check for \p Access. + /// + /// Returns true if we can emit a run-time no alias check for \p Access. + /// If we can check this access, this also adds it to a dependence set and + /// adds a run-time to check for it to \p RtCheck. If \p Assume is true, + /// we will attempt to use additional run-time checks in order to get + /// the bounds of the pointer. + bool createCheckForAccess(RuntimePointerChecking &RtCheck, + MemAccessInfo Access, + const ValueToValueMap &Strides, + DenseMap<Value *, unsigned> &DepSetId, + Loop *TheLoop, unsigned &RunningDepId, + unsigned ASId, bool ShouldCheckStride, + bool Assume); + + /// Check whether we can check the pointers at runtime for + /// non-intersection. + /// + /// Returns true if we need no check or if we do and we can generate them + /// (i.e. the pointers have computable bounds). + bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, + Loop *TheLoop, const ValueToValueMap &Strides, + bool ShouldCheckWrap = false); + + /// Goes over all memory accesses, checks whether a RT check is needed + /// and builds sets of dependent accesses. + void buildDependenceSets() { + processMemAccesses(); + } + + /// Initial processing of memory accesses determined that we need to + /// perform dependency checking. + /// + /// Note that this can later be cleared if we retry memcheck analysis without + /// dependency checking (i.e. FoundNonConstantDistanceDependence). + bool isDependencyCheckNeeded() { return !CheckDeps.empty(); } + + /// We decided that no dependence analysis would be used. Reset the state. + void resetDepChecks(MemoryDepChecker &DepChecker) { + CheckDeps.clear(); + DepChecker.clearDependences(); + } + + MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; } + +private: + typedef SetVector<MemAccessInfo> PtrAccessSet; + + /// Go over all memory access and check whether runtime pointer checks + /// are needed and build sets of dependency check candidates. + void processMemAccesses(); + + /// Set of all accesses. + PtrAccessSet Accesses; + + const DataLayout &DL; + + /// The loop being checked. + const Loop *TheLoop; + + /// List of accesses that need a further dependence check. + MemAccessInfoList CheckDeps; + + /// Set of pointers that are read only. + SmallPtrSet<Value*, 16> ReadOnlyPtr; + + /// An alias set tracker to partition the access set by underlying object and + //intrinsic property (such as TBAA metadata). + AliasSetTracker AST; + + LoopInfo *LI; + + /// Sets of potentially dependent accesses - members of one set share an + /// underlying pointer. The set "CheckDeps" identfies which sets really need a + /// dependence check. + MemoryDepChecker::DepCandidates &DepCands; + + /// Initial processing of memory accesses determined that we may need + /// to add memchecks. Perform the analysis to determine the necessary checks. + /// + /// Note that, this is different from isDependencyCheckNeeded. When we retry + /// memcheck analysis without dependency checking + /// (i.e. FoundNonConstantDistanceDependence), isDependencyCheckNeeded is + /// cleared while this remains set if we have potentially dependent accesses. + bool IsRTCheckAnalysisNeeded; + + /// The SCEV predicate containing all the SCEV-related assumptions. + PredicatedScalarEvolution &PSE; +}; + +} // end anonymous namespace + +/// Check whether a pointer can participate in a runtime bounds check. +/// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr +/// by adding run-time checks (overflow checks) if necessary. +static bool hasComputableBounds(PredicatedScalarEvolution &PSE, + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, bool Assume) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + + // The bounds for loop-invariant pointer is trivial. + if (PSE.getSE()->isLoopInvariant(PtrScev, L)) + return true; + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + + if (!AR && Assume) + AR = PSE.getAsAddRec(Ptr); + + if (!AR) + return false; + + return AR->isAffine(); +} + +/// Check whether a pointer address cannot wrap. +static bool isNoWrap(PredicatedScalarEvolution &PSE, + const ValueToValueMap &Strides, Value *Ptr, Loop *L) { + const SCEV *PtrScev = PSE.getSCEV(Ptr); + if (PSE.getSE()->isLoopInvariant(PtrScev, L)) + return true; + + int64_t Stride = getPtrStride(PSE, Ptr, L, Strides); + if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) + return true; + + return false; +} + +bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, + MemAccessInfo Access, + const ValueToValueMap &StridesMap, + DenseMap<Value *, unsigned> &DepSetId, + Loop *TheLoop, unsigned &RunningDepId, + unsigned ASId, bool ShouldCheckWrap, + bool Assume) { + Value *Ptr = Access.getPointer(); + + if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) + return false; + + // When we run after a failing dependency check we have to make sure + // we don't have wrapping pointers. + if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { + auto *Expr = PSE.getSCEV(Ptr); + if (!Assume || !isa<SCEVAddRecExpr>(Expr)) + return false; + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } + + // The id of the dependence set. + unsigned DepId; + + if (isDependencyCheckNeeded()) { + Value *Leader = DepCands.getLeaderValue(Access).getPointer(); + unsigned &LeaderId = DepSetId[Leader]; + if (!LeaderId) + LeaderId = RunningDepId++; + DepId = LeaderId; + } else + // Each access has its own dependence set. + DepId = RunningDepId++; + + bool IsWrite = Access.getInt(); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); + LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + + return true; + } + +bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, + ScalarEvolution *SE, Loop *TheLoop, + const ValueToValueMap &StridesMap, + bool ShouldCheckWrap) { + // Find pointers with computable bounds. We are going to use this information + // to place a runtime bound check. + bool CanDoRT = true; + + bool NeedRTCheck = false; + if (!IsRTCheckAnalysisNeeded) return true; + + bool IsDepCheckNeeded = isDependencyCheckNeeded(); + + // We assign a consecutive id to access from different alias sets. + // Accesses between different groups doesn't need to be checked. + unsigned ASId = 1; + for (auto &AS : AST) { + int NumReadPtrChecks = 0; + int NumWritePtrChecks = 0; + bool CanDoAliasSetRT = true; + + // We assign consecutive id to access from different dependence sets. + // Accesses within the same set don't need a runtime check. + unsigned RunningDepId = 1; + DenseMap<Value *, unsigned> DepSetId; + + SmallVector<MemAccessInfo, 4> Retries; + + for (auto A : AS) { + Value *Ptr = A.getValue(); + bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); + MemAccessInfo Access(Ptr, IsWrite); + + if (IsWrite) + ++NumWritePtrChecks; + else + ++NumReadPtrChecks; + + if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, + RunningDepId, ASId, ShouldCheckWrap, false)) { + LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Ptr << '\n'); + Retries.push_back(Access); + CanDoAliasSetRT = false; + } + } + + // If we have at least two writes or one write and a read then we need to + // check them. But there is no need to checks if there is only one + // dependence set for this alias set. + // + // Note that this function computes CanDoRT and NeedRTCheck independently. + // For example CanDoRT=false, NeedRTCheck=false means that we have a pointer + // for which we couldn't find the bounds but we don't actually need to emit + // any checks so it does not matter. + bool NeedsAliasSetRTCheck = false; + if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2)) + NeedsAliasSetRTCheck = (NumWritePtrChecks >= 2 || + (NumReadPtrChecks >= 1 && NumWritePtrChecks >= 1)); + + // We need to perform run-time alias checks, but some pointers had bounds + // that couldn't be checked. + if (NeedsAliasSetRTCheck && !CanDoAliasSetRT) { + // Reset the CanDoSetRt flag and retry all accesses that have failed. + // We know that we need these checks, so we can now be more aggressive + // and add further checks if required (overflow checks). + CanDoAliasSetRT = true; + for (auto Access : Retries) + if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, + TheLoop, RunningDepId, ASId, + ShouldCheckWrap, /*Assume=*/true)) { + CanDoAliasSetRT = false; + break; + } + } + + CanDoRT &= CanDoAliasSetRT; + NeedRTCheck |= NeedsAliasSetRTCheck; + ++ASId; + } + + // If the pointers that we would use for the bounds comparison have different + // address spaces, assume the values aren't directly comparable, so we can't + // use them for the runtime check. We also have to assume they could + // overlap. In the future there should be metadata for whether address spaces + // are disjoint. + unsigned NumPointers = RtCheck.Pointers.size(); + for (unsigned i = 0; i < NumPointers; ++i) { + for (unsigned j = i + 1; j < NumPointers; ++j) { + // Only need to check pointers between two different dependency sets. + if (RtCheck.Pointers[i].DependencySetId == + RtCheck.Pointers[j].DependencySetId) + continue; + // Only need to check pointers in the same alias set. + if (RtCheck.Pointers[i].AliasSetId != RtCheck.Pointers[j].AliasSetId) + continue; + + Value *PtrI = RtCheck.Pointers[i].PointerValue; + Value *PtrJ = RtCheck.Pointers[j].PointerValue; + + unsigned ASi = PtrI->getType()->getPointerAddressSpace(); + unsigned ASj = PtrJ->getType()->getPointerAddressSpace(); + if (ASi != ASj) { + LLVM_DEBUG( + dbgs() << "LAA: Runtime check would require comparison between" + " different address spaces\n"); + return false; + } + } + } + + if (NeedRTCheck && CanDoRT) + RtCheck.generateChecks(DepCands, IsDepCheckNeeded); + + LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks() + << " pointer comparisons.\n"); + + RtCheck.Need = NeedRTCheck; + + bool CanDoRTIfNeeded = !NeedRTCheck || CanDoRT; + if (!CanDoRTIfNeeded) + RtCheck.reset(); + return CanDoRTIfNeeded; +} + +void AccessAnalysis::processMemAccesses() { + // We process the set twice: first we process read-write pointers, last we + // process read-only pointers. This allows us to skip dependence tests for + // read-only pointers. + + LLVM_DEBUG(dbgs() << "LAA: Processing memory accesses...\n"); + LLVM_DEBUG(dbgs() << " AST: "; AST.dump()); + LLVM_DEBUG(dbgs() << "LAA: Accesses(" << Accesses.size() << "):\n"); + LLVM_DEBUG({ + for (auto A : Accesses) + dbgs() << "\t" << *A.getPointer() << " (" << + (A.getInt() ? "write" : (ReadOnlyPtr.count(A.getPointer()) ? + "read-only" : "read")) << ")\n"; + }); + + // The AliasSetTracker has nicely partitioned our pointers by metadata + // compatibility and potential for underlying-object overlap. As a result, we + // only need to check for potential pointer dependencies within each alias + // set. + for (auto &AS : AST) { + // Note that both the alias-set tracker and the alias sets themselves used + // linked lists internally and so the iteration order here is deterministic + // (matching the original instruction order within each set). + + bool SetHasWrite = false; + + // Map of pointers to last access encountered. + typedef DenseMap<const Value*, MemAccessInfo> UnderlyingObjToAccessMap; + UnderlyingObjToAccessMap ObjToLastAccess; + + // Set of access to check after all writes have been processed. + PtrAccessSet DeferredAccesses; + + // Iterate over each alias set twice, once to process read/write pointers, + // and then to process read-only pointers. + for (int SetIteration = 0; SetIteration < 2; ++SetIteration) { + bool UseDeferred = SetIteration > 0; + PtrAccessSet &S = UseDeferred ? DeferredAccesses : Accesses; + + for (auto AV : AS) { + Value *Ptr = AV.getValue(); + + // For a single memory access in AliasSetTracker, Accesses may contain + // both read and write, and they both need to be handled for CheckDeps. + for (auto AC : S) { + if (AC.getPointer() != Ptr) + continue; + + bool IsWrite = AC.getInt(); + + // If we're using the deferred access set, then it contains only + // reads. + bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite; + if (UseDeferred && !IsReadOnlyPtr) + continue; + // Otherwise, the pointer must be in the PtrAccessSet, either as a + // read or a write. + assert(((IsReadOnlyPtr && UseDeferred) || IsWrite || + S.count(MemAccessInfo(Ptr, false))) && + "Alias-set pointer not in the access set?"); + + MemAccessInfo Access(Ptr, IsWrite); + DepCands.insert(Access); + + // Memorize read-only pointers for later processing and skip them in + // the first round (they need to be checked after we have seen all + // write pointers). Note: we also mark pointer that are not + // consecutive as "read-only" pointers (so that we check + // "a[b[i]] +="). Hence, we need the second check for "!IsWrite". + if (!UseDeferred && IsReadOnlyPtr) { + DeferredAccesses.insert(Access); + continue; + } + + // If this is a write - check other reads and writes for conflicts. If + // this is a read only check other writes for conflicts (but only if + // there is no other write to the ptr - this is an optimization to + // catch "a[i] = a[i] + " without having to do a dependence check). + if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { + CheckDeps.push_back(Access); + IsRTCheckAnalysisNeeded = true; + } + + if (IsWrite) + SetHasWrite = true; + + // Create sets of pointers connected by a shared alias set and + // underlying object. + typedef SmallVector<const Value *, 16> ValueVector; + ValueVector TempObjects; + + GetUnderlyingObjects(Ptr, TempObjects, DL, LI); + LLVM_DEBUG(dbgs() + << "Underlying objects for pointer " << *Ptr << "\n"); + for (const Value *UnderlyingObj : TempObjects) { + // nullptr never alias, don't join sets for pointer that have "null" + // in their UnderlyingObjects list. + if (isa<ConstantPointerNull>(UnderlyingObj) && + !NullPointerIsDefined( + TheLoop->getHeader()->getParent(), + UnderlyingObj->getType()->getPointerAddressSpace())) + continue; + + UnderlyingObjToAccessMap::iterator Prev = + ObjToLastAccess.find(UnderlyingObj); + if (Prev != ObjToLastAccess.end()) + DepCands.unionSets(Access, Prev->second); + + ObjToLastAccess[UnderlyingObj] = Access; + LLVM_DEBUG(dbgs() << " " << *UnderlyingObj << "\n"); + } + } + } + } + } +} + +static bool isInBoundsGep(Value *Ptr) { + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) + return GEP->isInBounds(); + return false; +} + +/// Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, +/// i.e. monotonically increasing/decreasing. +static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, + PredicatedScalarEvolution &PSE, const Loop *L) { + // FIXME: This should probably only return true for NUW. + if (AR->getNoWrapFlags(SCEV::NoWrapMask)) + return true; + + // Scalar evolution does not propagate the non-wrapping flags to values that + // are derived from a non-wrapping induction variable because non-wrapping + // could be flow-sensitive. + // + // Look through the potentially overflowing instruction to try to prove + // non-wrapping for the *specific* value of Ptr. + + // The arithmetic implied by an inbounds GEP can't overflow. + auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); + if (!GEP || !GEP->isInBounds()) + return false; + + // Make sure there is only one non-const index and analyze that. + Value *NonConstIndex = nullptr; + for (Value *Index : make_range(GEP->idx_begin(), GEP->idx_end())) + if (!isa<ConstantInt>(Index)) { + if (NonConstIndex) + return false; + NonConstIndex = Index; + } + if (!NonConstIndex) + // The recurrence is on the pointer, ignore for now. + return false; + + // The index in GEP is signed. It is non-wrapping if it's derived from a NSW + // AddRec using a NSW operation. + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(NonConstIndex)) + if (OBO->hasNoSignedWrap() && + // Assume constant for other the operand so that the AddRec can be + // easily found. + isa<ConstantInt>(OBO->getOperand(1))) { + auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); + + if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev)) + return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); + } + + return false; +} + +/// Check whether the access through \p Ptr has a constant stride. +int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Value *Ptr, + const Loop *Lp, const ValueToValueMap &StridesMap, + bool Assume, bool ShouldCheckWrap) { + Type *Ty = Ptr->getType(); + assert(Ty->isPointerTy() && "Unexpected non-ptr"); + + // Make sure that the pointer does not point to aggregate types. + auto *PtrTy = cast<PointerType>(Ty); + if (PtrTy->getElementType()->isAggregateType()) { + LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" + << *Ptr << "\n"); + return 0; + } + + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + if (Assume && !AR) + AR = PSE.getAsAddRec(Ptr); + + if (!AR) { + LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); + return 0; + } + + // The access function must stride over the innermost loop. + if (Lp != AR->getLoop()) { + LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " + << *Ptr << " SCEV: " << *AR << "\n"); + return 0; + } + + // The address calculation must not wrap. Otherwise, a dependence could be + // inverted. + // An inbounds getelementptr that is a AddRec with a unit stride + // cannot wrap per definition. The unit stride requirement is checked later. + // An getelementptr without an inbounds attribute and unit stride would have + // to access the pointer value "0" which is undefined behavior in address + // space 0, therefore we can also vectorize this case. + bool IsInBoundsGEP = isInBoundsGep(Ptr); + bool IsNoWrapAddRec = !ShouldCheckWrap || + PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); + if (!IsNoWrapAddRec && !IsInBoundsGEP && + NullPointerIsDefined(Lp->getHeader()->getParent(), + PtrTy->getAddressSpace())) { + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + IsNoWrapAddRec = true; + LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + } else { + LLVM_DEBUG( + dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *AR << "\n"); + return 0; + } + } + + // Check the step is constant. + const SCEV *Step = AR->getStepRecurrence(*PSE.getSE()); + + // Calculate the pointer stride and check if it is constant. + const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); + if (!C) { + LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr + << " SCEV: " << *AR << "\n"); + return 0; + } + + auto &DL = Lp->getHeader()->getModule()->getDataLayout(); + int64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + const APInt &APStepVal = C->getAPInt(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return 0; + + int64_t StepVal = APStepVal.getSExtValue(); + + // Strided access. + int64_t Stride = StepVal / Size; + int64_t Rem = StepVal % Size; + if (Rem) + return 0; + + // If the SCEV could wrap but we have an inbounds gep with a unit stride we + // know we can't "wrap around the address space". In case of address space + // zero we know that this won't happen without triggering undefined behavior. + if (!IsNoWrapAddRec && Stride != 1 && Stride != -1 && + (IsInBoundsGEP || !NullPointerIsDefined(Lp->getHeader()->getParent(), + PtrTy->getAddressSpace()))) { + if (Assume) { + // We can avoid this case by adding a run-time check. + LLVM_DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " + << "inbounds or in address space 0 may wrap:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } else + return 0; + } + + return Stride; +} + +bool llvm::sortPtrAccesses(ArrayRef<Value *> VL, const DataLayout &DL, + ScalarEvolution &SE, + SmallVectorImpl<unsigned> &SortedIndices) { + assert(llvm::all_of( + VL, [](const Value *V) { return V->getType()->isPointerTy(); }) && + "Expected list of pointer operands."); + SmallVector<std::pair<int64_t, Value *>, 4> OffValPairs; + OffValPairs.reserve(VL.size()); + + // Walk over the pointers, and map each of them to an offset relative to + // first pointer in the array. + Value *Ptr0 = VL[0]; + const SCEV *Scev0 = SE.getSCEV(Ptr0); + Value *Obj0 = GetUnderlyingObject(Ptr0, DL); + + llvm::SmallSet<int64_t, 4> Offsets; + for (auto *Ptr : VL) { + // TODO: Outline this code as a special, more time consuming, version of + // computeConstantDifference() function. + if (Ptr->getType()->getPointerAddressSpace() != + Ptr0->getType()->getPointerAddressSpace()) + return false; + // If a pointer refers to a different underlying object, bail - the + // pointers are by definition incomparable. + Value *CurrObj = GetUnderlyingObject(Ptr, DL); + if (CurrObj != Obj0) + return false; + + const SCEV *Scev = SE.getSCEV(Ptr); + const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Scev, Scev0)); + // The pointers may not have a constant offset from each other, or SCEV + // may just not be smart enough to figure out they do. Regardless, + // there's nothing we can do. + if (!Diff) + return false; + + // Check if the pointer with the same offset is found. + int64_t Offset = Diff->getAPInt().getSExtValue(); + if (!Offsets.insert(Offset).second) + return false; + OffValPairs.emplace_back(Offset, Ptr); + } + SortedIndices.clear(); + SortedIndices.resize(VL.size()); + std::iota(SortedIndices.begin(), SortedIndices.end(), 0); + + // Sort the memory accesses and keep the order of their uses in UseOrder. + llvm::stable_sort(SortedIndices, [&](unsigned Left, unsigned Right) { + return OffValPairs[Left].first < OffValPairs[Right].first; + }); + + // Check if the order is consecutive already. + if (llvm::all_of(SortedIndices, [&SortedIndices](const unsigned I) { + return I == SortedIndices[I]; + })) + SortedIndices.clear(); + + return true; +} + +/// Take the address space operand from the Load/Store instruction. +/// Returns -1 if this is not a valid Load/Store instruction. +static unsigned getAddressSpaceOperand(Value *I) { + if (LoadInst *L = dyn_cast<LoadInst>(I)) + return L->getPointerAddressSpace(); + if (StoreInst *S = dyn_cast<StoreInst>(I)) + return S->getPointerAddressSpace(); + return -1; +} + +/// Returns true if the memory operations \p A and \p B are consecutive. +bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, + ScalarEvolution &SE, bool CheckType) { + Value *PtrA = getLoadStorePointerOperand(A); + Value *PtrB = getLoadStorePointerOperand(B); + unsigned ASA = getAddressSpaceOperand(A); + unsigned ASB = getAddressSpaceOperand(B); + + // Check that the address spaces match and that the pointers are valid. + if (!PtrA || !PtrB || (ASA != ASB)) + return false; + + // Make sure that A and B are different pointers. + if (PtrA == PtrB) + return false; + + // Make sure that A and B have the same type if required. + if (CheckType && PtrA->getType() != PtrB->getType()) + return false; + + unsigned IdxWidth = DL.getIndexSizeInBits(ASA); + Type *Ty = cast<PointerType>(PtrA->getType())->getElementType(); + + APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0); + PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); + PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + + // Retrieve the address space again as pointer stripping now tracks through + // `addrspacecast`. + ASA = cast<PointerType>(PtrA->getType())->getAddressSpace(); + ASB = cast<PointerType>(PtrB->getType())->getAddressSpace(); + // Check that the address spaces match and that the pointers are valid. + if (ASA != ASB) + return false; + + IdxWidth = DL.getIndexSizeInBits(ASA); + OffsetA = OffsetA.sextOrTrunc(IdxWidth); + OffsetB = OffsetB.sextOrTrunc(IdxWidth); + + APInt Size(IdxWidth, DL.getTypeStoreSize(Ty)); + + // OffsetDelta = OffsetB - OffsetA; + const SCEV *OffsetSCEVA = SE.getConstant(OffsetA); + const SCEV *OffsetSCEVB = SE.getConstant(OffsetB); + const SCEV *OffsetDeltaSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA); + const APInt &OffsetDelta = cast<SCEVConstant>(OffsetDeltaSCEV)->getAPInt(); + + // Check if they are based on the same pointer. That makes the offsets + // sufficient. + if (PtrA == PtrB) + return OffsetDelta == Size; + + // Compute the necessary base pointer delta to have the necessary final delta + // equal to the size. + // BaseDelta = Size - OffsetDelta; + const SCEV *SizeSCEV = SE.getConstant(Size); + const SCEV *BaseDelta = SE.getMinusSCEV(SizeSCEV, OffsetDeltaSCEV); + + // Otherwise compute the distance with SCEV between the base pointers. + const SCEV *PtrSCEVA = SE.getSCEV(PtrA); + const SCEV *PtrSCEVB = SE.getSCEV(PtrB); + const SCEV *X = SE.getAddExpr(PtrSCEVA, BaseDelta); + return X == PtrSCEVB; +} + +MemoryDepChecker::VectorizationSafetyStatus +MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { + switch (Type) { + case NoDep: + case Forward: + case BackwardVectorizable: + return VectorizationSafetyStatus::Safe; + + case Unknown: + return VectorizationSafetyStatus::PossiblySafeWithRtChecks; + case ForwardButPreventsForwarding: + case Backward: + case BackwardVectorizableButPreventsForwarding: + return VectorizationSafetyStatus::Unsafe; + } + llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::Dependence::isBackward() const { + switch (Type) { + case NoDep: + case Forward: + case ForwardButPreventsForwarding: + case Unknown: + return false; + + case BackwardVectorizable: + case Backward: + case BackwardVectorizableButPreventsForwarding: + return true; + } + llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::Dependence::isPossiblyBackward() const { + return isBackward() || Type == Unknown; +} + +bool MemoryDepChecker::Dependence::isForward() const { + switch (Type) { + case Forward: + case ForwardButPreventsForwarding: + return true; + + case NoDep: + case Unknown: + case BackwardVectorizable: + case Backward: + case BackwardVectorizableButPreventsForwarding: + return false; + } + llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance, + uint64_t TypeByteSize) { + // If loads occur at a distance that is not a multiple of a feasible vector + // factor store-load forwarding does not take place. + // Positive dependences might cause troubles because vectorizing them might + // prevent store-load forwarding making vectorized code run a lot slower. + // a[i] = a[i-3] ^ a[i-8]; + // The stores to a[i:i+1] don't align with the stores to a[i-3:i-2] and + // hence on your typical architecture store-load forwarding does not take + // place. Vectorizing in such cases does not make sense. + // Store-load forwarding distance. + + // After this many iterations store-to-load forwarding conflicts should not + // cause any slowdowns. + const uint64_t NumItersForStoreLoadThroughMemory = 8 * TypeByteSize; + // Maximum vector factor. + uint64_t MaxVFWithoutSLForwardIssues = std::min( + VectorizerParams::MaxVectorWidth * TypeByteSize, MaxSafeDepDistBytes); + + // Compute the smallest VF at which the store and load would be misaligned. + for (uint64_t VF = 2 * TypeByteSize; VF <= MaxVFWithoutSLForwardIssues; + VF *= 2) { + // If the number of vector iteration between the store and the load are + // small we could incur conflicts. + if (Distance % VF && Distance / VF < NumItersForStoreLoadThroughMemory) { + MaxVFWithoutSLForwardIssues = (VF >>= 1); + break; + } + } + + if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) { + LLVM_DEBUG( + dbgs() << "LAA: Distance " << Distance + << " that could cause a store-load forwarding conflict\n"); + return true; + } + + if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes && + MaxVFWithoutSLForwardIssues != + VectorizerParams::MaxVectorWidth * TypeByteSize) + MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues; + return false; +} + +void MemoryDepChecker::mergeInStatus(VectorizationSafetyStatus S) { + if (Status < S) + Status = S; +} + +/// Given a non-constant (unknown) dependence-distance \p Dist between two +/// memory accesses, that have the same stride whose absolute value is given +/// in \p Stride, and that have the same type size \p TypeByteSize, +/// in a loop whose takenCount is \p BackedgeTakenCount, check if it is +/// possible to prove statically that the dependence distance is larger +/// than the range that the accesses will travel through the execution of +/// the loop. If so, return true; false otherwise. This is useful for +/// example in loops such as the following (PR31098): +/// for (i = 0; i < D; ++i) { +/// = out[i]; +/// out[i+D] = +/// } +static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE, + const SCEV &BackedgeTakenCount, + const SCEV &Dist, uint64_t Stride, + uint64_t TypeByteSize) { + + // If we can prove that + // (**) |Dist| > BackedgeTakenCount * Step + // where Step is the absolute stride of the memory accesses in bytes, + // then there is no dependence. + // + // Rationale: + // We basically want to check if the absolute distance (|Dist/Step|) + // is >= the loop iteration count (or > BackedgeTakenCount). + // This is equivalent to the Strong SIV Test (Practical Dependence Testing, + // Section 4.2.1); Note, that for vectorization it is sufficient to prove + // that the dependence distance is >= VF; This is checked elsewhere. + // But in some cases we can prune unknown dependence distances early, and + // even before selecting the VF, and without a runtime test, by comparing + // the distance against the loop iteration count. Since the vectorized code + // will be executed only if LoopCount >= VF, proving distance >= LoopCount + // also guarantees that distance >= VF. + // + const uint64_t ByteStride = Stride * TypeByteSize; + const SCEV *Step = SE.getConstant(BackedgeTakenCount.getType(), ByteStride); + const SCEV *Product = SE.getMulExpr(&BackedgeTakenCount, Step); + + const SCEV *CastedDist = &Dist; + const SCEV *CastedProduct = Product; + uint64_t DistTypeSize = DL.getTypeAllocSize(Dist.getType()); + uint64_t ProductTypeSize = DL.getTypeAllocSize(Product->getType()); + + // The dependence distance can be positive/negative, so we sign extend Dist; + // The multiplication of the absolute stride in bytes and the + // backedgeTakenCount is non-negative, so we zero extend Product. + if (DistTypeSize > ProductTypeSize) + CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType()); + else + CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType()); + + // Is Dist - (BackedgeTakenCount * Step) > 0 ? + // (If so, then we have proven (**) because |Dist| >= Dist) + const SCEV *Minus = SE.getMinusSCEV(CastedDist, CastedProduct); + if (SE.isKnownPositive(Minus)) + return true; + + // Second try: Is -Dist - (BackedgeTakenCount * Step) > 0 ? + // (If so, then we have proven (**) because |Dist| >= -1*Dist) + const SCEV *NegDist = SE.getNegativeSCEV(CastedDist); + Minus = SE.getMinusSCEV(NegDist, CastedProduct); + if (SE.isKnownPositive(Minus)) + return true; + + return false; +} + +/// Check the dependence for two accesses with the same stride \p Stride. +/// \p Distance is the positive distance and \p TypeByteSize is type size in +/// bytes. +/// +/// \returns true if they are independent. +static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride, + uint64_t TypeByteSize) { + assert(Stride > 1 && "The stride must be greater than 1"); + assert(TypeByteSize > 0 && "The type size in byte must be non-zero"); + assert(Distance > 0 && "The distance must be non-zero"); + + // Skip if the distance is not multiple of type byte size. + if (Distance % TypeByteSize) + return false; + + uint64_t ScaledDist = Distance / TypeByteSize; + + // No dependence if the scaled distance is not multiple of the stride. + // E.g. + // for (i = 0; i < 1024 ; i += 4) + // A[i+2] = A[i] + 1; + // + // Two accesses in memory (scaled distance is 2, stride is 4): + // | A[0] | | | | A[4] | | | | + // | | | A[2] | | | | A[6] | | + // + // E.g. + // for (i = 0; i < 1024 ; i += 3) + // A[i+4] = A[i] + 1; + // + // Two accesses in memory (scaled distance is 4, stride is 3): + // | A[0] | | | A[3] | | | A[6] | | | + // | | | | | A[4] | | | A[7] | | + return ScaledDist % Stride; +} + +MemoryDepChecker::Dependence::DepType +MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, + const MemAccessInfo &B, unsigned BIdx, + const ValueToValueMap &Strides) { + assert (AIdx < BIdx && "Must pass arguments in program order"); + + Value *APtr = A.getPointer(); + Value *BPtr = B.getPointer(); + bool AIsWrite = A.getInt(); + bool BIsWrite = B.getInt(); + + // Two reads are independent. + if (!AIsWrite && !BIsWrite) + return Dependence::NoDep; + + // We cannot check pointers in different address spaces. + if (APtr->getType()->getPointerAddressSpace() != + BPtr->getType()->getPointerAddressSpace()) + return Dependence::Unknown; + + int64_t StrideAPtr = getPtrStride(PSE, APtr, InnermostLoop, Strides, true); + int64_t StrideBPtr = getPtrStride(PSE, BPtr, InnermostLoop, Strides, true); + + const SCEV *Src = PSE.getSCEV(APtr); + const SCEV *Sink = PSE.getSCEV(BPtr); + + // If the induction step is negative we have to invert source and sink of the + // dependence. + if (StrideAPtr < 0) { + std::swap(APtr, BPtr); + std::swap(Src, Sink); + std::swap(AIsWrite, BIsWrite); + std::swap(AIdx, BIdx); + std::swap(StrideAPtr, StrideBPtr); + } + + const SCEV *Dist = PSE.getSE()->getMinusSCEV(Sink, Src); + + LLVM_DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink + << "(Induction step: " << StrideAPtr << ")\n"); + LLVM_DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to " + << *InstMap[BIdx] << ": " << *Dist << "\n"); + + // Need accesses with constant stride. We don't want to vectorize + // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in + // the address space. + if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr){ + LLVM_DEBUG(dbgs() << "Pointer access with non-constant stride\n"); + return Dependence::Unknown; + } + + Type *ATy = APtr->getType()->getPointerElementType(); + Type *BTy = BPtr->getType()->getPointerElementType(); + auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); + uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); + uint64_t Stride = std::abs(StrideAPtr); + const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist); + if (!C) { + if (TypeByteSize == DL.getTypeAllocSize(BTy) && + isSafeDependenceDistance(DL, *(PSE.getSE()), + *(PSE.getBackedgeTakenCount()), *Dist, Stride, + TypeByteSize)) + return Dependence::NoDep; + + LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n"); + FoundNonConstantDistanceDependence = true; + return Dependence::Unknown; + } + + const APInt &Val = C->getAPInt(); + int64_t Distance = Val.getSExtValue(); + + // Attempt to prove strided accesses independent. + if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy && + areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) { + LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n"); + return Dependence::NoDep; + } + + // Negative distances are not plausible dependencies. + if (Val.isNegative()) { + bool IsTrueDataDependence = (AIsWrite && !BIsWrite); + if (IsTrueDataDependence && EnableForwardingConflictDetection && + (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeByteSize) || + ATy != BTy)) { + LLVM_DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n"); + return Dependence::ForwardButPreventsForwarding; + } + + LLVM_DEBUG(dbgs() << "LAA: Dependence is negative\n"); + return Dependence::Forward; + } + + // Write to the same location with the same size. + // Could be improved to assert type sizes are the same (i32 == float, etc). + if (Val == 0) { + if (ATy == BTy) + return Dependence::Forward; + LLVM_DEBUG( + dbgs() << "LAA: Zero dependence difference but different types\n"); + return Dependence::Unknown; + } + + assert(Val.isStrictlyPositive() && "Expect a positive value"); + + if (ATy != BTy) { + LLVM_DEBUG( + dbgs() + << "LAA: ReadWrite-Write positive dependency with different types\n"); + return Dependence::Unknown; + } + + // Bail out early if passed-in parameters make vectorization not feasible. + unsigned ForcedFactor = (VectorizerParams::VectorizationFactor ? + VectorizerParams::VectorizationFactor : 1); + unsigned ForcedUnroll = (VectorizerParams::VectorizationInterleave ? + VectorizerParams::VectorizationInterleave : 1); + // The minimum number of iterations for a vectorized/unrolled version. + unsigned MinNumIter = std::max(ForcedFactor * ForcedUnroll, 2U); + + // It's not vectorizable if the distance is smaller than the minimum distance + // needed for a vectroized/unrolled version. Vectorizing one iteration in + // front needs TypeByteSize * Stride. Vectorizing the last iteration needs + // TypeByteSize (No need to plus the last gap distance). + // + // E.g. Assume one char is 1 byte in memory and one int is 4 bytes. + // foo(int *A) { + // int *B = (int *)((char *)A + 14); + // for (i = 0 ; i < 1024 ; i += 2) + // B[i] = A[i] + 1; + // } + // + // Two accesses in memory (stride is 2): + // | A[0] | | A[2] | | A[4] | | A[6] | | + // | B[0] | | B[2] | | B[4] | + // + // Distance needs for vectorizing iterations except the last iteration: + // 4 * 2 * (MinNumIter - 1). Distance needs for the last iteration: 4. + // So the minimum distance needed is: 4 * 2 * (MinNumIter - 1) + 4. + // + // If MinNumIter is 2, it is vectorizable as the minimum distance needed is + // 12, which is less than distance. + // + // If MinNumIter is 4 (Say if a user forces the vectorization factor to be 4), + // the minimum distance needed is 28, which is greater than distance. It is + // not safe to do vectorization. + uint64_t MinDistanceNeeded = + TypeByteSize * Stride * (MinNumIter - 1) + TypeByteSize; + if (MinDistanceNeeded > static_cast<uint64_t>(Distance)) { + LLVM_DEBUG(dbgs() << "LAA: Failure because of positive distance " + << Distance << '\n'); + return Dependence::Backward; + } + + // Unsafe if the minimum distance needed is greater than max safe distance. + if (MinDistanceNeeded > MaxSafeDepDistBytes) { + LLVM_DEBUG(dbgs() << "LAA: Failure because it needs at least " + << MinDistanceNeeded << " size in bytes"); + return Dependence::Backward; + } + + // Positive distance bigger than max vectorization factor. + // FIXME: Should use max factor instead of max distance in bytes, which could + // not handle different types. + // E.g. Assume one char is 1 byte in memory and one int is 4 bytes. + // void foo (int *A, char *B) { + // for (unsigned i = 0; i < 1024; i++) { + // A[i+2] = A[i] + 1; + // B[i+2] = B[i] + 1; + // } + // } + // + // This case is currently unsafe according to the max safe distance. If we + // analyze the two accesses on array B, the max safe dependence distance + // is 2. Then we analyze the accesses on array A, the minimum distance needed + // is 8, which is less than 2 and forbidden vectorization, But actually + // both A and B could be vectorized by 2 iterations. + MaxSafeDepDistBytes = + std::min(static_cast<uint64_t>(Distance), MaxSafeDepDistBytes); + + bool IsTrueDataDependence = (!AIsWrite && BIsWrite); + if (IsTrueDataDependence && EnableForwardingConflictDetection && + couldPreventStoreLoadForward(Distance, TypeByteSize)) + return Dependence::BackwardVectorizableButPreventsForwarding; + + uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride); + LLVM_DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue() + << " with max VF = " << MaxVF << '\n'); + uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8; + MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits); + return Dependence::BackwardVectorizable; +} + +bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, + MemAccessInfoList &CheckDeps, + const ValueToValueMap &Strides) { + + MaxSafeDepDistBytes = -1; + SmallPtrSet<MemAccessInfo, 8> Visited; + for (MemAccessInfo CurAccess : CheckDeps) { + if (Visited.count(CurAccess)) + continue; + + // Get the relevant memory access set. + EquivalenceClasses<MemAccessInfo>::iterator I = + AccessSets.findValue(AccessSets.getLeaderValue(CurAccess)); + + // Check accesses within this set. + EquivalenceClasses<MemAccessInfo>::member_iterator AI = + AccessSets.member_begin(I); + EquivalenceClasses<MemAccessInfo>::member_iterator AE = + AccessSets.member_end(); + + // Check every access pair. + while (AI != AE) { + Visited.insert(*AI); + bool AIIsWrite = AI->getInt(); + // Check loads only against next equivalent class, but stores also against + // other stores in the same equivalence class - to the same address. + EquivalenceClasses<MemAccessInfo>::member_iterator OI = + (AIIsWrite ? AI : std::next(AI)); + while (OI != AE) { + // Check every accessing instruction pair in program order. + for (std::vector<unsigned>::iterator I1 = Accesses[*AI].begin(), + I1E = Accesses[*AI].end(); I1 != I1E; ++I1) + // Scan all accesses of another equivalence class, but only the next + // accesses of the same equivalent class. + for (std::vector<unsigned>::iterator + I2 = (OI == AI ? std::next(I1) : Accesses[*OI].begin()), + I2E = (OI == AI ? I1E : Accesses[*OI].end()); + I2 != I2E; ++I2) { + auto A = std::make_pair(&*AI, *I1); + auto B = std::make_pair(&*OI, *I2); + + assert(*I1 != *I2); + if (*I1 > *I2) + std::swap(A, B); + + Dependence::DepType Type = + isDependent(*A.first, A.second, *B.first, B.second, Strides); + mergeInStatus(Dependence::isSafeForVectorization(Type)); + + // Gather dependences unless we accumulated MaxDependences + // dependences. In that case return as soon as we find the first + // unsafe dependence. This puts a limit on this quadratic + // algorithm. + if (RecordDependences) { + if (Type != Dependence::NoDep) + Dependences.push_back(Dependence(A.second, B.second, Type)); + + if (Dependences.size() >= MaxDependences) { + RecordDependences = false; + Dependences.clear(); + LLVM_DEBUG(dbgs() + << "Too many dependences, stopped recording\n"); + } + } + if (!RecordDependences && !isSafeForVectorization()) + return false; + } + ++OI; + } + AI++; + } + } + + LLVM_DEBUG(dbgs() << "Total Dependences: " << Dependences.size() << "\n"); + return isSafeForVectorization(); +} + +SmallVector<Instruction *, 4> +MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool isWrite) const { + MemAccessInfo Access(Ptr, isWrite); + auto &IndexVector = Accesses.find(Access)->second; + + SmallVector<Instruction *, 4> Insts; + transform(IndexVector, + std::back_inserter(Insts), + [&](unsigned Idx) { return this->InstMap[Idx]; }); + return Insts; +} + +const char *MemoryDepChecker::Dependence::DepName[] = { + "NoDep", "Unknown", "Forward", "ForwardButPreventsForwarding", "Backward", + "BackwardVectorizable", "BackwardVectorizableButPreventsForwarding"}; + +void MemoryDepChecker::Dependence::print( + raw_ostream &OS, unsigned Depth, + const SmallVectorImpl<Instruction *> &Instrs) const { + OS.indent(Depth) << DepName[Type] << ":\n"; + OS.indent(Depth + 2) << *Instrs[Source] << " -> \n"; + OS.indent(Depth + 2) << *Instrs[Destination] << "\n"; +} + +bool LoopAccessInfo::canAnalyzeLoop() { + // We need to have a loop header. + LLVM_DEBUG(dbgs() << "LAA: Found a loop in " + << TheLoop->getHeader()->getParent()->getName() << ": " + << TheLoop->getHeader()->getName() << '\n'); + + // We can only analyze innermost loops. + if (!TheLoop->empty()) { + LLVM_DEBUG(dbgs() << "LAA: loop is not the innermost loop\n"); + recordAnalysis("NotInnerMostLoop") << "loop is not the innermost loop"; + return false; + } + + // We must have a single backedge. + if (TheLoop->getNumBackEdges() != 1) { + LLVM_DEBUG( + dbgs() << "LAA: loop control flow is not understood by analyzer\n"); + recordAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by analyzer"; + return false; + } + + // We must have a single exiting block. + if (!TheLoop->getExitingBlock()) { + LLVM_DEBUG( + dbgs() << "LAA: loop control flow is not understood by analyzer\n"); + recordAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by analyzer"; + return false; + } + + // We only handle bottom-tested loops, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { + LLVM_DEBUG( + dbgs() << "LAA: loop control flow is not understood by analyzer\n"); + recordAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by analyzer"; + return false; + } + + // ScalarEvolution needs to be able to find the exit count. + const SCEV *ExitCount = PSE->getBackedgeTakenCount(); + if (ExitCount == PSE->getSE()->getCouldNotCompute()) { + recordAnalysis("CantComputeNumberOfIterations") + << "could not determine number of loop iterations"; + LLVM_DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); + return false; + } + + return true; +} + +void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, + const TargetLibraryInfo *TLI, + DominatorTree *DT) { + typedef SmallPtrSet<Value*, 16> ValueSet; + + // Holds the Load and Store instructions. + SmallVector<LoadInst *, 16> Loads; + SmallVector<StoreInst *, 16> Stores; + + // Holds all the different accesses in the loop. + unsigned NumReads = 0; + unsigned NumReadWrites = 0; + + bool HasComplexMemInst = false; + + // A runtime check is only legal to insert if there are no convergent calls. + HasConvergentOp = false; + + PtrRtChecking->Pointers.clear(); + PtrRtChecking->Need = false; + + const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + + // For each block. + for (BasicBlock *BB : TheLoop->blocks()) { + // Scan the BB and collect legal loads and stores. Also detect any + // convergent instructions. + for (Instruction &I : *BB) { + if (auto *Call = dyn_cast<CallBase>(&I)) { + if (Call->isConvergent()) + HasConvergentOp = true; + } + + // With both a non-vectorizable memory instruction and a convergent + // operation, found in this loop, no reason to continue the search. + if (HasComplexMemInst && HasConvergentOp) { + CanVecMem = false; + return; + } + + // Avoid hitting recordAnalysis multiple times. + if (HasComplexMemInst) + continue; + + // If this is a load, save it. If this instruction can read from memory + // but is not a load, then we quit. Notice that we don't handle function + // calls that read or write. + if (I.mayReadFromMemory()) { + // Many math library functions read the rounding mode. We will only + // vectorize a loop if it contains known function calls that don't set + // the flag. Therefore, it is safe to ignore this read from memory. + auto *Call = dyn_cast<CallInst>(&I); + if (Call && getVectorIntrinsicIDForCall(Call, TLI)) + continue; + + // If the function has an explicit vectorized counterpart, we can safely + // assume that it can be vectorized. + if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() && + TLI->isFunctionVectorizable(Call->getCalledFunction()->getName())) + continue; + + auto *Ld = dyn_cast<LoadInst>(&I); + if (!Ld) { + recordAnalysis("CantVectorizeInstruction", Ld) + << "instruction cannot be vectorized"; + HasComplexMemInst = true; + continue; + } + if (!Ld->isSimple() && !IsAnnotatedParallel) { + recordAnalysis("NonSimpleLoad", Ld) + << "read with atomic ordering or volatile read"; + LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n"); + HasComplexMemInst = true; + continue; + } + NumLoads++; + Loads.push_back(Ld); + DepChecker->addAccess(Ld); + if (EnableMemAccessVersioning) + collectStridedAccess(Ld); + continue; + } + + // Save 'store' instructions. Abort if other instructions write to memory. + if (I.mayWriteToMemory()) { + auto *St = dyn_cast<StoreInst>(&I); + if (!St) { + recordAnalysis("CantVectorizeInstruction", St) + << "instruction cannot be vectorized"; + HasComplexMemInst = true; + continue; + } + if (!St->isSimple() && !IsAnnotatedParallel) { + recordAnalysis("NonSimpleStore", St) + << "write with atomic ordering or volatile write"; + LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n"); + HasComplexMemInst = true; + continue; + } + NumStores++; + Stores.push_back(St); + DepChecker->addAccess(St); + if (EnableMemAccessVersioning) + collectStridedAccess(St); + } + } // Next instr. + } // Next block. + + if (HasComplexMemInst) { + CanVecMem = false; + return; + } + + // Now we have two lists that hold the loads and the stores. + // Next, we find the pointers that they use. + + // Check if we see any stores. If there are no stores, then we don't + // care if the pointers are *restrict*. + if (!Stores.size()) { + LLVM_DEBUG(dbgs() << "LAA: Found a read-only loop!\n"); + CanVecMem = true; + return; + } + + MemoryDepChecker::DepCandidates DependentAccesses; + AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), + TheLoop, AA, LI, DependentAccesses, *PSE); + + // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects + // multiple times on the same object. If the ptr is accessed twice, once + // for read and once for write, it will only appear once (on the write + // list). This is okay, since we are going to check for conflicts between + // writes and between reads and writes, but not between reads and reads. + ValueSet Seen; + + // Record uniform store addresses to identify if we have multiple stores + // to the same address. + ValueSet UniformStores; + + for (StoreInst *ST : Stores) { + Value *Ptr = ST->getPointerOperand(); + + if (isUniform(Ptr)) + HasDependenceInvolvingLoopInvariantAddress |= + !UniformStores.insert(Ptr).second; + + // If we did *not* see this pointer before, insert it to the read-write + // list. At this phase it is only a 'write' list. + if (Seen.insert(Ptr).second) { + ++NumReadWrites; + + MemoryLocation Loc = MemoryLocation::get(ST); + // The TBAA metadata could have a control dependency on the predication + // condition, so we cannot rely on it when determining whether or not we + // need runtime pointer checks. + if (blockNeedsPredication(ST->getParent(), TheLoop, DT)) + Loc.AATags.TBAA = nullptr; + + Accesses.addStore(Loc); + } + } + + if (IsAnnotatedParallel) { + LLVM_DEBUG( + dbgs() << "LAA: A loop annotated parallel, ignore memory dependency " + << "checks.\n"); + CanVecMem = true; + return; + } + + for (LoadInst *LD : Loads) { + Value *Ptr = LD->getPointerOperand(); + // If we did *not* see this pointer before, insert it to the + // read list. If we *did* see it before, then it is already in + // the read-write list. This allows us to vectorize expressions + // such as A[i] += x; Because the address of A[i] is a read-write + // pointer. This only works if the index of A[i] is consecutive. + // If the address of i is unknown (for example A[B[i]]) then we may + // read a few words, modify, and write a few words, and some of the + // words may be written to the same address. + bool IsReadOnlyPtr = false; + if (Seen.insert(Ptr).second || + !getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) { + ++NumReads; + IsReadOnlyPtr = true; + } + + // See if there is an unsafe dependency between a load to a uniform address and + // store to the same uniform address. + if (UniformStores.count(Ptr)) { + LLVM_DEBUG(dbgs() << "LAA: Found an unsafe dependency between a uniform " + "load and uniform store to the same address!\n"); + HasDependenceInvolvingLoopInvariantAddress = true; + } + + MemoryLocation Loc = MemoryLocation::get(LD); + // The TBAA metadata could have a control dependency on the predication + // condition, so we cannot rely on it when determining whether or not we + // need runtime pointer checks. + if (blockNeedsPredication(LD->getParent(), TheLoop, DT)) + Loc.AATags.TBAA = nullptr; + + Accesses.addLoad(Loc, IsReadOnlyPtr); + } + + // If we write (or read-write) to a single destination and there are no + // other reads in this loop then is it safe to vectorize. + if (NumReadWrites == 1 && NumReads == 0) { + LLVM_DEBUG(dbgs() << "LAA: Found a write-only loop!\n"); + CanVecMem = true; + return; + } + + // Build dependence sets and check whether we need a runtime pointer bounds + // check. + Accesses.buildDependenceSets(); + + // Find pointers with computable bounds. We are going to use this information + // to place a runtime bound check. + bool CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), + TheLoop, SymbolicStrides); + if (!CanDoRTIfNeeded) { + recordAnalysis("CantIdentifyArrayBounds") << "cannot identify array bounds"; + LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " + << "the array bounds.\n"); + CanVecMem = false; + return; + } + + LLVM_DEBUG( + dbgs() << "LAA: May be able to perform a memory runtime check if needed.\n"); + + CanVecMem = true; + if (Accesses.isDependencyCheckNeeded()) { + LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n"); + CanVecMem = DepChecker->areDepsSafe( + DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides); + MaxSafeDepDistBytes = DepChecker->getMaxSafeDepDistBytes(); + + if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) { + LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n"); + + // Clear the dependency checks. We assume they are not needed. + Accesses.resetDepChecks(*DepChecker); + + PtrRtChecking->reset(); + PtrRtChecking->Need = true; + + auto *SE = PSE->getSE(); + CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, SE, TheLoop, + SymbolicStrides, true); + + // Check that we found the bounds for the pointer. + if (!CanDoRTIfNeeded) { + recordAnalysis("CantCheckMemDepsAtRunTime") + << "cannot check memory dependencies at runtime"; + LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); + CanVecMem = false; + return; + } + + CanVecMem = true; + } + } + + if (HasConvergentOp) { + recordAnalysis("CantInsertRuntimeCheckWithConvergent") + << "cannot add control dependency to convergent operation"; + LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because a runtime check " + "would be needed with a convergent operation\n"); + CanVecMem = false; + return; + } + + if (CanVecMem) + LLVM_DEBUG( + dbgs() << "LAA: No unsafe dependent memory operations in loop. We" + << (PtrRtChecking->Need ? "" : " don't") + << " need runtime memory checks.\n"); + else { + recordAnalysis("UnsafeMemDep") + << "unsafe dependent memory operations in loop. Use " + "#pragma loop distribute(enable) to allow loop distribution " + "to attempt to isolate the offending operations into a separate " + "loop"; + LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); + } +} + +bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop, + DominatorTree *DT) { + assert(TheLoop->contains(BB) && "Unknown block used"); + + // Blocks that do not dominate the latch need predication. + BasicBlock* Latch = TheLoop->getLoopLatch(); + return !DT->dominates(BB, Latch); +} + +OptimizationRemarkAnalysis &LoopAccessInfo::recordAnalysis(StringRef RemarkName, + Instruction *I) { + assert(!Report && "Multiple reports generated"); + + Value *CodeRegion = TheLoop->getHeader(); + DebugLoc DL = TheLoop->getStartLoc(); + + if (I) { + CodeRegion = I->getParent(); + // If there is no debug location attached to the instruction, revert back to + // using the loop's. + if (I->getDebugLoc()) + DL = I->getDebugLoc(); + } + + Report = std::make_unique<OptimizationRemarkAnalysis>(DEBUG_TYPE, RemarkName, DL, + CodeRegion); + return *Report; +} + +bool LoopAccessInfo::isUniform(Value *V) const { + auto *SE = PSE->getSE(); + // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is + // never considered uniform. + // TODO: Is this really what we want? Even without FP SCEV, we may want some + // trivially loop-invariant FP values to be considered uniform. + if (!SE->isSCEVable(V->getType())) + return false; + return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); +} + +// FIXME: this function is currently a duplicate of the one in +// LoopVectorize.cpp. +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast<Instruction>(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +namespace { + +/// IR Values for the lower and upper bounds of a pointer evolution. We +/// need to use value-handles because SCEV expansion can invalidate previously +/// expanded values. Thus expansion of a pointer can invalidate the bounds for +/// a previous one. +struct PointerBounds { + TrackingVH<Value> Start; + TrackingVH<Value> End; +}; + +} // end anonymous namespace + +/// Expand code for the lower and upper bound of the pointer group \p CG +/// in \p TheLoop. \return the values for the bounds. +static PointerBounds +expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, + Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE, + const RuntimePointerChecking &PtrRtChecking) { + Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue; + const SCEV *Sc = SE->getSCEV(Ptr); + + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + LLVMContext &Ctx = Loc->getContext(); + + // Use this type for pointer arithmetic. + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + + if (SE->isLoopInvariant(Sc, TheLoop)) { + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" + << *Ptr << "\n"); + // Ptr could be in the loop body. If so, expand a new one at the correct + // location. + Instruction *Inst = dyn_cast<Instruction>(Ptr); + Value *NewPtr = (Inst && TheLoop->contains(Inst)) + ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) + : Ptr; + // We must return a half-open range, which means incrementing Sc. + const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); + Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); + return {NewPtr, NewPtrPlusOne}; + } else { + Value *Start = nullptr, *End = nullptr; + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High + << "\n"); + return {Start, End}; + } +} + +/// Turns a collection of checks into a collection of expanded upper and +/// lower bounds for both pointers in the check. +static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds( + const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks, + Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, + const RuntimePointerChecking &PtrRtChecking) { + SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; + + // Here we're relying on the SCEV Expander's cache to only emit code for the + // same bounds once. + transform( + PointerChecks, std::back_inserter(ChecksWithBounds), + [&](const RuntimePointerChecking::PointerCheck &Check) { + PointerBounds + First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking), + Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking); + return std::make_pair(First, Second); + }); + + return ChecksWithBounds; +} + +std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( + Instruction *Loc, + const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks) + const { + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + auto *SE = PSE->getSE(); + SCEVExpander Exp(*SE, DL, "induction"); + auto ExpandedChecks = + expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, *PtrRtChecking); + + LLVMContext &Ctx = Loc->getContext(); + Instruction *FirstInst = nullptr; + IRBuilder<> ChkBuilder(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + + for (const auto &Check : ExpandedChecks) { + const PointerBounds &A = Check.first, &B = Check.second; + // Check if two pointers (A and B) conflict where conflict is computed as: + // start(A) <= end(B) && start(B) <= end(A) + unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); + unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); + + assert((AS0 == B.End->getType()->getPointerAddressSpace()) && + (AS1 == A.End->getType()->getPointerAddressSpace()) && + "Trying to bounds check pointers with different address spaces"); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + + Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); + + // [A|B].Start points to the first accessed byte under base [A|B]. + // [A|B].End points to the last accessed byte, plus one. + // There is no conflict when the intervals are disjoint: + // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) + // + // bound0 = (B.Start < A.End) + // bound1 = (A.Start < B.End) + // IsConflict = bound0 & bound1 + Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); + FirstInst = getFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + FirstInst = getFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = + ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + + if (!MemoryRuntimeCheck) + return std::make_pair(nullptr, nullptr); + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, + ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(Check, "memcheck.conflict"); + FirstInst = getFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} + +std::pair<Instruction *, Instruction *> +LoopAccessInfo::addRuntimeChecks(Instruction *Loc) const { + if (!PtrRtChecking->Need) + return std::make_pair(nullptr, nullptr); + + return addRuntimeChecks(Loc, PtrRtChecking->getChecks()); +} + +void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { + Value *Ptr = nullptr; + if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) + Ptr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess)) + Ptr = SI->getPointerOperand(); + else + return; + + Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); + if (!Stride) + return; + + LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " + "versioning:"); + LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + + // Avoid adding the "Stride == 1" predicate when we know that + // Stride >= Trip-Count. Such a predicate will effectively optimize a single + // or zero iteration loop, as Trip-Count <= Stride == 1. + // + // TODO: We are currently not making a very informed decision on when it is + // beneficial to apply stride versioning. It might make more sense that the + // users of this analysis (such as the vectorizer) will trigger it, based on + // their specific cost considerations; For example, in cases where stride + // versioning does not help resolving memory accesses/dependences, the + // vectorizer should evaluate the cost of the runtime test, and the benefit + // of various possible stride specializations, considering the alternatives + // of using gather/scatters (if available). + + const SCEV *StrideExpr = PSE->getSCEV(Stride); + const SCEV *BETakenCount = PSE->getBackedgeTakenCount(); + + // Match the types so we can compare the stride and the BETakenCount. + // The Stride can be positive/negative, so we sign extend Stride; + // The backedgeTakenCount is non-negative, so we zero extend BETakenCount. + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType()); + uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType()); + const SCEV *CastedStride = StrideExpr; + const SCEV *CastedBECount = BETakenCount; + ScalarEvolution *SE = PSE->getSE(); + if (BETypeSize >= StrideTypeSize) + CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType()); + else + CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType()); + const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount); + // Since TripCount == BackEdgeTakenCount + 1, checking: + // "Stride >= TripCount" is equivalent to checking: + // Stride - BETakenCount > 0 + if (SE->isKnownPositive(StrideMinusBETaken)) { + LLVM_DEBUG( + dbgs() << "LAA: Stride>=TripCount; No point in versioning as the " + "Stride==1 predicate will imply that the loop executes " + "at most once.\n"); + return; + } + LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version."); + + SymbolicStrides[Ptr] = Stride; + StrideSet.insert(Stride); +} + +LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, + const TargetLibraryInfo *TLI, AliasAnalysis *AA, + DominatorTree *DT, LoopInfo *LI) + : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)), + PtrRtChecking(std::make_unique<RuntimePointerChecking>(SE)), + DepChecker(std::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L), + NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), + HasConvergentOp(false), + HasDependenceInvolvingLoopInvariantAddress(false) { + if (canAnalyzeLoop()) + analyzeLoop(AA, LI, TLI, DT); +} + +void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { + if (CanVecMem) { + OS.indent(Depth) << "Memory dependences are safe"; + if (MaxSafeDepDistBytes != -1ULL) + OS << " with a maximum dependence distance of " << MaxSafeDepDistBytes + << " bytes"; + if (PtrRtChecking->Need) + OS << " with run-time checks"; + OS << "\n"; + } + + if (HasConvergentOp) + OS.indent(Depth) << "Has convergent operation in loop\n"; + + if (Report) + OS.indent(Depth) << "Report: " << Report->getMsg() << "\n"; + + if (auto *Dependences = DepChecker->getDependences()) { + OS.indent(Depth) << "Dependences:\n"; + for (auto &Dep : *Dependences) { + Dep.print(OS, Depth + 2, DepChecker->getMemoryInstructions()); + OS << "\n"; + } + } else + OS.indent(Depth) << "Too many dependences, not recorded\n"; + + // List the pair of accesses need run-time checks to prove independence. + PtrRtChecking->print(OS, Depth); + OS << "\n"; + + OS.indent(Depth) << "Non vectorizable stores to invariant address were " + << (HasDependenceInvolvingLoopInvariantAddress ? "" : "not ") + << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + PSE->getUnionPredicate().print(OS, Depth); + + OS << "\n"; + + OS.indent(Depth) << "Expressions re-written:\n"; + PSE->print(OS, Depth); +} + +const LoopAccessInfo &LoopAccessLegacyAnalysis::getInfo(Loop *L) { + auto &LAI = LoopAccessInfoMap[L]; + + if (!LAI) + LAI = std::make_unique<LoopAccessInfo>(L, SE, TLI, AA, DT, LI); + + return *LAI.get(); +} + +void LoopAccessLegacyAnalysis::print(raw_ostream &OS, const Module *M) const { + LoopAccessLegacyAnalysis &LAA = *const_cast<LoopAccessLegacyAnalysis *>(this); + + for (Loop *TopLevelLoop : *LI) + for (Loop *L : depth_first(TopLevelLoop)) { + OS.indent(2) << L->getHeader()->getName() << ":\n"; + auto &LAI = LAA.getInfo(L); + LAI.print(OS, 4); + } +} + +bool LoopAccessLegacyAnalysis::runOnFunction(Function &F) { + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + TLI = TLIP ? &TLIP->getTLI(F) : nullptr; + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + + return false; +} + +void LoopAccessLegacyAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + + AU.setPreservesAll(); +} + +char LoopAccessLegacyAnalysis::ID = 0; +static const char laa_name[] = "Loop Access Analysis"; +#define LAA_NAME "loop-accesses" + +INITIALIZE_PASS_BEGIN(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) + +AnalysisKey LoopAccessAnalysis::Key; + +LoopAccessInfo LoopAccessAnalysis::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR) { + return LoopAccessInfo(&L, &AR.SE, &AR.TLI, &AR.AA, &AR.DT, &AR.LI); +} + +namespace llvm { + + Pass *createLAAPass() { + return new LoopAccessLegacyAnalysis(); + } + +} // end namespace llvm diff --git a/llvm/lib/Analysis/LoopAnalysisManager.cpp b/llvm/lib/Analysis/LoopAnalysisManager.cpp new file mode 100644 index 000000000000..02d40fb8d72a --- /dev/null +++ b/llvm/lib/Analysis/LoopAnalysisManager.cpp @@ -0,0 +1,151 @@ +//===- LoopAnalysisManager.cpp - Loop analysis management -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +namespace llvm { +// Explicit template instantiations and specialization definitions for core +// template typedefs. +template class AllAnalysesOn<Loop>; +template class AnalysisManager<Loop, LoopStandardAnalysisResults &>; +template class InnerAnalysisManagerProxy<LoopAnalysisManager, Function>; +template class OuterAnalysisManagerProxy<FunctionAnalysisManager, Loop, + LoopStandardAnalysisResults &>; + +bool LoopAnalysisManagerFunctionProxy::Result::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // First compute the sequence of IR units covered by this proxy. We will want + // to visit this in postorder, but because this is a tree structure we can do + // this by building a preorder sequence and walking it backwards. We also + // want siblings in forward program order to match the LoopPassManager so we + // get the preorder with siblings reversed. + SmallVector<Loop *, 4> PreOrderLoops = LI->getLoopsInReverseSiblingPreorder(); + + // If this proxy or the loop info is going to be invalidated, we also need + // to clear all the keys coming from that analysis. We also completely blow + // away the loop analyses if any of the standard analyses provided by the + // loop pass manager go away so that loop analyses can freely use these + // without worrying about declaring dependencies on them etc. + // FIXME: It isn't clear if this is the right tradeoff. We could instead make + // loop analyses declare any dependencies on these and use the more general + // invalidation logic below to act on that. + auto PAC = PA.getChecker<LoopAnalysisManagerFunctionProxy>(); + bool invalidateMemorySSAAnalysis = false; + if (MSSAUsed) + invalidateMemorySSAAnalysis = Inv.invalidate<MemorySSAAnalysis>(F, PA); + if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + Inv.invalidate<AAManager>(F, PA) || + Inv.invalidate<AssumptionAnalysis>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA) || + Inv.invalidate<LoopAnalysis>(F, PA) || + Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) || + invalidateMemorySSAAnalysis) { + // Note that the LoopInfo may be stale at this point, however the loop + // objects themselves remain the only viable keys that could be in the + // analysis manager's cache. So we just walk the keys and forcibly clear + // those results. Note that the order doesn't matter here as this will just + // directly destroy the results without calling methods on them. + for (Loop *L : PreOrderLoops) { + // NB! `L` may not be in a good enough state to run Loop::getName. + InnerAM->clear(*L, "<possibly invalidated loop>"); + } + + // We also need to null out the inner AM so that when the object gets + // destroyed as invalid we don't try to clear the inner AM again. At that + // point we won't be able to reliably walk the loops for this function and + // only clear results associated with those loops the way we do here. + // FIXME: Making InnerAM null at this point isn't very nice. Most analyses + // try to remain valid during invalidation. Maybe we should add an + // `IsClean` flag? + InnerAM = nullptr; + + // Now return true to indicate this *is* invalid and a fresh proxy result + // needs to be built. This is especially important given the null InnerAM. + return true; + } + + // Directly check if the relevant set is preserved so we can short circuit + // invalidating loops. + bool AreLoopAnalysesPreserved = + PA.allAnalysesInSetPreserved<AllAnalysesOn<Loop>>(); + + // Since we have a valid LoopInfo we can actually leave the cached results in + // the analysis manager associated with the Loop keys, but we need to + // propagate any necessary invalidation logic into them. We'd like to + // invalidate things in roughly the same order as they were put into the + // cache and so we walk the preorder list in reverse to form a valid + // postorder. + for (Loop *L : reverse(PreOrderLoops)) { + Optional<PreservedAnalyses> InnerPA; + + // Check to see whether the preserved set needs to be adjusted based on + // function-level analysis invalidation triggering deferred invalidation + // for this loop. + if (auto *OuterProxy = + InnerAM->getCachedResult<FunctionAnalysisManagerLoopProxy>(*L)) + for (const auto &OuterInvalidationPair : + OuterProxy->getOuterInvalidations()) { + AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; + const auto &InnerAnalysisIDs = OuterInvalidationPair.second; + if (Inv.invalidate(OuterAnalysisID, F, PA)) { + if (!InnerPA) + InnerPA = PA; + for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) + InnerPA->abandon(InnerAnalysisID); + } + } + + // Check if we needed a custom PA set. If so we'll need to run the inner + // invalidation. + if (InnerPA) { + InnerAM->invalidate(*L, *InnerPA); + continue; + } + + // Otherwise we only need to do invalidation if the original PA set didn't + // preserve all Loop analyses. + if (!AreLoopAnalysesPreserved) + InnerAM->invalidate(*L, PA); + } + + // Return false to indicate that this result is still a valid proxy. + return false; +} + +template <> +LoopAnalysisManagerFunctionProxy::Result +LoopAnalysisManagerFunctionProxy::run(Function &F, + FunctionAnalysisManager &AM) { + return Result(*InnerAM, AM.getResult<LoopAnalysis>(F)); +} +} + +PreservedAnalyses llvm::getLoopPassPreservedAnalyses() { + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + PA.preserve<LoopAnalysisManagerFunctionProxy>(); + PA.preserve<ScalarEvolutionAnalysis>(); + // FIXME: What we really want to do here is preserve an AA category, but that + // concept doesn't exist yet. + PA.preserve<AAManager>(); + PA.preserve<BasicAA>(); + PA.preserve<GlobalsAA>(); + PA.preserve<SCEVAA>(); + return PA; +} diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp new file mode 100644 index 000000000000..10d2fe07884a --- /dev/null +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -0,0 +1,625 @@ +//===- LoopCacheAnalysis.cpp - Loop Cache Analysis -------------------------==// +// +// The LLVM Compiler Infrastructure +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the implementation for the loop cache analysis. +/// The implementation is largely based on the following paper: +/// +/// Compiler Optimizations for Improving Data Locality +/// By: Steve Carr, Katherine S. McKinley, Chau-Wen Tseng +/// http://www.cs.utexas.edu/users/mckinley/papers/asplos-1994.pdf +/// +/// The general approach taken to estimate the number of cache lines used by the +/// memory references in an inner loop is: +/// 1. Partition memory references that exhibit temporal or spacial reuse +/// into reference groups. +/// 2. For each loop L in the a loop nest LN: +/// a. Compute the cost of the reference group +/// b. Compute the loop cost by summing up the reference groups costs +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopCacheAnalysis.h" +#include "llvm/ADT/BreadthFirstIterator.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "loop-cache-cost" + +static cl::opt<unsigned> DefaultTripCount( + "default-trip-count", cl::init(100), cl::Hidden, + cl::desc("Use this to specify the default trip count of a loop")); + +// In this analysis two array references are considered to exhibit temporal +// reuse if they access either the same memory location, or a memory location +// with distance smaller than a configurable threshold. +static cl::opt<unsigned> TemporalReuseThreshold( + "temporal-reuse-threshold", cl::init(2), cl::Hidden, + cl::desc("Use this to specify the max. distance between array elements " + "accessed in a loop so that the elements are classified to have " + "temporal reuse")); + +/// Retrieve the innermost loop in the given loop nest \p Loops. It returns a +/// nullptr if any loops in the loop vector supplied has more than one sibling. +/// The loop vector is expected to contain loops collected in breadth-first +/// order. +static Loop *getInnerMostLoop(const LoopVectorTy &Loops) { + assert(!Loops.empty() && "Expecting a non-empy loop vector"); + + Loop *LastLoop = Loops.back(); + Loop *ParentLoop = LastLoop->getParentLoop(); + + if (ParentLoop == nullptr) { + assert(Loops.size() == 1 && "Expecting a single loop"); + return LastLoop; + } + + return (std::is_sorted(Loops.begin(), Loops.end(), + [](const Loop *L1, const Loop *L2) { + return L1->getLoopDepth() < L2->getLoopDepth(); + })) + ? LastLoop + : nullptr; +} + +static bool isOneDimensionalArray(const SCEV &AccessFn, const SCEV &ElemSize, + const Loop &L, ScalarEvolution &SE) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&AccessFn); + if (!AR || !AR->isAffine()) + return false; + + assert(AR->getLoop() && "AR should have a loop"); + + // Check that start and increment are not add recurrences. + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(SE); + if (isa<SCEVAddRecExpr>(Start) || isa<SCEVAddRecExpr>(Step)) + return false; + + // Check that start and increment are both invariant in the loop. + if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L)) + return false; + + return AR->getStepRecurrence(SE) == &ElemSize; +} + +/// Compute the trip count for the given loop \p L. Return the SCEV expression +/// for the trip count or nullptr if it cannot be computed. +static const SCEV *computeTripCount(const Loop &L, ScalarEvolution &SE) { + const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(&L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount) || + !isa<SCEVConstant>(BackedgeTakenCount)) + return nullptr; + + return SE.getAddExpr(BackedgeTakenCount, + SE.getOne(BackedgeTakenCount->getType())); +} + +//===----------------------------------------------------------------------===// +// IndexedReference implementation +// +raw_ostream &llvm::operator<<(raw_ostream &OS, const IndexedReference &R) { + if (!R.IsValid) { + OS << R.StoreOrLoadInst; + OS << ", IsValid=false."; + return OS; + } + + OS << *R.BasePointer; + for (const SCEV *Subscript : R.Subscripts) + OS << "[" << *Subscript << "]"; + + OS << ", Sizes: "; + for (const SCEV *Size : R.Sizes) + OS << "[" << *Size << "]"; + + return OS; +} + +IndexedReference::IndexedReference(Instruction &StoreOrLoadInst, + const LoopInfo &LI, ScalarEvolution &SE) + : StoreOrLoadInst(StoreOrLoadInst), SE(SE) { + assert((isa<StoreInst>(StoreOrLoadInst) || isa<LoadInst>(StoreOrLoadInst)) && + "Expecting a load or store instruction"); + + IsValid = delinearize(LI); + if (IsValid) + LLVM_DEBUG(dbgs().indent(2) << "Succesfully delinearized: " << *this + << "\n"); +} + +Optional<bool> IndexedReference::hasSpacialReuse(const IndexedReference &Other, + unsigned CLS, + AliasAnalysis &AA) const { + assert(IsValid && "Expecting a valid reference"); + + if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) { + LLVM_DEBUG(dbgs().indent(2) + << "No spacial reuse: different base pointers\n"); + return false; + } + + unsigned NumSubscripts = getNumSubscripts(); + if (NumSubscripts != Other.getNumSubscripts()) { + LLVM_DEBUG(dbgs().indent(2) + << "No spacial reuse: different number of subscripts\n"); + return false; + } + + // all subscripts must be equal, except the leftmost one (the last one). + for (auto SubNum : seq<unsigned>(0, NumSubscripts - 1)) { + if (getSubscript(SubNum) != Other.getSubscript(SubNum)) { + LLVM_DEBUG(dbgs().indent(2) << "No spacial reuse, different subscripts: " + << "\n\t" << *getSubscript(SubNum) << "\n\t" + << *Other.getSubscript(SubNum) << "\n"); + return false; + } + } + + // the difference between the last subscripts must be less than the cache line + // size. + const SCEV *LastSubscript = getLastSubscript(); + const SCEV *OtherLastSubscript = Other.getLastSubscript(); + const SCEVConstant *Diff = dyn_cast<SCEVConstant>( + SE.getMinusSCEV(LastSubscript, OtherLastSubscript)); + + if (Diff == nullptr) { + LLVM_DEBUG(dbgs().indent(2) + << "No spacial reuse, difference between subscript:\n\t" + << *LastSubscript << "\n\t" << OtherLastSubscript + << "\nis not constant.\n"); + return None; + } + + bool InSameCacheLine = (Diff->getValue()->getSExtValue() < CLS); + + LLVM_DEBUG({ + if (InSameCacheLine) + dbgs().indent(2) << "Found spacial reuse.\n"; + else + dbgs().indent(2) << "No spacial reuse.\n"; + }); + + return InSameCacheLine; +} + +Optional<bool> IndexedReference::hasTemporalReuse(const IndexedReference &Other, + unsigned MaxDistance, + const Loop &L, + DependenceInfo &DI, + AliasAnalysis &AA) const { + assert(IsValid && "Expecting a valid reference"); + + if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) { + LLVM_DEBUG(dbgs().indent(2) + << "No temporal reuse: different base pointer\n"); + return false; + } + + std::unique_ptr<Dependence> D = + DI.depends(&StoreOrLoadInst, &Other.StoreOrLoadInst, true); + + if (D == nullptr) { + LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: no dependence\n"); + return false; + } + + if (D->isLoopIndependent()) { + LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n"); + return true; + } + + // Check the dependence distance at every loop level. There is temporal reuse + // if the distance at the given loop's depth is small (|d| <= MaxDistance) and + // it is zero at every other loop level. + int LoopDepth = L.getLoopDepth(); + int Levels = D->getLevels(); + for (int Level = 1; Level <= Levels; ++Level) { + const SCEV *Distance = D->getDistance(Level); + const SCEVConstant *SCEVConst = dyn_cast_or_null<SCEVConstant>(Distance); + + if (SCEVConst == nullptr) { + LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: distance unknown\n"); + return None; + } + + const ConstantInt &CI = *SCEVConst->getValue(); + if (Level != LoopDepth && !CI.isZero()) { + LLVM_DEBUG(dbgs().indent(2) + << "No temporal reuse: distance is not zero at depth=" << Level + << "\n"); + return false; + } else if (Level == LoopDepth && CI.getSExtValue() > MaxDistance) { + LLVM_DEBUG( + dbgs().indent(2) + << "No temporal reuse: distance is greater than MaxDistance at depth=" + << Level << "\n"); + return false; + } + } + + LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n"); + return true; +} + +CacheCostTy IndexedReference::computeRefCost(const Loop &L, + unsigned CLS) const { + assert(IsValid && "Expecting a valid reference"); + LLVM_DEBUG({ + dbgs().indent(2) << "Computing cache cost for:\n"; + dbgs().indent(4) << *this << "\n"; + }); + + // If the indexed reference is loop invariant the cost is one. + if (isLoopInvariant(L)) { + LLVM_DEBUG(dbgs().indent(4) << "Reference is loop invariant: RefCost=1\n"); + return 1; + } + + const SCEV *TripCount = computeTripCount(L, SE); + if (!TripCount) { + LLVM_DEBUG(dbgs() << "Trip count of loop " << L.getName() + << " could not be computed, using DefaultTripCount\n"); + const SCEV *ElemSize = Sizes.back(); + TripCount = SE.getConstant(ElemSize->getType(), DefaultTripCount); + } + LLVM_DEBUG(dbgs() << "TripCount=" << *TripCount << "\n"); + + // If the indexed reference is 'consecutive' the cost is + // (TripCount*Stride)/CLS, otherwise the cost is TripCount. + const SCEV *RefCost = TripCount; + + if (isConsecutive(L, CLS)) { + const SCEV *Coeff = getLastCoefficient(); + const SCEV *ElemSize = Sizes.back(); + const SCEV *Stride = SE.getMulExpr(Coeff, ElemSize); + const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS); + const SCEV *Numerator = SE.getMulExpr(Stride, TripCount); + RefCost = SE.getUDivExpr(Numerator, CacheLineSize); + LLVM_DEBUG(dbgs().indent(4) + << "Access is consecutive: RefCost=(TripCount*Stride)/CLS=" + << *RefCost << "\n"); + } else + LLVM_DEBUG(dbgs().indent(4) + << "Access is not consecutive: RefCost=TripCount=" << *RefCost + << "\n"); + + // Attempt to fold RefCost into a constant. + if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost)) + return ConstantCost->getValue()->getSExtValue(); + + LLVM_DEBUG(dbgs().indent(4) + << "RefCost is not a constant! Setting to RefCost=InvalidCost " + "(invalid value).\n"); + + return CacheCost::InvalidCost; +} + +bool IndexedReference::delinearize(const LoopInfo &LI) { + assert(Subscripts.empty() && "Subscripts should be empty"); + assert(Sizes.empty() && "Sizes should be empty"); + assert(!IsValid && "Should be called once from the constructor"); + LLVM_DEBUG(dbgs() << "Delinearizing: " << StoreOrLoadInst << "\n"); + + const SCEV *ElemSize = SE.getElementSize(&StoreOrLoadInst); + const BasicBlock *BB = StoreOrLoadInst.getParent(); + + for (Loop *L = LI.getLoopFor(BB); L != nullptr; L = L->getParentLoop()) { + const SCEV *AccessFn = + SE.getSCEVAtScope(getPointerOperand(&StoreOrLoadInst), L); + + BasePointer = dyn_cast<SCEVUnknown>(SE.getPointerBase(AccessFn)); + if (BasePointer == nullptr) { + LLVM_DEBUG( + dbgs().indent(2) + << "ERROR: failed to delinearize, can't identify base pointer\n"); + return false; + } + + AccessFn = SE.getMinusSCEV(AccessFn, BasePointer); + + LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName() + << "', AccessFn: " << *AccessFn << "\n"); + + SE.delinearize(AccessFn, Subscripts, Sizes, + SE.getElementSize(&StoreOrLoadInst)); + + if (Subscripts.empty() || Sizes.empty() || + Subscripts.size() != Sizes.size()) { + // Attempt to determine whether we have a single dimensional array access. + // before giving up. + if (!isOneDimensionalArray(*AccessFn, *ElemSize, *L, SE)) { + LLVM_DEBUG(dbgs().indent(2) + << "ERROR: failed to delinearize reference\n"); + Subscripts.clear(); + Sizes.clear(); + break; + } + + const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize); + Subscripts.push_back(Div); + Sizes.push_back(ElemSize); + } + + return all_of(Subscripts, [&](const SCEV *Subscript) { + return isSimpleAddRecurrence(*Subscript, *L); + }); + } + + return false; +} + +bool IndexedReference::isLoopInvariant(const Loop &L) const { + Value *Addr = getPointerOperand(&StoreOrLoadInst); + assert(Addr != nullptr && "Expecting either a load or a store instruction"); + assert(SE.isSCEVable(Addr->getType()) && "Addr should be SCEVable"); + + if (SE.isLoopInvariant(SE.getSCEV(Addr), &L)) + return true; + + // The indexed reference is loop invariant if none of the coefficients use + // the loop induction variable. + bool allCoeffForLoopAreZero = all_of(Subscripts, [&](const SCEV *Subscript) { + return isCoeffForLoopZeroOrInvariant(*Subscript, L); + }); + + return allCoeffForLoopAreZero; +} + +bool IndexedReference::isConsecutive(const Loop &L, unsigned CLS) const { + // The indexed reference is 'consecutive' if the only coefficient that uses + // the loop induction variable is the last one... + const SCEV *LastSubscript = Subscripts.back(); + for (const SCEV *Subscript : Subscripts) { + if (Subscript == LastSubscript) + continue; + if (!isCoeffForLoopZeroOrInvariant(*Subscript, L)) + return false; + } + + // ...and the access stride is less than the cache line size. + const SCEV *Coeff = getLastCoefficient(); + const SCEV *ElemSize = Sizes.back(); + const SCEV *Stride = SE.getMulExpr(Coeff, ElemSize); + const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS); + + return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize); +} + +const SCEV *IndexedReference::getLastCoefficient() const { + const SCEV *LastSubscript = getLastSubscript(); + assert(isa<SCEVAddRecExpr>(LastSubscript) && + "Expecting a SCEV add recurrence expression"); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LastSubscript); + return AR->getStepRecurrence(SE); +} + +bool IndexedReference::isCoeffForLoopZeroOrInvariant(const SCEV &Subscript, + const Loop &L) const { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&Subscript); + return (AR != nullptr) ? AR->getLoop() != &L + : SE.isLoopInvariant(&Subscript, &L); +} + +bool IndexedReference::isSimpleAddRecurrence(const SCEV &Subscript, + const Loop &L) const { + if (!isa<SCEVAddRecExpr>(Subscript)) + return false; + + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(&Subscript); + assert(AR->getLoop() && "AR should have a loop"); + + if (!AR->isAffine()) + return false; + + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(SE); + + if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L)) + return false; + + return true; +} + +bool IndexedReference::isAliased(const IndexedReference &Other, + AliasAnalysis &AA) const { + const auto &Loc1 = MemoryLocation::get(&StoreOrLoadInst); + const auto &Loc2 = MemoryLocation::get(&Other.StoreOrLoadInst); + return AA.isMustAlias(Loc1, Loc2); +} + +//===----------------------------------------------------------------------===// +// CacheCost implementation +// +raw_ostream &llvm::operator<<(raw_ostream &OS, const CacheCost &CC) { + for (const auto &LC : CC.LoopCosts) { + const Loop *L = LC.first; + OS << "Loop '" << L->getName() << "' has cost = " << LC.second << "\n"; + } + return OS; +} + +CacheCost::CacheCost(const LoopVectorTy &Loops, const LoopInfo &LI, + ScalarEvolution &SE, TargetTransformInfo &TTI, + AliasAnalysis &AA, DependenceInfo &DI, + Optional<unsigned> TRT) + : Loops(Loops), TripCounts(), LoopCosts(), + TRT(TRT == None ? Optional<unsigned>(TemporalReuseThreshold) : TRT), + LI(LI), SE(SE), TTI(TTI), AA(AA), DI(DI) { + assert(!Loops.empty() && "Expecting a non-empty loop vector."); + + for (const Loop *L : Loops) { + unsigned TripCount = SE.getSmallConstantTripCount(L); + TripCount = (TripCount == 0) ? DefaultTripCount : TripCount; + TripCounts.push_back({L, TripCount}); + } + + calculateCacheFootprint(); +} + +std::unique_ptr<CacheCost> +CacheCost::getCacheCost(Loop &Root, LoopStandardAnalysisResults &AR, + DependenceInfo &DI, Optional<unsigned> TRT) { + if (Root.getParentLoop()) { + LLVM_DEBUG(dbgs() << "Expecting the outermost loop in a loop nest\n"); + return nullptr; + } + + LoopVectorTy Loops; + for (Loop *L : breadth_first(&Root)) + Loops.push_back(L); + + if (!getInnerMostLoop(Loops)) { + LLVM_DEBUG(dbgs() << "Cannot compute cache cost of loop nest with more " + "than one innermost loop\n"); + return nullptr; + } + + return std::make_unique<CacheCost>(Loops, AR.LI, AR.SE, AR.TTI, AR.AA, DI, TRT); +} + +void CacheCost::calculateCacheFootprint() { + LLVM_DEBUG(dbgs() << "POPULATING REFERENCE GROUPS\n"); + ReferenceGroupsTy RefGroups; + if (!populateReferenceGroups(RefGroups)) + return; + + LLVM_DEBUG(dbgs() << "COMPUTING LOOP CACHE COSTS\n"); + for (const Loop *L : Loops) { + assert((std::find_if(LoopCosts.begin(), LoopCosts.end(), + [L](const LoopCacheCostTy &LCC) { + return LCC.first == L; + }) == LoopCosts.end()) && + "Should not add duplicate element"); + CacheCostTy LoopCost = computeLoopCacheCost(*L, RefGroups); + LoopCosts.push_back(std::make_pair(L, LoopCost)); + } + + sortLoopCosts(); + RefGroups.clear(); +} + +bool CacheCost::populateReferenceGroups(ReferenceGroupsTy &RefGroups) const { + assert(RefGroups.empty() && "Reference groups should be empty"); + + unsigned CLS = TTI.getCacheLineSize(); + Loop *InnerMostLoop = getInnerMostLoop(Loops); + assert(InnerMostLoop != nullptr && "Expecting a valid innermost loop"); + + for (BasicBlock *BB : InnerMostLoop->getBlocks()) { + for (Instruction &I : *BB) { + if (!isa<StoreInst>(I) && !isa<LoadInst>(I)) + continue; + + std::unique_ptr<IndexedReference> R(new IndexedReference(I, LI, SE)); + if (!R->isValid()) + continue; + + bool Added = false; + for (ReferenceGroupTy &RefGroup : RefGroups) { + const IndexedReference &Representative = *RefGroup.front().get(); + LLVM_DEBUG({ + dbgs() << "References:\n"; + dbgs().indent(2) << *R << "\n"; + dbgs().indent(2) << Representative << "\n"; + }); + + Optional<bool> HasTemporalReuse = + R->hasTemporalReuse(Representative, *TRT, *InnerMostLoop, DI, AA); + Optional<bool> HasSpacialReuse = + R->hasSpacialReuse(Representative, CLS, AA); + + if ((HasTemporalReuse.hasValue() && *HasTemporalReuse) || + (HasSpacialReuse.hasValue() && *HasSpacialReuse)) { + RefGroup.push_back(std::move(R)); + Added = true; + break; + } + } + + if (!Added) { + ReferenceGroupTy RG; + RG.push_back(std::move(R)); + RefGroups.push_back(std::move(RG)); + } + } + } + + if (RefGroups.empty()) + return false; + + LLVM_DEBUG({ + dbgs() << "\nIDENTIFIED REFERENCE GROUPS:\n"; + int n = 1; + for (const ReferenceGroupTy &RG : RefGroups) { + dbgs().indent(2) << "RefGroup " << n << ":\n"; + for (const auto &IR : RG) + dbgs().indent(4) << *IR << "\n"; + n++; + } + dbgs() << "\n"; + }); + + return true; +} + +CacheCostTy +CacheCost::computeLoopCacheCost(const Loop &L, + const ReferenceGroupsTy &RefGroups) const { + if (!L.isLoopSimplifyForm()) + return InvalidCost; + + LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName() + << "' as innermost loop.\n"); + + // Compute the product of the trip counts of each other loop in the nest. + CacheCostTy TripCountsProduct = 1; + for (const auto &TC : TripCounts) { + if (TC.first == &L) + continue; + TripCountsProduct *= TC.second; + } + + CacheCostTy LoopCost = 0; + for (const ReferenceGroupTy &RG : RefGroups) { + CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L); + LoopCost += RefGroupCost * TripCountsProduct; + } + + LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName() + << "' has cost=" << LoopCost << "\n"); + + return LoopCost; +} + +CacheCostTy CacheCost::computeRefGroupCacheCost(const ReferenceGroupTy &RG, + const Loop &L) const { + assert(!RG.empty() && "Reference group should have at least one member."); + + const IndexedReference *Representative = RG.front().get(); + return Representative->computeRefCost(L, TTI.getCacheLineSize()); +} + +//===----------------------------------------------------------------------===// +// LoopCachePrinterPass implementation +// +PreservedAnalyses LoopCachePrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + Function *F = L.getHeader()->getParent(); + DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI); + + if (auto CC = CacheCost::getCacheCost(L, AR, DI)) + OS << *CC; + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp new file mode 100644 index 000000000000..dbab5db7dbc2 --- /dev/null +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -0,0 +1,1109 @@ +//===- LoopInfo.cpp - Natural Loop Calculator -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the LoopInfo class that is used to identify natural loops +// and determine the loop depth of various nodes of the CFG. Note that the +// loops identified may actually be several natural loops that share the same +// header node... not just a single natural loop. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/IVDescriptors.h" +#include "llvm/Analysis/LoopInfoImpl.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +using namespace llvm; + +// Explicitly instantiate methods in LoopInfoImpl.h for IR-level Loops. +template class llvm::LoopBase<BasicBlock, Loop>; +template class llvm::LoopInfoBase<BasicBlock, Loop>; + +// Always verify loopinfo if expensive checking is enabled. +#ifdef EXPENSIVE_CHECKS +bool llvm::VerifyLoopInfo = true; +#else +bool llvm::VerifyLoopInfo = false; +#endif +static cl::opt<bool, true> + VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo), + cl::Hidden, cl::desc("Verify loop info (time consuming)")); + +//===----------------------------------------------------------------------===// +// Loop implementation +// + +bool Loop::isLoopInvariant(const Value *V) const { + if (const Instruction *I = dyn_cast<Instruction>(V)) + return !contains(I); + return true; // All non-instructions are loop invariant +} + +bool Loop::hasLoopInvariantOperands(const Instruction *I) const { + return all_of(I->operands(), [this](Value *V) { return isLoopInvariant(V); }); +} + +bool Loop::makeLoopInvariant(Value *V, bool &Changed, Instruction *InsertPt, + MemorySSAUpdater *MSSAU) const { + if (Instruction *I = dyn_cast<Instruction>(V)) + return makeLoopInvariant(I, Changed, InsertPt, MSSAU); + return true; // All non-instructions are loop-invariant. +} + +bool Loop::makeLoopInvariant(Instruction *I, bool &Changed, + Instruction *InsertPt, + MemorySSAUpdater *MSSAU) const { + // Test if the value is already loop-invariant. + if (isLoopInvariant(I)) + return true; + if (!isSafeToSpeculativelyExecute(I)) + return false; + if (I->mayReadFromMemory()) + return false; + // EH block instructions are immobile. + if (I->isEHPad()) + return false; + // Determine the insertion point, unless one was given. + if (!InsertPt) { + BasicBlock *Preheader = getLoopPreheader(); + // Without a preheader, hoisting is not feasible. + if (!Preheader) + return false; + InsertPt = Preheader->getTerminator(); + } + // Don't hoist instructions with loop-variant operands. + for (Value *Operand : I->operands()) + if (!makeLoopInvariant(Operand, Changed, InsertPt, MSSAU)) + return false; + + // Hoist. + I->moveBefore(InsertPt); + if (MSSAU) + if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(I)) + MSSAU->moveToPlace(MUD, InsertPt->getParent(), MemorySSA::End); + + // There is possibility of hoisting this instruction above some arbitrary + // condition. Any metadata defined on it can be control dependent on this + // condition. Conservatively strip it here so that we don't give any wrong + // information to the optimizer. + I->dropUnknownNonDebugMetadata(); + + Changed = true; + return true; +} + +bool Loop::getIncomingAndBackEdge(BasicBlock *&Incoming, + BasicBlock *&Backedge) const { + BasicBlock *H = getHeader(); + + Incoming = nullptr; + Backedge = nullptr; + pred_iterator PI = pred_begin(H); + assert(PI != pred_end(H) && "Loop must have at least one backedge!"); + Backedge = *PI++; + if (PI == pred_end(H)) + return false; // dead loop + Incoming = *PI++; + if (PI != pred_end(H)) + return false; // multiple backedges? + + if (contains(Incoming)) { + if (contains(Backedge)) + return false; + std::swap(Incoming, Backedge); + } else if (!contains(Backedge)) + return false; + + assert(Incoming && Backedge && "expected non-null incoming and backedges"); + return true; +} + +PHINode *Loop::getCanonicalInductionVariable() const { + BasicBlock *H = getHeader(); + + BasicBlock *Incoming = nullptr, *Backedge = nullptr; + if (!getIncomingAndBackEdge(Incoming, Backedge)) + return nullptr; + + // Loop over all of the PHI nodes, looking for a canonical indvar. + for (BasicBlock::iterator I = H->begin(); isa<PHINode>(I); ++I) { + PHINode *PN = cast<PHINode>(I); + if (ConstantInt *CI = + dyn_cast<ConstantInt>(PN->getIncomingValueForBlock(Incoming))) + if (CI->isZero()) + if (Instruction *Inc = + dyn_cast<Instruction>(PN->getIncomingValueForBlock(Backedge))) + if (Inc->getOpcode() == Instruction::Add && Inc->getOperand(0) == PN) + if (ConstantInt *CI = dyn_cast<ConstantInt>(Inc->getOperand(1))) + if (CI->isOne()) + return PN; + } + return nullptr; +} + +/// Get the latch condition instruction. +static ICmpInst *getLatchCmpInst(const Loop &L) { + if (BasicBlock *Latch = L.getLoopLatch()) + if (BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator())) + if (BI->isConditional()) + return dyn_cast<ICmpInst>(BI->getCondition()); + + return nullptr; +} + +/// Return the final value of the loop induction variable if found. +static Value *findFinalIVValue(const Loop &L, const PHINode &IndVar, + const Instruction &StepInst) { + ICmpInst *LatchCmpInst = getLatchCmpInst(L); + if (!LatchCmpInst) + return nullptr; + + Value *Op0 = LatchCmpInst->getOperand(0); + Value *Op1 = LatchCmpInst->getOperand(1); + if (Op0 == &IndVar || Op0 == &StepInst) + return Op1; + + if (Op1 == &IndVar || Op1 == &StepInst) + return Op0; + + return nullptr; +} + +Optional<Loop::LoopBounds> Loop::LoopBounds::getBounds(const Loop &L, + PHINode &IndVar, + ScalarEvolution &SE) { + InductionDescriptor IndDesc; + if (!InductionDescriptor::isInductionPHI(&IndVar, &L, &SE, IndDesc)) + return None; + + Value *InitialIVValue = IndDesc.getStartValue(); + Instruction *StepInst = IndDesc.getInductionBinOp(); + if (!InitialIVValue || !StepInst) + return None; + + const SCEV *Step = IndDesc.getStep(); + Value *StepInstOp1 = StepInst->getOperand(1); + Value *StepInstOp0 = StepInst->getOperand(0); + Value *StepValue = nullptr; + if (SE.getSCEV(StepInstOp1) == Step) + StepValue = StepInstOp1; + else if (SE.getSCEV(StepInstOp0) == Step) + StepValue = StepInstOp0; + + Value *FinalIVValue = findFinalIVValue(L, IndVar, *StepInst); + if (!FinalIVValue) + return None; + + return LoopBounds(L, *InitialIVValue, *StepInst, StepValue, *FinalIVValue, + SE); +} + +using Direction = Loop::LoopBounds::Direction; + +ICmpInst::Predicate Loop::LoopBounds::getCanonicalPredicate() const { + BasicBlock *Latch = L.getLoopLatch(); + assert(Latch && "Expecting valid latch"); + + BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator()); + assert(BI && BI->isConditional() && "Expecting conditional latch branch"); + + ICmpInst *LatchCmpInst = dyn_cast<ICmpInst>(BI->getCondition()); + assert(LatchCmpInst && + "Expecting the latch compare instruction to be a CmpInst"); + + // Need to inverse the predicate when first successor is not the loop + // header + ICmpInst::Predicate Pred = (BI->getSuccessor(0) == L.getHeader()) + ? LatchCmpInst->getPredicate() + : LatchCmpInst->getInversePredicate(); + + if (LatchCmpInst->getOperand(0) == &getFinalIVValue()) + Pred = ICmpInst::getSwappedPredicate(Pred); + + // Need to flip strictness of the predicate when the latch compare instruction + // is not using StepInst + if (LatchCmpInst->getOperand(0) == &getStepInst() || + LatchCmpInst->getOperand(1) == &getStepInst()) + return Pred; + + // Cannot flip strictness of NE and EQ + if (Pred != ICmpInst::ICMP_NE && Pred != ICmpInst::ICMP_EQ) + return ICmpInst::getFlippedStrictnessPredicate(Pred); + + Direction D = getDirection(); + if (D == Direction::Increasing) + return ICmpInst::ICMP_SLT; + + if (D == Direction::Decreasing) + return ICmpInst::ICMP_SGT; + + // If cannot determine the direction, then unable to find the canonical + // predicate + return ICmpInst::BAD_ICMP_PREDICATE; +} + +Direction Loop::LoopBounds::getDirection() const { + if (const SCEVAddRecExpr *StepAddRecExpr = + dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&getStepInst()))) + if (const SCEV *StepRecur = StepAddRecExpr->getStepRecurrence(SE)) { + if (SE.isKnownPositive(StepRecur)) + return Direction::Increasing; + if (SE.isKnownNegative(StepRecur)) + return Direction::Decreasing; + } + + return Direction::Unknown; +} + +Optional<Loop::LoopBounds> Loop::getBounds(ScalarEvolution &SE) const { + if (PHINode *IndVar = getInductionVariable(SE)) + return LoopBounds::getBounds(*this, *IndVar, SE); + + return None; +} + +PHINode *Loop::getInductionVariable(ScalarEvolution &SE) const { + if (!isLoopSimplifyForm()) + return nullptr; + + BasicBlock *Header = getHeader(); + assert(Header && "Expected a valid loop header"); + ICmpInst *CmpInst = getLatchCmpInst(*this); + if (!CmpInst) + return nullptr; + + Instruction *LatchCmpOp0 = dyn_cast<Instruction>(CmpInst->getOperand(0)); + Instruction *LatchCmpOp1 = dyn_cast<Instruction>(CmpInst->getOperand(1)); + + for (PHINode &IndVar : Header->phis()) { + InductionDescriptor IndDesc; + if (!InductionDescriptor::isInductionPHI(&IndVar, this, &SE, IndDesc)) + continue; + + Instruction *StepInst = IndDesc.getInductionBinOp(); + + // case 1: + // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}] + // StepInst = IndVar + step + // cmp = StepInst < FinalValue + if (StepInst == LatchCmpOp0 || StepInst == LatchCmpOp1) + return &IndVar; + + // case 2: + // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}] + // StepInst = IndVar + step + // cmp = IndVar < FinalValue + if (&IndVar == LatchCmpOp0 || &IndVar == LatchCmpOp1) + return &IndVar; + } + + return nullptr; +} + +bool Loop::getInductionDescriptor(ScalarEvolution &SE, + InductionDescriptor &IndDesc) const { + if (PHINode *IndVar = getInductionVariable(SE)) + return InductionDescriptor::isInductionPHI(IndVar, this, &SE, IndDesc); + + return false; +} + +bool Loop::isAuxiliaryInductionVariable(PHINode &AuxIndVar, + ScalarEvolution &SE) const { + // Located in the loop header + BasicBlock *Header = getHeader(); + if (AuxIndVar.getParent() != Header) + return false; + + // No uses outside of the loop + for (User *U : AuxIndVar.users()) + if (const Instruction *I = dyn_cast<Instruction>(U)) + if (!contains(I)) + return false; + + InductionDescriptor IndDesc; + if (!InductionDescriptor::isInductionPHI(&AuxIndVar, this, &SE, IndDesc)) + return false; + + // The step instruction opcode should be add or sub. + if (IndDesc.getInductionOpcode() != Instruction::Add && + IndDesc.getInductionOpcode() != Instruction::Sub) + return false; + + // Incremented by a loop invariant step for each loop iteration + return SE.isLoopInvariant(IndDesc.getStep(), this); +} + +BranchInst *Loop::getLoopGuardBranch() const { + if (!isLoopSimplifyForm()) + return nullptr; + + BasicBlock *Preheader = getLoopPreheader(); + BasicBlock *Latch = getLoopLatch(); + assert(Preheader && Latch && + "Expecting a loop with valid preheader and latch"); + + // Loop should be in rotate form. + if (!isLoopExiting(Latch)) + return nullptr; + + // Disallow loops with more than one unique exit block, as we do not verify + // that GuardOtherSucc post dominates all exit blocks. + BasicBlock *ExitFromLatch = getUniqueExitBlock(); + if (!ExitFromLatch) + return nullptr; + + BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor(); + if (!ExitFromLatchSucc) + return nullptr; + + BasicBlock *GuardBB = Preheader->getUniquePredecessor(); + if (!GuardBB) + return nullptr; + + assert(GuardBB->getTerminator() && "Expecting valid guard terminator"); + + BranchInst *GuardBI = dyn_cast<BranchInst>(GuardBB->getTerminator()); + if (!GuardBI || GuardBI->isUnconditional()) + return nullptr; + + BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader) + ? GuardBI->getSuccessor(1) + : GuardBI->getSuccessor(0); + return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr; +} + +bool Loop::isCanonical(ScalarEvolution &SE) const { + InductionDescriptor IndDesc; + if (!getInductionDescriptor(SE, IndDesc)) + return false; + + ConstantInt *Init = dyn_cast_or_null<ConstantInt>(IndDesc.getStartValue()); + if (!Init || !Init->isZero()) + return false; + + if (IndDesc.getInductionOpcode() != Instruction::Add) + return false; + + ConstantInt *Step = IndDesc.getConstIntStepValue(); + if (!Step || !Step->isOne()) + return false; + + return true; +} + +// Check that 'BB' doesn't have any uses outside of the 'L' +static bool isBlockInLCSSAForm(const Loop &L, const BasicBlock &BB, + DominatorTree &DT) { + for (const Instruction &I : BB) { + // Tokens can't be used in PHI nodes and live-out tokens prevent loop + // optimizations, so for the purposes of considered LCSSA form, we + // can ignore them. + if (I.getType()->isTokenTy()) + continue; + + for (const Use &U : I.uses()) { + const Instruction *UI = cast<Instruction>(U.getUser()); + const BasicBlock *UserBB = UI->getParent(); + if (const PHINode *P = dyn_cast<PHINode>(UI)) + UserBB = P->getIncomingBlock(U); + + // Check the current block, as a fast-path, before checking whether + // the use is anywhere in the loop. Most values are used in the same + // block they are defined in. Also, blocks not reachable from the + // entry are special; uses in them don't need to go through PHIs. + if (UserBB != &BB && !L.contains(UserBB) && + DT.isReachableFromEntry(UserBB)) + return false; + } + } + return true; +} + +bool Loop::isLCSSAForm(DominatorTree &DT) const { + // For each block we check that it doesn't have any uses outside of this loop. + return all_of(this->blocks(), [&](const BasicBlock *BB) { + return isBlockInLCSSAForm(*this, *BB, DT); + }); +} + +bool Loop::isRecursivelyLCSSAForm(DominatorTree &DT, const LoopInfo &LI) const { + // For each block we check that it doesn't have any uses outside of its + // innermost loop. This process will transitively guarantee that the current + // loop and all of the nested loops are in LCSSA form. + return all_of(this->blocks(), [&](const BasicBlock *BB) { + return isBlockInLCSSAForm(*LI.getLoopFor(BB), *BB, DT); + }); +} + +bool Loop::isLoopSimplifyForm() const { + // Normal-form loops have a preheader, a single backedge, and all of their + // exits have all their predecessors inside the loop. + return getLoopPreheader() && getLoopLatch() && hasDedicatedExits(); +} + +// Routines that reform the loop CFG and split edges often fail on indirectbr. +bool Loop::isSafeToClone() const { + // Return false if any loop blocks contain indirectbrs, or there are any calls + // to noduplicate functions. + // FIXME: it should be ok to clone CallBrInst's if we correctly update the + // operand list to reflect the newly cloned labels. + for (BasicBlock *BB : this->blocks()) { + if (isa<IndirectBrInst>(BB->getTerminator()) || + isa<CallBrInst>(BB->getTerminator())) + return false; + + for (Instruction &I : *BB) + if (auto CS = CallSite(&I)) + if (CS.cannotDuplicate()) + return false; + } + return true; +} + +MDNode *Loop::getLoopID() const { + MDNode *LoopID = nullptr; + + // Go through the latch blocks and check the terminator for the metadata. + SmallVector<BasicBlock *, 4> LatchesBlocks; + getLoopLatches(LatchesBlocks); + for (BasicBlock *BB : LatchesBlocks) { + Instruction *TI = BB->getTerminator(); + MDNode *MD = TI->getMetadata(LLVMContext::MD_loop); + + if (!MD) + return nullptr; + + if (!LoopID) + LoopID = MD; + else if (MD != LoopID) + return nullptr; + } + if (!LoopID || LoopID->getNumOperands() == 0 || + LoopID->getOperand(0) != LoopID) + return nullptr; + return LoopID; +} + +void Loop::setLoopID(MDNode *LoopID) const { + assert((!LoopID || LoopID->getNumOperands() > 0) && + "Loop ID needs at least one operand"); + assert((!LoopID || LoopID->getOperand(0) == LoopID) && + "Loop ID should refer to itself"); + + SmallVector<BasicBlock *, 4> LoopLatches; + getLoopLatches(LoopLatches); + for (BasicBlock *BB : LoopLatches) + BB->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); +} + +void Loop::setLoopAlreadyUnrolled() { + LLVMContext &Context = getHeader()->getContext(); + + MDNode *DisableUnrollMD = + MDNode::get(Context, MDString::get(Context, "llvm.loop.unroll.disable")); + MDNode *LoopID = getLoopID(); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, LoopID, {"llvm.loop.unroll."}, {DisableUnrollMD}); + setLoopID(NewLoopID); +} + +bool Loop::isAnnotatedParallel() const { + MDNode *DesiredLoopIdMetadata = getLoopID(); + + if (!DesiredLoopIdMetadata) + return false; + + MDNode *ParallelAccesses = + findOptionMDForLoop(this, "llvm.loop.parallel_accesses"); + SmallPtrSet<MDNode *, 4> + ParallelAccessGroups; // For scalable 'contains' check. + if (ParallelAccesses) { + for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) { + MDNode *AccGroup = cast<MDNode>(MD.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + ParallelAccessGroups.insert(AccGroup); + } + } + + // The loop branch contains the parallel loop metadata. In order to ensure + // that any parallel-loop-unaware optimization pass hasn't added loop-carried + // dependencies (thus converted the loop back to a sequential loop), check + // that all the memory instructions in the loop belong to an access group that + // is parallel to this loop. + for (BasicBlock *BB : this->blocks()) { + for (Instruction &I : *BB) { + if (!I.mayReadOrWriteMemory()) + continue; + + if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) { + auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool { + if (AG->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AG) && "Item must be an access group"); + return ParallelAccessGroups.count(AG); + } + + for (const MDOperand &AccessListItem : AG->operands()) { + MDNode *AccGroup = cast<MDNode>(AccessListItem.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + if (ParallelAccessGroups.count(AccGroup)) + return true; + } + return false; + }; + + if (ContainsAccessGroup(AccessGroup)) + continue; + } + + // The memory instruction can refer to the loop identifier metadata + // directly or indirectly through another list metadata (in case of + // nested parallel loops). The loop identifier metadata refers to + // itself so we can check both cases with the same routine. + MDNode *LoopIdMD = + I.getMetadata(LLVMContext::MD_mem_parallel_loop_access); + + if (!LoopIdMD) + return false; + + bool LoopIdMDFound = false; + for (const MDOperand &MDOp : LoopIdMD->operands()) { + if (MDOp == DesiredLoopIdMetadata) { + LoopIdMDFound = true; + break; + } + } + + if (!LoopIdMDFound) + return false; + } + } + return true; +} + +DebugLoc Loop::getStartLoc() const { return getLocRange().getStart(); } + +Loop::LocRange Loop::getLocRange() const { + // If we have a debug location in the loop ID, then use it. + if (MDNode *LoopID = getLoopID()) { + DebugLoc Start; + // We use the first DebugLoc in the header as the start location of the loop + // and if there is a second DebugLoc in the header we use it as end location + // of the loop. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + if (DILocation *L = dyn_cast<DILocation>(LoopID->getOperand(i))) { + if (!Start) + Start = DebugLoc(L); + else + return LocRange(Start, DebugLoc(L)); + } + } + + if (Start) + return LocRange(Start); + } + + // Try the pre-header first. + if (BasicBlock *PHeadBB = getLoopPreheader()) + if (DebugLoc DL = PHeadBB->getTerminator()->getDebugLoc()) + return LocRange(DL); + + // If we have no pre-header or there are no instructions with debug + // info in it, try the header. + if (BasicBlock *HeadBB = getHeader()) + return LocRange(HeadBB->getTerminator()->getDebugLoc()); + + return LocRange(); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Loop::dump() const { print(dbgs()); } + +LLVM_DUMP_METHOD void Loop::dumpVerbose() const { + print(dbgs(), /*Depth=*/0, /*Verbose=*/true); +} +#endif + +//===----------------------------------------------------------------------===// +// UnloopUpdater implementation +// + +namespace { +/// Find the new parent loop for all blocks within the "unloop" whose last +/// backedges has just been removed. +class UnloopUpdater { + Loop &Unloop; + LoopInfo *LI; + + LoopBlocksDFS DFS; + + // Map unloop's immediate subloops to their nearest reachable parents. Nested + // loops within these subloops will not change parents. However, an immediate + // subloop's new parent will be the nearest loop reachable from either its own + // exits *or* any of its nested loop's exits. + DenseMap<Loop *, Loop *> SubloopParents; + + // Flag the presence of an irreducible backedge whose destination is a block + // directly contained by the original unloop. + bool FoundIB; + +public: + UnloopUpdater(Loop *UL, LoopInfo *LInfo) + : Unloop(*UL), LI(LInfo), DFS(UL), FoundIB(false) {} + + void updateBlockParents(); + + void removeBlocksFromAncestors(); + + void updateSubloopParents(); + +protected: + Loop *getNearestLoop(BasicBlock *BB, Loop *BBLoop); +}; +} // end anonymous namespace + +/// Update the parent loop for all blocks that are directly contained within the +/// original "unloop". +void UnloopUpdater::updateBlockParents() { + if (Unloop.getNumBlocks()) { + // Perform a post order CFG traversal of all blocks within this loop, + // propagating the nearest loop from successors to predecessors. + LoopBlocksTraversal Traversal(DFS, LI); + for (BasicBlock *POI : Traversal) { + + Loop *L = LI->getLoopFor(POI); + Loop *NL = getNearestLoop(POI, L); + + if (NL != L) { + // For reducible loops, NL is now an ancestor of Unloop. + assert((NL != &Unloop && (!NL || NL->contains(&Unloop))) && + "uninitialized successor"); + LI->changeLoopFor(POI, NL); + } else { + // Or the current block is part of a subloop, in which case its parent + // is unchanged. + assert((FoundIB || Unloop.contains(L)) && "uninitialized successor"); + } + } + } + // Each irreducible loop within the unloop induces a round of iteration using + // the DFS result cached by Traversal. + bool Changed = FoundIB; + for (unsigned NIters = 0; Changed; ++NIters) { + assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm"); + + // Iterate over the postorder list of blocks, propagating the nearest loop + // from successors to predecessors as before. + Changed = false; + for (LoopBlocksDFS::POIterator POI = DFS.beginPostorder(), + POE = DFS.endPostorder(); + POI != POE; ++POI) { + + Loop *L = LI->getLoopFor(*POI); + Loop *NL = getNearestLoop(*POI, L); + if (NL != L) { + assert(NL != &Unloop && (!NL || NL->contains(&Unloop)) && + "uninitialized successor"); + LI->changeLoopFor(*POI, NL); + Changed = true; + } + } + } +} + +/// Remove unloop's blocks from all ancestors below their new parents. +void UnloopUpdater::removeBlocksFromAncestors() { + // Remove all unloop's blocks (including those in nested subloops) from + // ancestors below the new parent loop. + for (Loop::block_iterator BI = Unloop.block_begin(), BE = Unloop.block_end(); + BI != BE; ++BI) { + Loop *OuterParent = LI->getLoopFor(*BI); + if (Unloop.contains(OuterParent)) { + while (OuterParent->getParentLoop() != &Unloop) + OuterParent = OuterParent->getParentLoop(); + OuterParent = SubloopParents[OuterParent]; + } + // Remove blocks from former Ancestors except Unloop itself which will be + // deleted. + for (Loop *OldParent = Unloop.getParentLoop(); OldParent != OuterParent; + OldParent = OldParent->getParentLoop()) { + assert(OldParent && "new loop is not an ancestor of the original"); + OldParent->removeBlockFromLoop(*BI); + } + } +} + +/// Update the parent loop for all subloops directly nested within unloop. +void UnloopUpdater::updateSubloopParents() { + while (!Unloop.empty()) { + Loop *Subloop = *std::prev(Unloop.end()); + Unloop.removeChildLoop(std::prev(Unloop.end())); + + assert(SubloopParents.count(Subloop) && "DFS failed to visit subloop"); + if (Loop *Parent = SubloopParents[Subloop]) + Parent->addChildLoop(Subloop); + else + LI->addTopLevelLoop(Subloop); + } +} + +/// Return the nearest parent loop among this block's successors. If a successor +/// is a subloop header, consider its parent to be the nearest parent of the +/// subloop's exits. +/// +/// For subloop blocks, simply update SubloopParents and return NULL. +Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) { + + // Initially for blocks directly contained by Unloop, NearLoop == Unloop and + // is considered uninitialized. + Loop *NearLoop = BBLoop; + + Loop *Subloop = nullptr; + if (NearLoop != &Unloop && Unloop.contains(NearLoop)) { + Subloop = NearLoop; + // Find the subloop ancestor that is directly contained within Unloop. + while (Subloop->getParentLoop() != &Unloop) { + Subloop = Subloop->getParentLoop(); + assert(Subloop && "subloop is not an ancestor of the original loop"); + } + // Get the current nearest parent of the Subloop exits, initially Unloop. + NearLoop = SubloopParents.insert({Subloop, &Unloop}).first->second; + } + + succ_iterator I = succ_begin(BB), E = succ_end(BB); + if (I == E) { + assert(!Subloop && "subloop blocks must have a successor"); + NearLoop = nullptr; // unloop blocks may now exit the function. + } + for (; I != E; ++I) { + if (*I == BB) + continue; // self loops are uninteresting + + Loop *L = LI->getLoopFor(*I); + if (L == &Unloop) { + // This successor has not been processed. This path must lead to an + // irreducible backedge. + assert((FoundIB || !DFS.hasPostorder(*I)) && "should have seen IB"); + FoundIB = true; + } + if (L != &Unloop && Unloop.contains(L)) { + // Successor is in a subloop. + if (Subloop) + continue; // Branching within subloops. Ignore it. + + // BB branches from the original into a subloop header. + assert(L->getParentLoop() == &Unloop && "cannot skip into nested loops"); + + // Get the current nearest parent of the Subloop's exits. + L = SubloopParents[L]; + // L could be Unloop if the only exit was an irreducible backedge. + } + if (L == &Unloop) { + continue; + } + // Handle critical edges from Unloop into a sibling loop. + if (L && !L->contains(&Unloop)) { + L = L->getParentLoop(); + } + // Remember the nearest parent loop among successors or subloop exits. + if (NearLoop == &Unloop || !NearLoop || NearLoop->contains(L)) + NearLoop = L; + } + if (Subloop) { + SubloopParents[Subloop] = NearLoop; + return BBLoop; + } + return NearLoop; +} + +LoopInfo::LoopInfo(const DomTreeBase<BasicBlock> &DomTree) { analyze(DomTree); } + +bool LoopInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<LoopAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + +void LoopInfo::erase(Loop *Unloop) { + assert(!Unloop->isInvalid() && "Loop has already been erased!"); + + auto InvalidateOnExit = make_scope_exit([&]() { destroy(Unloop); }); + + // First handle the special case of no parent loop to simplify the algorithm. + if (!Unloop->getParentLoop()) { + // Since BBLoop had no parent, Unloop blocks are no longer in a loop. + for (Loop::block_iterator I = Unloop->block_begin(), + E = Unloop->block_end(); + I != E; ++I) { + + // Don't reparent blocks in subloops. + if (getLoopFor(*I) != Unloop) + continue; + + // Blocks no longer have a parent but are still referenced by Unloop until + // the Unloop object is deleted. + changeLoopFor(*I, nullptr); + } + + // Remove the loop from the top-level LoopInfo object. + for (iterator I = begin();; ++I) { + assert(I != end() && "Couldn't find loop"); + if (*I == Unloop) { + removeLoop(I); + break; + } + } + + // Move all of the subloops to the top-level. + while (!Unloop->empty()) + addTopLevelLoop(Unloop->removeChildLoop(std::prev(Unloop->end()))); + + return; + } + + // Update the parent loop for all blocks within the loop. Blocks within + // subloops will not change parents. + UnloopUpdater Updater(Unloop, this); + Updater.updateBlockParents(); + + // Remove blocks from former ancestor loops. + Updater.removeBlocksFromAncestors(); + + // Add direct subloops as children in their new parent loop. + Updater.updateSubloopParents(); + + // Remove unloop from its parent loop. + Loop *ParentLoop = Unloop->getParentLoop(); + for (Loop::iterator I = ParentLoop->begin();; ++I) { + assert(I != ParentLoop->end() && "Couldn't find loop"); + if (*I == Unloop) { + ParentLoop->removeChildLoop(I); + break; + } + } +} + +AnalysisKey LoopAnalysis::Key; + +LoopInfo LoopAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + // FIXME: Currently we create a LoopInfo from scratch for every function. + // This may prove to be too wasteful due to deallocating and re-allocating + // memory each time for the underlying map and vector datastructures. At some + // point it may prove worthwhile to use a freelist and recycle LoopInfo + // objects. I don't want to add that kind of complexity until the scope of + // the problem is better understood. + LoopInfo LI; + LI.analyze(AM.getResult<DominatorTreeAnalysis>(F)); + return LI; +} + +PreservedAnalyses LoopPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<LoopAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} + +void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { + + if (forcePrintModuleIR()) { + // handling -print-module-scope + OS << Banner << " (loop: "; + L.getHeader()->printAsOperand(OS, false); + OS << ")\n"; + + // printing whole module + OS << *L.getHeader()->getModule(); + return; + } + + OS << Banner; + + auto *PreHeader = L.getLoopPreheader(); + if (PreHeader) { + OS << "\n; Preheader:"; + PreHeader->print(OS); + OS << "\n; Loop:"; + } + + for (auto *Block : L.blocks()) + if (Block) + Block->print(OS); + else + OS << "Printing <null> block"; + + SmallVector<BasicBlock *, 8> ExitBlocks; + L.getExitBlocks(ExitBlocks); + if (!ExitBlocks.empty()) { + OS << "\n; Exit blocks"; + for (auto *Block : ExitBlocks) + if (Block) + Block->print(OS); + else + OS << "Printing <null> block"; + } +} + +MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) { + // No loop metadata node, no loop properties. + if (!LoopID) + return nullptr; + + // First operand should refer to the metadata node itself, for legacy reasons. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + // Iterate over the metdata node operands and look for MDString metadata. + for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (!MD || MD->getNumOperands() < 1) + continue; + MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (!S) + continue; + // Return the operand node if MDString holds expected metadata. + if (Name.equals(S->getString())) + return MD; + } + + // Loop property not found. + return nullptr; +} + +MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) { + return findOptionMDForLoopID(TheLoop->getLoopID(), Name); +} + +bool llvm::isValidAsAccessGroup(MDNode *Node) { + return Node->getNumOperands() == 0 && Node->isDistinct(); +} + +MDNode *llvm::makePostTransformationMetadata(LLVMContext &Context, + MDNode *OrigLoopID, + ArrayRef<StringRef> RemovePrefixes, + ArrayRef<MDNode *> AddAttrs) { + // First remove any existing loop metadata related to this transformation. + SmallVector<Metadata *, 4> MDs; + + // Reserve first location for self reference to the LoopID metadata node. + TempMDTuple TempNode = MDNode::getTemporary(Context, None); + MDs.push_back(TempNode.get()); + + // Remove metadata for the transformation that has been applied or that became + // outdated. + if (OrigLoopID) { + for (unsigned i = 1, ie = OrigLoopID->getNumOperands(); i < ie; ++i) { + bool IsVectorMetadata = false; + Metadata *Op = OrigLoopID->getOperand(i); + if (MDNode *MD = dyn_cast<MDNode>(Op)) { + const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (S) + IsVectorMetadata = + llvm::any_of(RemovePrefixes, [S](StringRef Prefix) -> bool { + return S->getString().startswith(Prefix); + }); + } + if (!IsVectorMetadata) + MDs.push_back(Op); + } + } + + // Add metadata to avoid reapplying a transformation, such as + // llvm.loop.unroll.disable and llvm.loop.isvectorized. + MDs.append(AddAttrs.begin(), AddAttrs.end()); + + MDNode *NewLoopID = MDNode::getDistinct(Context, MDs); + // Replace the temporary node with a self-reference. + NewLoopID->replaceOperandWith(0, NewLoopID); + return NewLoopID; +} + +//===----------------------------------------------------------------------===// +// LoopInfo implementation +// + +char LoopInfoWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopInfoWrapperPass, "loops", "Natural Loop Information", + true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(LoopInfoWrapperPass, "loops", "Natural Loop Information", + true, true) + +bool LoopInfoWrapperPass::runOnFunction(Function &) { + releaseMemory(); + LI.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); + return false; +} + +void LoopInfoWrapperPass::verifyAnalysis() const { + // LoopInfoWrapperPass is a FunctionPass, but verifying every loop in the + // function each time verifyAnalysis is called is very expensive. The + // -verify-loop-info option can enable this. In order to perform some + // checking by default, LoopPass has been taught to call verifyLoop manually + // during loop pass sequences. + if (VerifyLoopInfo) { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LI.verify(DT); + } +} + +void LoopInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); +} + +void LoopInfoWrapperPass::print(raw_ostream &OS, const Module *) const { + LI.print(OS); +} + +PreservedAnalyses LoopVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + LoopInfo &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + LI.verify(DT); + return PreservedAnalyses::all(); +} + +//===----------------------------------------------------------------------===// +// LoopBlocksDFS implementation +// + +/// Traverse the loop blocks and store the DFS result. +/// Useful for clients that just want the final DFS result and don't need to +/// visit blocks during the initial traversal. +void LoopBlocksDFS::perform(LoopInfo *LI) { + LoopBlocksTraversal Traversal(*this, LI); + for (LoopBlocksTraversal::POTIterator POI = Traversal.begin(), + POE = Traversal.end(); + POI != POE; ++POI) + ; +} diff --git a/llvm/lib/Analysis/LoopPass.cpp b/llvm/lib/Analysis/LoopPass.cpp new file mode 100644 index 000000000000..4ab3798039d8 --- /dev/null +++ b/llvm/lib/Analysis/LoopPass.cpp @@ -0,0 +1,414 @@ +//===- LoopPass.cpp - Loop Pass and Loop Pass Manager ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements LoopPass and LPPassManager. All loop optimization +// and transformation passes are derived from LoopPass. LPPassManager is +// responsible for managing LoopPasses. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PassTimingInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/TimeProfiler.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-pass-manager" + +namespace { + +/// PrintLoopPass - Print a Function corresponding to a Loop. +/// +class PrintLoopPassWrapper : public LoopPass { + raw_ostream &OS; + std::string Banner; + +public: + static char ID; + PrintLoopPassWrapper() : LoopPass(ID), OS(dbgs()) {} + PrintLoopPassWrapper(raw_ostream &OS, const std::string &Banner) + : LoopPass(ID), OS(OS), Banner(Banner) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + bool runOnLoop(Loop *L, LPPassManager &) override { + auto BBI = llvm::find_if(L->blocks(), [](BasicBlock *BB) { return BB; }); + if (BBI != L->blocks().end() && + isFunctionInPrintList((*BBI)->getParent()->getName())) { + printLoop(*L, OS, Banner); + } + return false; + } + + StringRef getPassName() const override { return "Print Loop IR"; } +}; + +char PrintLoopPassWrapper::ID = 0; +} + +//===----------------------------------------------------------------------===// +// LPPassManager +// + +char LPPassManager::ID = 0; + +LPPassManager::LPPassManager() + : FunctionPass(ID), PMDataManager() { + LI = nullptr; + CurrentLoop = nullptr; +} + +// Insert loop into loop nest (LoopInfo) and loop queue (LQ). +void LPPassManager::addLoop(Loop &L) { + if (!L.getParentLoop()) { + // This is the top level loop. + LQ.push_front(&L); + return; + } + + // Insert L into the loop queue after the parent loop. + for (auto I = LQ.begin(), E = LQ.end(); I != E; ++I) { + if (*I == L.getParentLoop()) { + // deque does not support insert after. + ++I; + LQ.insert(I, 1, &L); + return; + } + } +} + +/// cloneBasicBlockSimpleAnalysis - Invoke cloneBasicBlockAnalysis hook for +/// all loop passes. +void LPPassManager::cloneBasicBlockSimpleAnalysis(BasicBlock *From, + BasicBlock *To, Loop *L) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *LP = getContainedPass(Index); + LP->cloneBasicBlockAnalysis(From, To, L); + } +} + +/// deleteSimpleAnalysisValue - Invoke deleteAnalysisValue hook for all passes. +void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { + if (BasicBlock *BB = dyn_cast<BasicBlock>(V)) { + for (Instruction &I : *BB) { + deleteSimpleAnalysisValue(&I, L); + } + } + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *LP = getContainedPass(Index); + LP->deleteAnalysisValue(V, L); + } +} + +/// Invoke deleteAnalysisLoop hook for all passes. +void LPPassManager::deleteSimpleAnalysisLoop(Loop *L) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *LP = getContainedPass(Index); + LP->deleteAnalysisLoop(L); + } +} + + +// Recurse through all subloops and all loops into LQ. +static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) { + LQ.push_back(L); + for (Loop *I : reverse(*L)) + addLoopIntoQueue(I, LQ); +} + +/// Pass Manager itself does not invalidate any analysis info. +void LPPassManager::getAnalysisUsage(AnalysisUsage &Info) const { + // LPPassManager needs LoopInfo. In the long term LoopInfo class will + // become part of LPPassManager. + Info.addRequired<LoopInfoWrapperPass>(); + Info.addRequired<DominatorTreeWrapperPass>(); + Info.setPreservesAll(); +} + +void LPPassManager::markLoopAsDeleted(Loop &L) { + assert((&L == CurrentLoop || CurrentLoop->contains(&L)) && + "Must not delete loop outside the current loop tree!"); + // If this loop appears elsewhere within the queue, we also need to remove it + // there. However, we have to be careful to not remove the back of the queue + // as that is assumed to match the current loop. + assert(LQ.back() == CurrentLoop && "Loop queue back isn't the current loop!"); + LQ.erase(std::remove(LQ.begin(), LQ.end(), &L), LQ.end()); + + if (&L == CurrentLoop) { + CurrentLoopDeleted = true; + // Add this loop back onto the back of the queue to preserve our invariants. + LQ.push_back(&L); + } +} + +/// run - Execute all of the passes scheduled for execution. Keep track of +/// whether any of the passes modifies the function, and if so, return true. +bool LPPassManager::runOnFunction(Function &F) { + auto &LIWP = getAnalysis<LoopInfoWrapperPass>(); + LI = &LIWP.getLoopInfo(); + Module &M = *F.getParent(); +#if 0 + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +#endif + bool Changed = false; + + // Collect inherited analysis from Module level pass manager. + populateInheritedAnalysis(TPM->activeStack); + + // Populate the loop queue in reverse program order. There is no clear need to + // process sibling loops in either forward or reverse order. There may be some + // advantage in deleting uses in a later loop before optimizing the + // definitions in an earlier loop. If we find a clear reason to process in + // forward order, then a forward variant of LoopPassManager should be created. + // + // Note that LoopInfo::iterator visits loops in reverse program + // order. Here, reverse_iterator gives us a forward order, and the LoopQueue + // reverses the order a third time by popping from the back. + for (Loop *L : reverse(*LI)) + addLoopIntoQueue(L, LQ); + + if (LQ.empty()) // No loops, skip calling finalizers + return false; + + // Initialization + for (Loop *L : LQ) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *P = getContainedPass(Index); + Changed |= P->doInitialization(L, *this); + } + } + + // Walk Loops + unsigned InstrCount, FunctionSize = 0; + StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount; + bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); + // Collect the initial size of the module and the function we're looking at. + if (EmitICRemark) { + InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount); + FunctionSize = F.getInstructionCount(); + } + while (!LQ.empty()) { + CurrentLoopDeleted = false; + CurrentLoop = LQ.back(); + + // Run all passes on the current Loop. + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *P = getContainedPass(Index); + + llvm::TimeTraceScope LoopPassScope("RunLoopPass", P->getPassName()); + + dumpPassInfo(P, EXECUTION_MSG, ON_LOOP_MSG, + CurrentLoop->getHeader()->getName()); + dumpRequiredSet(P); + + initializeAnalysisImpl(P); + + bool LocalChanged = false; + { + PassManagerPrettyStackEntry X(P, *CurrentLoop->getHeader()); + TimeRegion PassTimer(getPassTimer(P)); + LocalChanged = P->runOnLoop(CurrentLoop, *this); + Changed |= LocalChanged; + if (EmitICRemark) { + unsigned NewSize = F.getInstructionCount(); + // Update the size of the function, emit a remark, and update the + // size of the module. + if (NewSize != FunctionSize) { + int64_t Delta = static_cast<int64_t>(NewSize) - + static_cast<int64_t>(FunctionSize); + emitInstrCountChangedRemark(P, M, Delta, InstrCount, + FunctionToInstrCount, &F); + InstrCount = static_cast<int64_t>(InstrCount) + Delta; + FunctionSize = NewSize; + } + } + } + + if (LocalChanged) + dumpPassInfo(P, MODIFICATION_MSG, ON_LOOP_MSG, + CurrentLoopDeleted ? "<deleted loop>" + : CurrentLoop->getName()); + dumpPreservedSet(P); + + if (CurrentLoopDeleted) { + // Notify passes that the loop is being deleted. + deleteSimpleAnalysisLoop(CurrentLoop); + } else { + // Manually check that this loop is still healthy. This is done + // instead of relying on LoopInfo::verifyLoop since LoopInfo + // is a function pass and it's really expensive to verify every + // loop in the function every time. That level of checking can be + // enabled with the -verify-loop-info option. + { + TimeRegion PassTimer(getPassTimer(&LIWP)); + CurrentLoop->verifyLoop(); + } + // Here we apply same reasoning as in the above case. Only difference + // is that LPPassManager might run passes which do not require LCSSA + // form (LoopPassPrinter for example). We should skip verification for + // such passes. + // FIXME: Loop-sink currently break LCSSA. Fix it and reenable the + // verification! +#if 0 + if (mustPreserveAnalysisID(LCSSAVerificationPass::ID)) + assert(CurrentLoop->isRecursivelyLCSSAForm(*DT, *LI)); +#endif + + // Then call the regular verifyAnalysis functions. + verifyPreservedAnalysis(P); + + F.getContext().yield(); + } + + removeNotPreservedAnalysis(P); + recordAvailableAnalysis(P); + removeDeadPasses(P, + CurrentLoopDeleted ? "<deleted>" + : CurrentLoop->getHeader()->getName(), + ON_LOOP_MSG); + + if (CurrentLoopDeleted) + // Do not run other passes on this loop. + break; + } + + // If the loop was deleted, release all the loop passes. This frees up + // some memory, and avoids trouble with the pass manager trying to call + // verifyAnalysis on them. + if (CurrentLoopDeleted) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + freePass(P, "<deleted>", ON_LOOP_MSG); + } + } + + // Pop the loop from queue after running all passes. + LQ.pop_back(); + } + + // Finalization + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *P = getContainedPass(Index); + Changed |= P->doFinalization(); + } + + return Changed; +} + +/// Print passes managed by this manager +void LPPassManager::dumpPassStructure(unsigned Offset) { + errs().indent(Offset*2) << "Loop Pass Manager\n"; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + P->dumpPassStructure(Offset + 1); + dumpLastUses(P, Offset+1); + } +} + + +//===----------------------------------------------------------------------===// +// LoopPass + +Pass *LoopPass::createPrinterPass(raw_ostream &O, + const std::string &Banner) const { + return new PrintLoopPassWrapper(O, Banner); +} + +// Check if this pass is suitable for the current LPPassManager, if +// available. This pass P is not suitable for a LPPassManager if P +// is not preserving higher level analysis info used by other +// LPPassManager passes. In such case, pop LPPassManager from the +// stack. This will force assignPassManager() to create new +// LPPassManger as expected. +void LoopPass::preparePassManager(PMStack &PMS) { + + // Find LPPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_LoopPassManager) + PMS.pop(); + + // If this pass is destroying high level information that is used + // by other passes that are managed by LPM then do not insert + // this pass in current LPM. Use new LPPassManager. + if (PMS.top()->getPassManagerType() == PMT_LoopPassManager && + !PMS.top()->preserveHigherLevelAnalysis(this)) + PMS.pop(); +} + +/// Assign pass manager to manage this pass. +void LoopPass::assignPassManager(PMStack &PMS, + PassManagerType PreferredType) { + // Find LPPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_LoopPassManager) + PMS.pop(); + + LPPassManager *LPPM; + if (PMS.top()->getPassManagerType() == PMT_LoopPassManager) + LPPM = (LPPassManager*)PMS.top(); + else { + // Create new Loop Pass Manager if it does not exist. + assert (!PMS.empty() && "Unable to create Loop Pass Manager"); + PMDataManager *PMD = PMS.top(); + + // [1] Create new Loop Pass Manager + LPPM = new LPPassManager(); + LPPM->populateInheritedAnalysis(PMS); + + // [2] Set up new manager's top level manager + PMTopLevelManager *TPM = PMD->getTopLevelManager(); + TPM->addIndirectPassManager(LPPM); + + // [3] Assign manager to manage this new manager. This may create + // and push new managers into PMS + Pass *P = LPPM->getAsPass(); + TPM->schedulePass(P); + + // [4] Push new manager into PMS + PMS.push(LPPM); + } + + LPPM->add(this); +} + +static std::string getDescription(const Loop &L) { + return "loop"; +} + +bool LoopPass::skipLoop(const Loop *L) const { + const Function *F = L->getHeader()->getParent(); + if (!F) + return false; + // Check the opt bisect limit. + OptPassGate &Gate = F->getContext().getOptPassGate(); + if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(*L))) + return true; + // Check for the OptimizeNone attribute. + if (F->hasOptNone()) { + // FIXME: Report this to dbgs() only once per function. + LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName() << "' in function " + << F->getName() << "\n"); + // FIXME: Delete loop from pass manager's queue? + return true; + } + return false; +} + +char LCSSAVerificationPass::ID = 0; +INITIALIZE_PASS(LCSSAVerificationPass, "lcssa-verification", "LCSSA Verifier", + false, false) diff --git a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp new file mode 100644 index 000000000000..762623de41e9 --- /dev/null +++ b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp @@ -0,0 +1,214 @@ +//===- LoopUnrollAnalyzer.cpp - Unrolling Effect Estimation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements UnrolledInstAnalyzer class. It's used for predicting +// potential effects that loop unrolling might have, such as enabling constant +// propagation and other optimizations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopUnrollAnalyzer.h" + +using namespace llvm; + +/// Try to simplify instruction \param I using its SCEV expression. +/// +/// The idea is that some AddRec expressions become constants, which then +/// could trigger folding of other instructions. However, that only happens +/// for expressions whose start value is also constant, which isn't always the +/// case. In another common and important case the start value is just some +/// address (i.e. SCEVUnknown) - in this case we compute the offset and save +/// it along with the base address instead. +bool UnrolledInstAnalyzer::simplifyInstWithSCEV(Instruction *I) { + if (!SE.isSCEVable(I->getType())) + return false; + + const SCEV *S = SE.getSCEV(I); + if (auto *SC = dyn_cast<SCEVConstant>(S)) { + SimplifiedValues[I] = SC->getValue(); + return true; + } + + auto *AR = dyn_cast<SCEVAddRecExpr>(S); + if (!AR || AR->getLoop() != L) + return false; + + const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE); + // Check if the AddRec expression becomes a constant. + if (auto *SC = dyn_cast<SCEVConstant>(ValueAtIteration)) { + SimplifiedValues[I] = SC->getValue(); + return true; + } + + // Check if the offset from the base address becomes a constant. + auto *Base = dyn_cast<SCEVUnknown>(SE.getPointerBase(S)); + if (!Base) + return false; + auto *Offset = + dyn_cast<SCEVConstant>(SE.getMinusSCEV(ValueAtIteration, Base)); + if (!Offset) + return false; + SimplifiedAddress Address; + Address.Base = Base->getValue(); + Address.Offset = Offset->getValue(); + SimplifiedAddresses[I] = Address; + return false; +} + +/// Try to simplify binary operator I. +/// +/// TODO: Probably it's worth to hoist the code for estimating the +/// simplifications effects to a separate class, since we have a very similar +/// code in InlineCost already. +bool UnrolledInstAnalyzer::visitBinaryOperator(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + if (!isa<Constant>(LHS)) + if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) + LHS = SimpleLHS; + if (!isa<Constant>(RHS)) + if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) + RHS = SimpleRHS; + + Value *SimpleV = nullptr; + const DataLayout &DL = I.getModule()->getDataLayout(); + if (auto FI = dyn_cast<FPMathOperator>(&I)) + SimpleV = + SimplifyBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); + else + SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); + + if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) + SimplifiedValues[&I] = C; + + if (SimpleV) + return true; + return Base::visitBinaryOperator(I); +} + +/// Try to fold load I. +bool UnrolledInstAnalyzer::visitLoad(LoadInst &I) { + Value *AddrOp = I.getPointerOperand(); + + auto AddressIt = SimplifiedAddresses.find(AddrOp); + if (AddressIt == SimplifiedAddresses.end()) + return false; + ConstantInt *SimplifiedAddrOp = AddressIt->second.Offset; + + auto *GV = dyn_cast<GlobalVariable>(AddressIt->second.Base); + // We're only interested in loads that can be completely folded to a + // constant. + if (!GV || !GV->hasDefinitiveInitializer() || !GV->isConstant()) + return false; + + ConstantDataSequential *CDS = + dyn_cast<ConstantDataSequential>(GV->getInitializer()); + if (!CDS) + return false; + + // We might have a vector load from an array. FIXME: for now we just bail + // out in this case, but we should be able to resolve and simplify such + // loads. + if (CDS->getElementType() != I.getType()) + return false; + + unsigned ElemSize = CDS->getElementType()->getPrimitiveSizeInBits() / 8U; + if (SimplifiedAddrOp->getValue().getActiveBits() > 64) + return false; + int64_t SimplifiedAddrOpV = SimplifiedAddrOp->getSExtValue(); + if (SimplifiedAddrOpV < 0) { + // FIXME: For now we conservatively ignore out of bound accesses, but + // we're allowed to perform the optimization in this case. + return false; + } + uint64_t Index = static_cast<uint64_t>(SimplifiedAddrOpV) / ElemSize; + if (Index >= CDS->getNumElements()) { + // FIXME: For now we conservatively ignore out of bound accesses, but + // we're allowed to perform the optimization in this case. + return false; + } + + Constant *CV = CDS->getElementAsConstant(Index); + assert(CV && "Constant expected."); + SimplifiedValues[&I] = CV; + + return true; +} + +/// Try to simplify cast instruction. +bool UnrolledInstAnalyzer::visitCastInst(CastInst &I) { + // Propagate constants through casts. + Constant *COp = dyn_cast<Constant>(I.getOperand(0)); + if (!COp) + COp = SimplifiedValues.lookup(I.getOperand(0)); + + // If we know a simplified value for this operand and cast is valid, save the + // result to SimplifiedValues. + // The cast can be invalid, because SimplifiedValues contains results of SCEV + // analysis, which operates on integers (and, e.g., might convert i8* null to + // i32 0). + if (COp && CastInst::castIsValid(I.getOpcode(), COp, I.getType())) { + if (Constant *C = + ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) { + SimplifiedValues[&I] = C; + return true; + } + } + + return Base::visitCastInst(I); +} + +/// Try to simplify cmp instruction. +bool UnrolledInstAnalyzer::visitCmpInst(CmpInst &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + // First try to handle simplified comparisons. + if (!isa<Constant>(LHS)) + if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) + LHS = SimpleLHS; + if (!isa<Constant>(RHS)) + if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) + RHS = SimpleRHS; + + if (!isa<Constant>(LHS) && !isa<Constant>(RHS)) { + auto SimplifiedLHS = SimplifiedAddresses.find(LHS); + if (SimplifiedLHS != SimplifiedAddresses.end()) { + auto SimplifiedRHS = SimplifiedAddresses.find(RHS); + if (SimplifiedRHS != SimplifiedAddresses.end()) { + SimplifiedAddress &LHSAddr = SimplifiedLHS->second; + SimplifiedAddress &RHSAddr = SimplifiedRHS->second; + if (LHSAddr.Base == RHSAddr.Base) { + LHS = LHSAddr.Offset; + RHS = RHSAddr.Offset; + } + } + } + } + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) { + if (CLHS->getType() == CRHS->getType()) { + if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { + SimplifiedValues[&I] = C; + return true; + } + } + } + } + + return Base::visitCmpInst(I); +} + +bool UnrolledInstAnalyzer::visitPHINode(PHINode &PN) { + // Run base visitor first. This way we can gather some useful for later + // analysis information. + if (Base::visitPHINode(PN)) + return true; + + // The loop induction PHI nodes are definitionally free. + return PN.getParent() == L->getHeader(); +} diff --git a/llvm/lib/Analysis/MemDepPrinter.cpp b/llvm/lib/Analysis/MemDepPrinter.cpp new file mode 100644 index 000000000000..6e1bb50e8893 --- /dev/null +++ b/llvm/lib/Analysis/MemDepPrinter.cpp @@ -0,0 +1,164 @@ +//===- MemDepPrinter.cpp - Printer for MemoryDependenceAnalysis -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { + struct MemDepPrinter : public FunctionPass { + const Function *F; + + enum DepType { + Clobber = 0, + Def, + NonFuncLocal, + Unknown + }; + + static const char *const DepTypeStr[]; + + typedef PointerIntPair<const Instruction *, 2, DepType> InstTypePair; + typedef std::pair<InstTypePair, const BasicBlock *> Dep; + typedef SmallSetVector<Dep, 4> DepSet; + typedef DenseMap<const Instruction *, DepSet> DepSetMap; + DepSetMap Deps; + + static char ID; // Pass identifcation, replacement for typeid + MemDepPrinter() : FunctionPass(ID) { + initializeMemDepPrinterPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void print(raw_ostream &OS, const Module * = nullptr) const override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredTransitive<AAResultsWrapperPass>(); + AU.addRequiredTransitive<MemoryDependenceWrapperPass>(); + AU.setPreservesAll(); + } + + void releaseMemory() override { + Deps.clear(); + F = nullptr; + } + + private: + static InstTypePair getInstTypePair(MemDepResult dep) { + if (dep.isClobber()) + return InstTypePair(dep.getInst(), Clobber); + if (dep.isDef()) + return InstTypePair(dep.getInst(), Def); + if (dep.isNonFuncLocal()) + return InstTypePair(dep.getInst(), NonFuncLocal); + assert(dep.isUnknown() && "unexpected dependence type"); + return InstTypePair(dep.getInst(), Unknown); + } + static InstTypePair getInstTypePair(const Instruction* inst, DepType type) { + return InstTypePair(inst, type); + } + }; +} + +char MemDepPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MemDepPrinter, "print-memdeps", + "Print MemDeps of function", false, true) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_END(MemDepPrinter, "print-memdeps", + "Print MemDeps of function", false, true) + +FunctionPass *llvm::createMemDepPrinter() { + return new MemDepPrinter(); +} + +const char *const MemDepPrinter::DepTypeStr[] + = {"Clobber", "Def", "NonFuncLocal", "Unknown"}; + +bool MemDepPrinter::runOnFunction(Function &F) { + this->F = &F; + MemoryDependenceResults &MDA = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + + // All this code uses non-const interfaces because MemDep is not + // const-friendly, though nothing is actually modified. + for (auto &I : instructions(F)) { + Instruction *Inst = &I; + + if (!Inst->mayReadFromMemory() && !Inst->mayWriteToMemory()) + continue; + + MemDepResult Res = MDA.getDependency(Inst); + if (!Res.isNonLocal()) { + Deps[Inst].insert(std::make_pair(getInstTypePair(Res), + static_cast<BasicBlock *>(nullptr))); + } else if (auto *Call = dyn_cast<CallBase>(Inst)) { + const MemoryDependenceResults::NonLocalDepInfo &NLDI = + MDA.getNonLocalCallDependency(Call); + + DepSet &InstDeps = Deps[Inst]; + for (const NonLocalDepEntry &I : NLDI) { + const MemDepResult &Res = I.getResult(); + InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); + } + } else { + SmallVector<NonLocalDepResult, 4> NLDI; + assert( (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || + isa<VAArgInst>(Inst)) && "Unknown memory instruction!"); + MDA.getNonLocalPointerDependency(Inst, NLDI); + + DepSet &InstDeps = Deps[Inst]; + for (const NonLocalDepResult &I : NLDI) { + const MemDepResult &Res = I.getResult(); + InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); + } + } + } + + return false; +} + +void MemDepPrinter::print(raw_ostream &OS, const Module *M) const { + for (const auto &I : instructions(*F)) { + const Instruction *Inst = &I; + + DepSetMap::const_iterator DI = Deps.find(Inst); + if (DI == Deps.end()) + continue; + + const DepSet &InstDeps = DI->second; + + for (const auto &I : InstDeps) { + const Instruction *DepInst = I.first.getPointer(); + DepType type = I.first.getInt(); + const BasicBlock *DepBB = I.second; + + OS << " "; + OS << DepTypeStr[type]; + if (DepBB) { + OS << " in block "; + DepBB->printAsOperand(OS, /*PrintType=*/false, M); + } + if (DepInst) { + OS << " from: "; + DepInst->print(OS); + } + OS << "\n"; + } + + Inst->print(OS); + OS << "\n\n"; + } +} diff --git a/llvm/lib/Analysis/MemDerefPrinter.cpp b/llvm/lib/Analysis/MemDerefPrinter.cpp new file mode 100644 index 000000000000..5cf516a538b5 --- /dev/null +++ b/llvm/lib/Analysis/MemDerefPrinter.cpp @@ -0,0 +1,76 @@ +//===- MemDerefPrinter.cpp - Printer for isDereferenceablePointer ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { + struct MemDerefPrinter : public FunctionPass { + SmallVector<Value *, 4> Deref; + SmallPtrSet<Value *, 4> DerefAndAligned; + + static char ID; // Pass identification, replacement for typeid + MemDerefPrinter() : FunctionPass(ID) { + initializeMemDerefPrinterPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + bool runOnFunction(Function &F) override; + void print(raw_ostream &OS, const Module * = nullptr) const override; + void releaseMemory() override { + Deref.clear(); + DerefAndAligned.clear(); + } + }; +} + +char MemDerefPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MemDerefPrinter, "print-memderefs", + "Memory Dereferenciblity of pointers in function", false, true) +INITIALIZE_PASS_END(MemDerefPrinter, "print-memderefs", + "Memory Dereferenciblity of pointers in function", false, true) + +FunctionPass *llvm::createMemDerefPrinter() { + return new MemDerefPrinter(); +} + +bool MemDerefPrinter::runOnFunction(Function &F) { + const DataLayout &DL = F.getParent()->getDataLayout(); + for (auto &I: instructions(F)) { + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + Value *PO = LI->getPointerOperand(); + if (isDereferenceablePointer(PO, LI->getType(), DL)) + Deref.push_back(PO); + if (isDereferenceableAndAlignedPointer( + PO, LI->getType(), MaybeAlign(LI->getAlignment()), DL)) + DerefAndAligned.insert(PO); + } + } + return false; +} + +void MemDerefPrinter::print(raw_ostream &OS, const Module *M) const { + OS << "The following are dereferenceable:\n"; + for (Value *V: Deref) { + V->print(OS); + if (DerefAndAligned.count(V)) + OS << "\t(aligned)"; + else + OS << "\t(unaligned)"; + OS << "\n\n"; + } +} diff --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp new file mode 100644 index 000000000000..172c86eb4646 --- /dev/null +++ b/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -0,0 +1,1051 @@ +//===- MemoryBuiltins.cpp - Identify calls to memory builtins -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This family of functions identifies calls to builtin functions that allocate +// or free memory. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memory-builtins" + +enum AllocType : uint8_t { + OpNewLike = 1<<0, // allocates; never returns null + MallocLike = 1<<1 | OpNewLike, // allocates; may return null + CallocLike = 1<<2, // allocates + bzero + ReallocLike = 1<<3, // reallocates + StrDupLike = 1<<4, + MallocOrCallocLike = MallocLike | CallocLike, + AllocLike = MallocLike | CallocLike | StrDupLike, + AnyAlloc = AllocLike | ReallocLike +}; + +struct AllocFnsTy { + AllocType AllocTy; + unsigned NumParams; + // First and Second size parameters (or -1 if unused) + int FstParam, SndParam; +}; + +// FIXME: certain users need more information. E.g., SimplifyLibCalls needs to +// know which functions are nounwind, noalias, nocapture parameters, etc. +static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = { + {LibFunc_malloc, {MallocLike, 1, 0, -1}}, + {LibFunc_valloc, {MallocLike, 1, 0, -1}}, + {LibFunc_Znwj, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc_ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc_ZnwjSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new(unsigned int, align_val_t) + {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, // new(unsigned int, align_val_t, nothrow) + {MallocLike, 3, 0, -1}}, + {LibFunc_Znwm, {OpNewLike, 1, 0, -1}}, // new(unsigned long) + {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned long, nothrow) + {LibFunc_ZnwmSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new(unsigned long, align_val_t) + {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, // new(unsigned long, align_val_t, nothrow) + {MallocLike, 3, 0, -1}}, + {LibFunc_Znaj, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc_ZnajSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new[](unsigned int, align_val_t) + {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, // new[](unsigned int, align_val_t, nothrow) + {MallocLike, 3, 0, -1}}, + {LibFunc_Znam, {OpNewLike, 1, 0, -1}}, // new[](unsigned long) + {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned long, nothrow) + {LibFunc_ZnamSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new[](unsigned long, align_val_t) + {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, // new[](unsigned long, align_val_t, nothrow) + {MallocLike, 3, 0, -1}}, + {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1}}, // new(unsigned long long) + {LibFunc_msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned long long, nothrow) + {LibFunc_msvc_new_array_int, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc_msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc_msvc_new_array_longlong, {OpNewLike, 1, 0, -1}}, // new[](unsigned long long) + {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned long long, nothrow) + {LibFunc_calloc, {CallocLike, 2, 0, 1}}, + {LibFunc_realloc, {ReallocLike, 2, 1, -1}}, + {LibFunc_reallocf, {ReallocLike, 2, 1, -1}}, + {LibFunc_strdup, {StrDupLike, 1, -1, -1}}, + {LibFunc_strndup, {StrDupLike, 2, 1, -1}} + // TODO: Handle "int posix_memalign(void **, size_t, size_t)" +}; + +static const Function *getCalledFunction(const Value *V, bool LookThroughBitCast, + bool &IsNoBuiltin) { + // Don't care about intrinsics in this case. + if (isa<IntrinsicInst>(V)) + return nullptr; + + if (LookThroughBitCast) + V = V->stripPointerCasts(); + + ImmutableCallSite CS(V); + if (!CS.getInstruction()) + return nullptr; + + IsNoBuiltin = CS.isNoBuiltin(); + + if (const Function *Callee = CS.getCalledFunction()) + return Callee; + return nullptr; +} + +/// Returns the allocation data for the given value if it's either a call to a +/// known allocation function, or a call to a function with the allocsize +/// attribute. +static Optional<AllocFnsTy> +getAllocationDataForFunction(const Function *Callee, AllocType AllocTy, + const TargetLibraryInfo *TLI) { + // Make sure that the function is available. + StringRef FnName = Callee->getName(); + LibFunc TLIFn; + if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) + return None; + + const auto *Iter = find_if( + AllocationFnData, [TLIFn](const std::pair<LibFunc, AllocFnsTy> &P) { + return P.first == TLIFn; + }); + + if (Iter == std::end(AllocationFnData)) + return None; + + const AllocFnsTy *FnData = &Iter->second; + if ((FnData->AllocTy & AllocTy) != FnData->AllocTy) + return None; + + // Check function prototype. + int FstParam = FnData->FstParam; + int SndParam = FnData->SndParam; + FunctionType *FTy = Callee->getFunctionType(); + + if (FTy->getReturnType() == Type::getInt8PtrTy(FTy->getContext()) && + FTy->getNumParams() == FnData->NumParams && + (FstParam < 0 || + (FTy->getParamType(FstParam)->isIntegerTy(32) || + FTy->getParamType(FstParam)->isIntegerTy(64))) && + (SndParam < 0 || + FTy->getParamType(SndParam)->isIntegerTy(32) || + FTy->getParamType(SndParam)->isIntegerTy(64))) + return *FnData; + return None; +} + +static Optional<AllocFnsTy> getAllocationData(const Value *V, AllocType AllocTy, + const TargetLibraryInfo *TLI, + bool LookThroughBitCast = false) { + bool IsNoBuiltinCall; + if (const Function *Callee = + getCalledFunction(V, LookThroughBitCast, IsNoBuiltinCall)) + if (!IsNoBuiltinCall) + return getAllocationDataForFunction(Callee, AllocTy, TLI); + return None; +} + +static Optional<AllocFnsTy> +getAllocationData(const Value *V, AllocType AllocTy, + function_ref<const TargetLibraryInfo &(Function &)> GetTLI, + bool LookThroughBitCast = false) { + bool IsNoBuiltinCall; + if (const Function *Callee = + getCalledFunction(V, LookThroughBitCast, IsNoBuiltinCall)) + if (!IsNoBuiltinCall) + return getAllocationDataForFunction( + Callee, AllocTy, &GetTLI(const_cast<Function &>(*Callee))); + return None; +} + +static Optional<AllocFnsTy> getAllocationSize(const Value *V, + const TargetLibraryInfo *TLI) { + bool IsNoBuiltinCall; + const Function *Callee = + getCalledFunction(V, /*LookThroughBitCast=*/false, IsNoBuiltinCall); + if (!Callee) + return None; + + // Prefer to use existing information over allocsize. This will give us an + // accurate AllocTy. + if (!IsNoBuiltinCall) + if (Optional<AllocFnsTy> Data = + getAllocationDataForFunction(Callee, AnyAlloc, TLI)) + return Data; + + Attribute Attr = Callee->getFnAttribute(Attribute::AllocSize); + if (Attr == Attribute()) + return None; + + std::pair<unsigned, Optional<unsigned>> Args = Attr.getAllocSizeArgs(); + + AllocFnsTy Result; + // Because allocsize only tells us how many bytes are allocated, we're not + // really allowed to assume anything, so we use MallocLike. + Result.AllocTy = MallocLike; + Result.NumParams = Callee->getNumOperands(); + Result.FstParam = Args.first; + Result.SndParam = Args.second.getValueOr(-1); + return Result; +} + +static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) { + ImmutableCallSite CS(LookThroughBitCast ? V->stripPointerCasts() : V); + return CS && CS.hasRetAttr(Attribute::NoAlias); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates or reallocates memory (either malloc, calloc, realloc, or strdup +/// like). +bool llvm::isAllocationFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, AnyAlloc, TLI, LookThroughBitCast).hasValue(); +} +bool llvm::isAllocationFn( + const Value *V, function_ref<const TargetLibraryInfo &(Function &)> GetTLI, + bool LookThroughBitCast) { + return getAllocationData(V, AnyAlloc, GetTLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a function that returns a +/// NoAlias pointer (including malloc/calloc/realloc/strdup-like functions). +bool llvm::isNoAliasFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + // it's safe to consider realloc as noalias since accessing the original + // pointer is undefined behavior + return isAllocationFn(V, TLI, LookThroughBitCast) || + hasNoAliasAttr(V, LookThroughBitCast); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates uninitialized memory (such as malloc). +bool llvm::isMallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, MallocLike, TLI, LookThroughBitCast).hasValue(); +} +bool llvm::isMallocLikeFn( + const Value *V, function_ref<const TargetLibraryInfo &(Function &)> GetTLI, + bool LookThroughBitCast) { + return getAllocationData(V, MallocLike, GetTLI, LookThroughBitCast) + .hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates zero-filled memory (such as calloc). +bool llvm::isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, CallocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory similar to malloc or calloc. +bool llvm::isMallocOrCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, MallocOrCallocLike, TLI, + LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory (either malloc, calloc, or strdup like). +bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, AllocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// reallocates memory (e.g., realloc). +bool llvm::isReallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, ReallocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a functions is a call or invoke to a library function that +/// reallocates memory (e.g., realloc). +bool llvm::isReallocLikeFn(const Function *F, const TargetLibraryInfo *TLI) { + return getAllocationDataForFunction(F, ReallocLike, TLI).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory and throws if an allocation failed (e.g., new). +bool llvm::isOpNewLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, OpNewLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory (strdup, strndup). +bool llvm::isStrdupLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, StrDupLike, TLI, LookThroughBitCast).hasValue(); +} + +/// extractMallocCall - Returns the corresponding CallInst if the instruction +/// is a malloc call. Since CallInst::CreateMalloc() only creates calls, we +/// ignore InvokeInst here. +const CallInst *llvm::extractMallocCall( + const Value *I, + function_ref<const TargetLibraryInfo &(Function &)> GetTLI) { + return isMallocLikeFn(I, GetTLI) ? dyn_cast<CallInst>(I) : nullptr; +} + +static Value *computeArraySize(const CallInst *CI, const DataLayout &DL, + const TargetLibraryInfo *TLI, + bool LookThroughSExt = false) { + if (!CI) + return nullptr; + + // The size of the malloc's result type must be known to determine array size. + Type *T = getMallocAllocatedType(CI, TLI); + if (!T || !T->isSized()) + return nullptr; + + unsigned ElementSize = DL.getTypeAllocSize(T); + if (StructType *ST = dyn_cast<StructType>(T)) + ElementSize = DL.getStructLayout(ST)->getSizeInBytes(); + + // If malloc call's arg can be determined to be a multiple of ElementSize, + // return the multiple. Otherwise, return NULL. + Value *MallocArg = CI->getArgOperand(0); + Value *Multiple = nullptr; + if (ComputeMultiple(MallocArg, ElementSize, Multiple, LookThroughSExt)) + return Multiple; + + return nullptr; +} + +/// getMallocType - Returns the PointerType resulting from the malloc call. +/// The PointerType depends on the number of bitcast uses of the malloc call: +/// 0: PointerType is the calls' return type. +/// 1: PointerType is the bitcast's result type. +/// >1: Unique PointerType cannot be determined, return NULL. +PointerType *llvm::getMallocType(const CallInst *CI, + const TargetLibraryInfo *TLI) { + assert(isMallocLikeFn(CI, TLI) && "getMallocType and not malloc call"); + + PointerType *MallocType = nullptr; + unsigned NumOfBitCastUses = 0; + + // Determine if CallInst has a bitcast use. + for (Value::const_user_iterator UI = CI->user_begin(), E = CI->user_end(); + UI != E;) + if (const BitCastInst *BCI = dyn_cast<BitCastInst>(*UI++)) { + MallocType = cast<PointerType>(BCI->getDestTy()); + NumOfBitCastUses++; + } + + // Malloc call has 1 bitcast use, so type is the bitcast's destination type. + if (NumOfBitCastUses == 1) + return MallocType; + + // Malloc call was not bitcast, so type is the malloc function's return type. + if (NumOfBitCastUses == 0) + return cast<PointerType>(CI->getType()); + + // Type could not be determined. + return nullptr; +} + +/// getMallocAllocatedType - Returns the Type allocated by malloc call. +/// The Type depends on the number of bitcast uses of the malloc call: +/// 0: PointerType is the malloc calls' return type. +/// 1: PointerType is the bitcast's result type. +/// >1: Unique PointerType cannot be determined, return NULL. +Type *llvm::getMallocAllocatedType(const CallInst *CI, + const TargetLibraryInfo *TLI) { + PointerType *PT = getMallocType(CI, TLI); + return PT ? PT->getElementType() : nullptr; +} + +/// getMallocArraySize - Returns the array size of a malloc call. If the +/// argument passed to malloc is a multiple of the size of the malloced type, +/// then return that multiple. For non-array mallocs, the multiple is +/// constant 1. Otherwise, return NULL for mallocs whose array size cannot be +/// determined. +Value *llvm::getMallocArraySize(CallInst *CI, const DataLayout &DL, + const TargetLibraryInfo *TLI, + bool LookThroughSExt) { + assert(isMallocLikeFn(CI, TLI) && "getMallocArraySize and not malloc call"); + return computeArraySize(CI, DL, TLI, LookThroughSExt); +} + +/// extractCallocCall - Returns the corresponding CallInst if the instruction +/// is a calloc call. +const CallInst *llvm::extractCallocCall(const Value *I, + const TargetLibraryInfo *TLI) { + return isCallocLikeFn(I, TLI) ? cast<CallInst>(I) : nullptr; +} + +/// isLibFreeFunction - Returns true if the function is a builtin free() +bool llvm::isLibFreeFunction(const Function *F, const LibFunc TLIFn) { + unsigned ExpectedNumParams; + if (TLIFn == LibFunc_free || + TLIFn == LibFunc_ZdlPv || // operator delete(void*) + TLIFn == LibFunc_ZdaPv || // operator delete[](void*) + TLIFn == LibFunc_msvc_delete_ptr32 || // operator delete(void*) + TLIFn == LibFunc_msvc_delete_ptr64 || // operator delete(void*) + TLIFn == LibFunc_msvc_delete_array_ptr32 || // operator delete[](void*) + TLIFn == LibFunc_msvc_delete_array_ptr64) // operator delete[](void*) + ExpectedNumParams = 1; + else if (TLIFn == LibFunc_ZdlPvj || // delete(void*, uint) + TLIFn == LibFunc_ZdlPvm || // delete(void*, ulong) + TLIFn == LibFunc_ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + TLIFn == LibFunc_ZdlPvSt11align_val_t || // delete(void*, align_val_t) + TLIFn == LibFunc_ZdaPvj || // delete[](void*, uint) + TLIFn == LibFunc_ZdaPvm || // delete[](void*, ulong) + TLIFn == LibFunc_ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow) + TLIFn == LibFunc_ZdaPvSt11align_val_t || // delete[](void*, align_val_t) + TLIFn == LibFunc_msvc_delete_ptr32_int || // delete(void*, uint) + TLIFn == LibFunc_msvc_delete_ptr64_longlong || // delete(void*, ulonglong) + TLIFn == LibFunc_msvc_delete_ptr32_nothrow || // delete(void*, nothrow) + TLIFn == LibFunc_msvc_delete_ptr64_nothrow || // delete(void*, nothrow) + TLIFn == LibFunc_msvc_delete_array_ptr32_int || // delete[](void*, uint) + TLIFn == LibFunc_msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong) + TLIFn == LibFunc_msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow) + TLIFn == LibFunc_msvc_delete_array_ptr64_nothrow) // delete[](void*, nothrow) + ExpectedNumParams = 2; + else if (TLIFn == LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t || // delete(void*, align_val_t, nothrow) + TLIFn == LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t) // delete[](void*, align_val_t, nothrow) + ExpectedNumParams = 3; + else + return false; + + // Check free prototype. + // FIXME: workaround for PR5130, this will be obsolete when a nobuiltin + // attribute will exist. + FunctionType *FTy = F->getFunctionType(); + if (!FTy->getReturnType()->isVoidTy()) + return false; + if (FTy->getNumParams() != ExpectedNumParams) + return false; + if (FTy->getParamType(0) != Type::getInt8PtrTy(F->getContext())) + return false; + + return true; +} + +/// isFreeCall - Returns non-null if the value is a call to the builtin free() +const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { + bool IsNoBuiltinCall; + const Function *Callee = + getCalledFunction(I, /*LookThroughBitCast=*/false, IsNoBuiltinCall); + if (Callee == nullptr || IsNoBuiltinCall) + return nullptr; + + StringRef FnName = Callee->getName(); + LibFunc TLIFn; + if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) + return nullptr; + + return isLibFreeFunction(Callee, TLIFn) ? dyn_cast<CallInst>(I) : nullptr; +} + + +//===----------------------------------------------------------------------===// +// Utility functions to compute size of objects. +// +static APInt getSizeWithOverflow(const SizeOffsetType &Data) { + if (Data.second.isNegative() || Data.first.ult(Data.second)) + return APInt(Data.first.getBitWidth(), 0); + return Data.first - Data.second; +} + +/// Compute the size of the object pointed by Ptr. Returns true and the +/// object size in Size if successful, and false otherwise. +/// If RoundToAlign is true, then Size is rounded up to the alignment of +/// allocas, byval arguments, and global variables. +bool llvm::getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, + const TargetLibraryInfo *TLI, ObjectSizeOpts Opts) { + ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), Opts); + SizeOffsetType Data = Visitor.compute(const_cast<Value*>(Ptr)); + if (!Visitor.bothKnown(Data)) + return false; + + Size = getSizeWithOverflow(Data).getZExtValue(); + return true; +} + +Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, + const DataLayout &DL, + const TargetLibraryInfo *TLI, + bool MustSucceed) { + assert(ObjectSize->getIntrinsicID() == Intrinsic::objectsize && + "ObjectSize must be a call to llvm.objectsize!"); + + bool MaxVal = cast<ConstantInt>(ObjectSize->getArgOperand(1))->isZero(); + ObjectSizeOpts EvalOptions; + // Unless we have to fold this to something, try to be as accurate as + // possible. + if (MustSucceed) + EvalOptions.EvalMode = + MaxVal ? ObjectSizeOpts::Mode::Max : ObjectSizeOpts::Mode::Min; + else + EvalOptions.EvalMode = ObjectSizeOpts::Mode::Exact; + + EvalOptions.NullIsUnknownSize = + cast<ConstantInt>(ObjectSize->getArgOperand(2))->isOne(); + + auto *ResultType = cast<IntegerType>(ObjectSize->getType()); + bool StaticOnly = cast<ConstantInt>(ObjectSize->getArgOperand(3))->isZero(); + if (StaticOnly) { + // FIXME: Does it make sense to just return a failure value if the size won't + // fit in the output and `!MustSucceed`? + uint64_t Size; + if (getObjectSize(ObjectSize->getArgOperand(0), Size, DL, TLI, EvalOptions) && + isUIntN(ResultType->getBitWidth(), Size)) + return ConstantInt::get(ResultType, Size); + } else { + LLVMContext &Ctx = ObjectSize->getFunction()->getContext(); + ObjectSizeOffsetEvaluator Eval(DL, TLI, Ctx, EvalOptions); + SizeOffsetEvalType SizeOffsetPair = + Eval.compute(ObjectSize->getArgOperand(0)); + + if (SizeOffsetPair != ObjectSizeOffsetEvaluator::unknown()) { + IRBuilder<TargetFolder> Builder(Ctx, TargetFolder(DL)); + Builder.SetInsertPoint(ObjectSize); + + // If we've outside the end of the object, then we can always access + // exactly 0 bytes. + Value *ResultSize = + Builder.CreateSub(SizeOffsetPair.first, SizeOffsetPair.second); + Value *UseZero = + Builder.CreateICmpULT(SizeOffsetPair.first, SizeOffsetPair.second); + return Builder.CreateSelect(UseZero, ConstantInt::get(ResultType, 0), + ResultSize); + } + } + + if (!MustSucceed) + return nullptr; + + return ConstantInt::get(ResultType, MaxVal ? -1ULL : 0); +} + +STATISTIC(ObjectVisitorArgument, + "Number of arguments with unsolved size and offset"); +STATISTIC(ObjectVisitorLoad, + "Number of load instructions with unsolved size and offset"); + +APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Alignment) { + if (Options.RoundToAlign && Alignment) + return APInt(IntTyBits, alignTo(Size.getZExtValue(), Align(Alignment))); + return Size; +} + +ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL, + const TargetLibraryInfo *TLI, + LLVMContext &Context, + ObjectSizeOpts Options) + : DL(DL), TLI(TLI), Options(Options) { + // Pointer size must be rechecked for each object visited since it could have + // a different address space. +} + +SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { + IntTyBits = DL.getPointerTypeSizeInBits(V->getType()); + Zero = APInt::getNullValue(IntTyBits); + + V = V->stripPointerCasts(); + if (Instruction *I = dyn_cast<Instruction>(V)) { + // If we have already seen this instruction, bail out. Cycles can happen in + // unreachable code after constant propagation. + if (!SeenInsts.insert(I).second) + return unknown(); + + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) + return visitGEPOperator(*GEP); + return visit(*I); + } + if (Argument *A = dyn_cast<Argument>(V)) + return visitArgument(*A); + if (ConstantPointerNull *P = dyn_cast<ConstantPointerNull>(V)) + return visitConstantPointerNull(*P); + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) + return visitGlobalAlias(*GA); + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) + return visitGlobalVariable(*GV); + if (UndefValue *UV = dyn_cast<UndefValue>(V)) + return visitUndefValue(*UV); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + if (CE->getOpcode() == Instruction::IntToPtr) + return unknown(); // clueless + if (CE->getOpcode() == Instruction::GetElementPtr) + return visitGEPOperator(cast<GEPOperator>(*CE)); + } + + LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor::compute() unhandled value: " + << *V << '\n'); + return unknown(); +} + +/// When we're compiling N-bit code, and the user uses parameters that are +/// greater than N bits (e.g. uint64_t on a 32-bit build), we can run into +/// trouble with APInt size issues. This function handles resizing + overflow +/// checks for us. Check and zext or trunc \p I depending on IntTyBits and +/// I's value. +bool ObjectSizeOffsetVisitor::CheckedZextOrTrunc(APInt &I) { + // More bits than we can handle. Checking the bit width isn't necessary, but + // it's faster than checking active bits, and should give `false` in the + // vast majority of cases. + if (I.getBitWidth() > IntTyBits && I.getActiveBits() > IntTyBits) + return false; + if (I.getBitWidth() != IntTyBits) + I = I.zextOrTrunc(IntTyBits); + return true; +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) { + if (!I.getAllocatedType()->isSized()) + return unknown(); + + APInt Size(IntTyBits, DL.getTypeAllocSize(I.getAllocatedType())); + if (!I.isArrayAllocation()) + return std::make_pair(align(Size, I.getAlignment()), Zero); + + Value *ArraySize = I.getArraySize(); + if (const ConstantInt *C = dyn_cast<ConstantInt>(ArraySize)) { + APInt NumElems = C->getValue(); + if (!CheckedZextOrTrunc(NumElems)) + return unknown(); + + bool Overflow; + Size = Size.umul_ov(NumElems, Overflow); + return Overflow ? unknown() : std::make_pair(align(Size, I.getAlignment()), + Zero); + } + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitArgument(Argument &A) { + // No interprocedural analysis is done at the moment. + if (!A.hasByValOrInAllocaAttr()) { + ++ObjectVisitorArgument; + return unknown(); + } + PointerType *PT = cast<PointerType>(A.getType()); + APInt Size(IntTyBits, DL.getTypeAllocSize(PT->getElementType())); + return std::make_pair(align(Size, A.getParamAlignment()), Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { + Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI); + if (!FnData) + return unknown(); + + // Handle strdup-like functions separately. + if (FnData->AllocTy == StrDupLike) { + APInt Size(IntTyBits, GetStringLength(CS.getArgument(0))); + if (!Size) + return unknown(); + + // Strndup limits strlen. + if (FnData->FstParam > 0) { + ConstantInt *Arg = + dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); + if (!Arg) + return unknown(); + + APInt MaxSize = Arg->getValue().zextOrSelf(IntTyBits); + if (Size.ugt(MaxSize)) + Size = MaxSize + 1; + } + return std::make_pair(Size, Zero); + } + + ConstantInt *Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); + if (!Arg) + return unknown(); + + APInt Size = Arg->getValue(); + if (!CheckedZextOrTrunc(Size)) + return unknown(); + + // Size is determined by just 1 parameter. + if (FnData->SndParam < 0) + return std::make_pair(Size, Zero); + + Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->SndParam)); + if (!Arg) + return unknown(); + + APInt NumElems = Arg->getValue(); + if (!CheckedZextOrTrunc(NumElems)) + return unknown(); + + bool Overflow; + Size = Size.umul_ov(NumElems, Overflow); + return Overflow ? unknown() : std::make_pair(Size, Zero); + + // TODO: handle more standard functions (+ wchar cousins): + // - strdup / strndup + // - strcpy / strncpy + // - strcat / strncat + // - memcpy / memmove + // - strcat / strncat + // - memset +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitConstantPointerNull(ConstantPointerNull& CPN) { + // If null is unknown, there's nothing we can do. Additionally, non-zero + // address spaces can make use of null, so we don't presume to know anything + // about that. + // + // TODO: How should this work with address space casts? We currently just drop + // them on the floor, but it's unclear what we should do when a NULL from + // addrspace(1) gets casted to addrspace(0) (or vice-versa). + if (Options.NullIsUnknownSize || CPN.getType()->getAddressSpace()) + return unknown(); + return std::make_pair(Zero, Zero); +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitExtractElementInst(ExtractElementInst&) { + return unknown(); +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitExtractValueInst(ExtractValueInst&) { + // Easy cases were already folded by previous passes. + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGEPOperator(GEPOperator &GEP) { + SizeOffsetType PtrData = compute(GEP.getPointerOperand()); + APInt Offset(IntTyBits, 0); + if (!bothKnown(PtrData) || !GEP.accumulateConstantOffset(DL, Offset)) + return unknown(); + + return std::make_pair(PtrData.first, PtrData.second + Offset); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) { + if (GA.isInterposable()) + return unknown(); + return compute(GA.getAliasee()); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalVariable(GlobalVariable &GV){ + if (!GV.hasDefinitiveInitializer()) + return unknown(); + + APInt Size(IntTyBits, DL.getTypeAllocSize(GV.getValueType())); + return std::make_pair(align(Size, GV.getAlignment()), Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitIntToPtrInst(IntToPtrInst&) { + // clueless + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitLoadInst(LoadInst&) { + ++ObjectVisitorLoad; + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode&) { + // too complex to analyze statically. + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { + SizeOffsetType TrueSide = compute(I.getTrueValue()); + SizeOffsetType FalseSide = compute(I.getFalseValue()); + if (bothKnown(TrueSide) && bothKnown(FalseSide)) { + if (TrueSide == FalseSide) { + return TrueSide; + } + + APInt TrueResult = getSizeWithOverflow(TrueSide); + APInt FalseResult = getSizeWithOverflow(FalseSide); + + if (TrueResult == FalseResult) { + return TrueSide; + } + if (Options.EvalMode == ObjectSizeOpts::Mode::Min) { + if (TrueResult.slt(FalseResult)) + return TrueSide; + return FalseSide; + } + if (Options.EvalMode == ObjectSizeOpts::Mode::Max) { + if (TrueResult.sgt(FalseResult)) + return TrueSide; + return FalseSide; + } + } + return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitUndefValue(UndefValue&) { + return std::make_pair(Zero, Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitInstruction(Instruction &I) { + LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor unknown instruction:" << I + << '\n'); + return unknown(); +} + +ObjectSizeOffsetEvaluator::ObjectSizeOffsetEvaluator( + const DataLayout &DL, const TargetLibraryInfo *TLI, LLVMContext &Context, + ObjectSizeOpts EvalOpts) + : DL(DL), TLI(TLI), Context(Context), + Builder(Context, TargetFolder(DL), + IRBuilderCallbackInserter( + [&](Instruction *I) { InsertedInstructions.insert(I); })), + EvalOpts(EvalOpts) { + // IntTy and Zero must be set for each compute() since the address space may + // be different for later objects. +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) { + // XXX - Are vectors of pointers possible here? + IntTy = cast<IntegerType>(DL.getIntPtrType(V->getType())); + Zero = ConstantInt::get(IntTy, 0); + + SizeOffsetEvalType Result = compute_(V); + + if (!bothKnown(Result)) { + // Erase everything that was computed in this iteration from the cache, so + // that no dangling references are left behind. We could be a bit smarter if + // we kept a dependency graph. It's probably not worth the complexity. + for (const Value *SeenVal : SeenVals) { + CacheMapTy::iterator CacheIt = CacheMap.find(SeenVal); + // non-computable results can be safely cached + if (CacheIt != CacheMap.end() && anyKnown(CacheIt->second)) + CacheMap.erase(CacheIt); + } + + // Erase any instructions we inserted as part of the traversal. + for (Instruction *I : InsertedInstructions) { + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + } + } + + SeenVals.clear(); + InsertedInstructions.clear(); + return Result; +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { + ObjectSizeOffsetVisitor Visitor(DL, TLI, Context, EvalOpts); + SizeOffsetType Const = Visitor.compute(V); + if (Visitor.bothKnown(Const)) + return std::make_pair(ConstantInt::get(Context, Const.first), + ConstantInt::get(Context, Const.second)); + + V = V->stripPointerCasts(); + + // Check cache. + CacheMapTy::iterator CacheIt = CacheMap.find(V); + if (CacheIt != CacheMap.end()) + return CacheIt->second; + + // Always generate code immediately before the instruction being + // processed, so that the generated code dominates the same BBs. + BuilderTy::InsertPointGuard Guard(Builder); + if (Instruction *I = dyn_cast<Instruction>(V)) + Builder.SetInsertPoint(I); + + // Now compute the size and offset. + SizeOffsetEvalType Result; + + // Record the pointers that were handled in this run, so that they can be + // cleaned later if something fails. We also use this set to break cycles that + // can occur in dead code. + if (!SeenVals.insert(V).second) { + Result = unknown(); + } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + Result = visitGEPOperator(*GEP); + } else if (Instruction *I = dyn_cast<Instruction>(V)) { + Result = visit(*I); + } else if (isa<Argument>(V) || + (isa<ConstantExpr>(V) && + cast<ConstantExpr>(V)->getOpcode() == Instruction::IntToPtr) || + isa<GlobalAlias>(V) || + isa<GlobalVariable>(V)) { + // Ignore values where we cannot do more than ObjectSizeVisitor. + Result = unknown(); + } else { + LLVM_DEBUG( + dbgs() << "ObjectSizeOffsetEvaluator::compute() unhandled value: " << *V + << '\n'); + Result = unknown(); + } + + // Don't reuse CacheIt since it may be invalid at this point. + CacheMap[V] = Result; + return Result; +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitAllocaInst(AllocaInst &I) { + if (!I.getAllocatedType()->isSized()) + return unknown(); + + // must be a VLA + assert(I.isArrayAllocation()); + Value *ArraySize = I.getArraySize(); + Value *Size = ConstantInt::get(ArraySize->getType(), + DL.getTypeAllocSize(I.getAllocatedType())); + Size = Builder.CreateMul(Size, ArraySize); + return std::make_pair(Size, Zero); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitCallSite(CallSite CS) { + Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI); + if (!FnData) + return unknown(); + + // Handle strdup-like functions separately. + if (FnData->AllocTy == StrDupLike) { + // TODO + return unknown(); + } + + Value *FirstArg = CS.getArgument(FnData->FstParam); + FirstArg = Builder.CreateZExt(FirstArg, IntTy); + if (FnData->SndParam < 0) + return std::make_pair(FirstArg, Zero); + + Value *SecondArg = CS.getArgument(FnData->SndParam); + SecondArg = Builder.CreateZExt(SecondArg, IntTy); + Value *Size = Builder.CreateMul(FirstArg, SecondArg); + return std::make_pair(Size, Zero); + + // TODO: handle more standard functions (+ wchar cousins): + // - strdup / strndup + // - strcpy / strncpy + // - strcat / strncat + // - memcpy / memmove + // - strcat / strncat + // - memset +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitExtractElementInst(ExtractElementInst&) { + return unknown(); +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitExtractValueInst(ExtractValueInst&) { + return unknown(); +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitGEPOperator(GEPOperator &GEP) { + SizeOffsetEvalType PtrData = compute_(GEP.getPointerOperand()); + if (!bothKnown(PtrData)) + return unknown(); + + Value *Offset = EmitGEPOffset(&Builder, DL, &GEP, /*NoAssumptions=*/true); + Offset = Builder.CreateAdd(PtrData.second, Offset); + return std::make_pair(PtrData.first, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitIntToPtrInst(IntToPtrInst&) { + // clueless + return unknown(); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst&) { + return unknown(); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitPHINode(PHINode &PHI) { + // Create 2 PHIs: one for size and another for offset. + PHINode *SizePHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); + PHINode *OffsetPHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); + + // Insert right away in the cache to handle recursive PHIs. + CacheMap[&PHI] = std::make_pair(SizePHI, OffsetPHI); + + // Compute offset/size for each PHI incoming pointer. + for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) { + Builder.SetInsertPoint(&*PHI.getIncomingBlock(i)->getFirstInsertionPt()); + SizeOffsetEvalType EdgeData = compute_(PHI.getIncomingValue(i)); + + if (!bothKnown(EdgeData)) { + OffsetPHI->replaceAllUsesWith(UndefValue::get(IntTy)); + OffsetPHI->eraseFromParent(); + InsertedInstructions.erase(OffsetPHI); + SizePHI->replaceAllUsesWith(UndefValue::get(IntTy)); + SizePHI->eraseFromParent(); + InsertedInstructions.erase(SizePHI); + return unknown(); + } + SizePHI->addIncoming(EdgeData.first, PHI.getIncomingBlock(i)); + OffsetPHI->addIncoming(EdgeData.second, PHI.getIncomingBlock(i)); + } + + Value *Size = SizePHI, *Offset = OffsetPHI; + if (Value *Tmp = SizePHI->hasConstantValue()) { + Size = Tmp; + SizePHI->replaceAllUsesWith(Size); + SizePHI->eraseFromParent(); + InsertedInstructions.erase(SizePHI); + } + if (Value *Tmp = OffsetPHI->hasConstantValue()) { + Offset = Tmp; + OffsetPHI->replaceAllUsesWith(Offset); + OffsetPHI->eraseFromParent(); + InsertedInstructions.erase(OffsetPHI); + } + return std::make_pair(Size, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitSelectInst(SelectInst &I) { + SizeOffsetEvalType TrueSide = compute_(I.getTrueValue()); + SizeOffsetEvalType FalseSide = compute_(I.getFalseValue()); + + if (!bothKnown(TrueSide) || !bothKnown(FalseSide)) + return unknown(); + if (TrueSide == FalseSide) + return TrueSide; + + Value *Size = Builder.CreateSelect(I.getCondition(), TrueSide.first, + FalseSide.first); + Value *Offset = Builder.CreateSelect(I.getCondition(), TrueSide.second, + FalseSide.second); + return std::make_pair(Size, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitInstruction(Instruction &I) { + LLVM_DEBUG(dbgs() << "ObjectSizeOffsetEvaluator unknown instruction:" << I + << '\n'); + return unknown(); +} diff --git a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp new file mode 100644 index 000000000000..884587e020bb --- /dev/null +++ b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -0,0 +1,1824 @@ +//===- MemoryDependenceAnalysis.cpp - Mem Deps Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an analysis that determines, for a given memory +// operation, what preceding memory operations it depends on. It builds on +// alias analysis information, and tries to provide a lazy, caching interface to +// a common kind of alias information query. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/PhiValues.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memdep" + +STATISTIC(NumCacheNonLocal, "Number of fully cached non-local responses"); +STATISTIC(NumCacheDirtyNonLocal, "Number of dirty cached non-local responses"); +STATISTIC(NumUncacheNonLocal, "Number of uncached non-local responses"); + +STATISTIC(NumCacheNonLocalPtr, + "Number of fully cached non-local ptr responses"); +STATISTIC(NumCacheDirtyNonLocalPtr, + "Number of cached, but dirty, non-local ptr responses"); +STATISTIC(NumUncacheNonLocalPtr, "Number of uncached non-local ptr responses"); +STATISTIC(NumCacheCompleteNonLocalPtr, + "Number of block queries that were completely cached"); + +// Limit for the number of instructions to scan in a block. + +static cl::opt<unsigned> BlockScanLimit( + "memdep-block-scan-limit", cl::Hidden, cl::init(100), + cl::desc("The number of instructions to scan in a block in memory " + "dependency analysis (default = 100)")); + +static cl::opt<unsigned> + BlockNumberLimit("memdep-block-number-limit", cl::Hidden, cl::init(1000), + cl::desc("The number of blocks to scan during memory " + "dependency analysis (default = 1000)")); + +// Limit on the number of memdep results to process. +static const unsigned int NumResultsLimit = 100; + +/// This is a helper function that removes Val from 'Inst's set in ReverseMap. +/// +/// If the set becomes empty, remove Inst's entry. +template <typename KeyTy> +static void +RemoveFromReverseMap(DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>> &ReverseMap, + Instruction *Inst, KeyTy Val) { + typename DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>>::iterator InstIt = + ReverseMap.find(Inst); + assert(InstIt != ReverseMap.end() && "Reverse map out of sync?"); + bool Found = InstIt->second.erase(Val); + assert(Found && "Invalid reverse map!"); + (void)Found; + if (InstIt->second.empty()) + ReverseMap.erase(InstIt); +} + +/// If the given instruction references a specific memory location, fill in Loc +/// with the details, otherwise set Loc.Ptr to null. +/// +/// Returns a ModRefInfo value describing the general behavior of the +/// instruction. +static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, + const TargetLibraryInfo &TLI) { + if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + if (LI->isUnordered()) { + Loc = MemoryLocation::get(LI); + return ModRefInfo::Ref; + } + if (LI->getOrdering() == AtomicOrdering::Monotonic) { + Loc = MemoryLocation::get(LI); + return ModRefInfo::ModRef; + } + Loc = MemoryLocation(); + return ModRefInfo::ModRef; + } + + if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + if (SI->isUnordered()) { + Loc = MemoryLocation::get(SI); + return ModRefInfo::Mod; + } + if (SI->getOrdering() == AtomicOrdering::Monotonic) { + Loc = MemoryLocation::get(SI); + return ModRefInfo::ModRef; + } + Loc = MemoryLocation(); + return ModRefInfo::ModRef; + } + + if (const VAArgInst *V = dyn_cast<VAArgInst>(Inst)) { + Loc = MemoryLocation::get(V); + return ModRefInfo::ModRef; + } + + if (const CallInst *CI = isFreeCall(Inst, &TLI)) { + // calls to free() deallocate the entire structure + Loc = MemoryLocation(CI->getArgOperand(0)); + return ModRefInfo::Mod; + } + + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + Loc = MemoryLocation::getForArgument(II, 1, TLI); + // These intrinsics don't really modify the memory, but returning Mod + // will allow them to be handled conservatively. + return ModRefInfo::Mod; + case Intrinsic::invariant_end: + Loc = MemoryLocation::getForArgument(II, 2, TLI); + // These intrinsics don't really modify the memory, but returning Mod + // will allow them to be handled conservatively. + return ModRefInfo::Mod; + default: + break; + } + } + + // Otherwise, just do the coarse-grained thing that always works. + if (Inst->mayWriteToMemory()) + return ModRefInfo::ModRef; + if (Inst->mayReadFromMemory()) + return ModRefInfo::Ref; + return ModRefInfo::NoModRef; +} + +/// Private helper for finding the local dependencies of a call site. +MemDepResult MemoryDependenceResults::getCallDependencyFrom( + CallBase *Call, bool isReadOnlyCall, BasicBlock::iterator ScanIt, + BasicBlock *BB) { + unsigned Limit = getDefaultBlockScanLimit(); + + // Walk backwards through the block, looking for dependencies. + while (ScanIt != BB->begin()) { + Instruction *Inst = &*--ScanIt; + // Debug intrinsics don't cause dependences and should not affect Limit + if (isa<DbgInfoIntrinsic>(Inst)) + continue; + + // Limit the amount of scanning we do so we don't end up with quadratic + // running time on extreme testcases. + --Limit; + if (!Limit) + return MemDepResult::getUnknown(); + + // If this inst is a memory op, get the pointer it accessed + MemoryLocation Loc; + ModRefInfo MR = GetLocation(Inst, Loc, TLI); + if (Loc.Ptr) { + // A simple instruction. + if (isModOrRefSet(AA.getModRefInfo(Call, Loc))) + return MemDepResult::getClobber(Inst); + continue; + } + + if (auto *CallB = dyn_cast<CallBase>(Inst)) { + // If these two calls do not interfere, look past it. + if (isNoModRef(AA.getModRefInfo(Call, CallB))) { + // If the two calls are the same, return Inst as a Def, so that + // Call can be found redundant and eliminated. + if (isReadOnlyCall && !isModSet(MR) && + Call->isIdenticalToWhenDefined(CallB)) + return MemDepResult::getDef(Inst); + + // Otherwise if the two calls don't interact (e.g. CallB is readnone) + // keep scanning. + continue; + } else + return MemDepResult::getClobber(Inst); + } + + // If we could not obtain a pointer for the instruction and the instruction + // touches memory then assume that this is a dependency. + if (isModOrRefSet(MR)) + return MemDepResult::getClobber(Inst); + } + + // No dependence found. If this is the entry block of the function, it is + // unknown, otherwise it is non-local. + if (BB != &BB->getParent()->getEntryBlock()) + return MemDepResult::getNonLocal(); + return MemDepResult::getNonFuncLocal(); +} + +unsigned MemoryDependenceResults::getLoadLoadClobberFullWidthSize( + const Value *MemLocBase, int64_t MemLocOffs, unsigned MemLocSize, + const LoadInst *LI) { + // We can only extend simple integer loads. + if (!isa<IntegerType>(LI->getType()) || !LI->isSimple()) + return 0; + + // Load widening is hostile to ThreadSanitizer: it may cause false positives + // or make the reports more cryptic (access sizes are wrong). + if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeThread)) + return 0; + + const DataLayout &DL = LI->getModule()->getDataLayout(); + + // Get the base of this load. + int64_t LIOffs = 0; + const Value *LIBase = + GetPointerBaseWithConstantOffset(LI->getPointerOperand(), LIOffs, DL); + + // If the two pointers are not based on the same pointer, we can't tell that + // they are related. + if (LIBase != MemLocBase) + return 0; + + // Okay, the two values are based on the same pointer, but returned as + // no-alias. This happens when we have things like two byte loads at "P+1" + // and "P+3". Check to see if increasing the size of the "LI" load up to its + // alignment (or the largest native integer type) will allow us to load all + // the bits required by MemLoc. + + // If MemLoc is before LI, then no widening of LI will help us out. + if (MemLocOffs < LIOffs) + return 0; + + // Get the alignment of the load in bytes. We assume that it is safe to load + // any legal integer up to this size without a problem. For example, if we're + // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can + // widen it up to an i32 load. If it is known 2-byte aligned, we can widen it + // to i16. + unsigned LoadAlign = LI->getAlignment(); + + int64_t MemLocEnd = MemLocOffs + MemLocSize; + + // If no amount of rounding up will let MemLoc fit into LI, then bail out. + if (LIOffs + LoadAlign < MemLocEnd) + return 0; + + // This is the size of the load to try. Start with the next larger power of + // two. + unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U; + NewLoadByteSize = NextPowerOf2(NewLoadByteSize); + + while (true) { + // If this load size is bigger than our known alignment or would not fit + // into a native integer register, then we fail. + if (NewLoadByteSize > LoadAlign || + !DL.fitsInLegalInteger(NewLoadByteSize * 8)) + return 0; + + if (LIOffs + NewLoadByteSize > MemLocEnd && + (LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeAddress) || + LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeHWAddress))) + // We will be reading past the location accessed by the original program. + // While this is safe in a regular build, Address Safety analysis tools + // may start reporting false warnings. So, don't do widening. + return 0; + + // If a load of this width would include all of MemLoc, then we succeed. + if (LIOffs + NewLoadByteSize >= MemLocEnd) + return NewLoadByteSize; + + NewLoadByteSize <<= 1; + } +} + +static bool isVolatile(Instruction *Inst) { + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return LI->isVolatile(); + if (auto *SI = dyn_cast<StoreInst>(Inst)) + return SI->isVolatile(); + if (auto *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) + return AI->isVolatile(); + return false; +} + +MemDepResult MemoryDependenceResults::getPointerDependencyFrom( + const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, + BasicBlock *BB, Instruction *QueryInst, unsigned *Limit, + OrderedBasicBlock *OBB) { + MemDepResult InvariantGroupDependency = MemDepResult::getUnknown(); + if (QueryInst != nullptr) { + if (auto *LI = dyn_cast<LoadInst>(QueryInst)) { + InvariantGroupDependency = getInvariantGroupPointerDependency(LI, BB); + + if (InvariantGroupDependency.isDef()) + return InvariantGroupDependency; + } + } + MemDepResult SimpleDep = getSimplePointerDependencyFrom( + MemLoc, isLoad, ScanIt, BB, QueryInst, Limit, OBB); + if (SimpleDep.isDef()) + return SimpleDep; + // Non-local invariant group dependency indicates there is non local Def + // (it only returns nonLocal if it finds nonLocal def), which is better than + // local clobber and everything else. + if (InvariantGroupDependency.isNonLocal()) + return InvariantGroupDependency; + + assert(InvariantGroupDependency.isUnknown() && + "InvariantGroupDependency should be only unknown at this point"); + return SimpleDep; +} + +MemDepResult +MemoryDependenceResults::getInvariantGroupPointerDependency(LoadInst *LI, + BasicBlock *BB) { + + if (!LI->hasMetadata(LLVMContext::MD_invariant_group)) + return MemDepResult::getUnknown(); + + // Take the ptr operand after all casts and geps 0. This way we can search + // cast graph down only. + Value *LoadOperand = LI->getPointerOperand()->stripPointerCasts(); + + // It's is not safe to walk the use list of global value, because function + // passes aren't allowed to look outside their functions. + // FIXME: this could be fixed by filtering instructions from outside + // of current function. + if (isa<GlobalValue>(LoadOperand)) + return MemDepResult::getUnknown(); + + // Queue to process all pointers that are equivalent to load operand. + SmallVector<const Value *, 8> LoadOperandsQueue; + LoadOperandsQueue.push_back(LoadOperand); + + Instruction *ClosestDependency = nullptr; + // Order of instructions in uses list is unpredictible. In order to always + // get the same result, we will look for the closest dominance. + auto GetClosestDependency = [this](Instruction *Best, Instruction *Other) { + assert(Other && "Must call it with not null instruction"); + if (Best == nullptr || DT.dominates(Best, Other)) + return Other; + return Best; + }; + + // FIXME: This loop is O(N^2) because dominates can be O(n) and in worst case + // we will see all the instructions. This should be fixed in MSSA. + while (!LoadOperandsQueue.empty()) { + const Value *Ptr = LoadOperandsQueue.pop_back_val(); + assert(Ptr && !isa<GlobalValue>(Ptr) && + "Null or GlobalValue should not be inserted"); + + for (const Use &Us : Ptr->uses()) { + auto *U = dyn_cast<Instruction>(Us.getUser()); + if (!U || U == LI || !DT.dominates(U, LI)) + continue; + + // Bitcast or gep with zeros are using Ptr. Add to queue to check it's + // users. U = bitcast Ptr + if (isa<BitCastInst>(U)) { + LoadOperandsQueue.push_back(U); + continue; + } + // Gep with zeros is equivalent to bitcast. + // FIXME: we are not sure if some bitcast should be canonicalized to gep 0 + // or gep 0 to bitcast because of SROA, so there are 2 forms. When + // typeless pointers will be ready then both cases will be gone + // (and this BFS also won't be needed). + if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) + if (GEP->hasAllZeroIndices()) { + LoadOperandsQueue.push_back(U); + continue; + } + + // If we hit load/store with the same invariant.group metadata (and the + // same pointer operand) we can assume that value pointed by pointer + // operand didn't change. + if ((isa<LoadInst>(U) || isa<StoreInst>(U)) && + U->hasMetadata(LLVMContext::MD_invariant_group)) + ClosestDependency = GetClosestDependency(ClosestDependency, U); + } + } + + if (!ClosestDependency) + return MemDepResult::getUnknown(); + if (ClosestDependency->getParent() == BB) + return MemDepResult::getDef(ClosestDependency); + // Def(U) can't be returned here because it is non-local. If local + // dependency won't be found then return nonLocal counting that the + // user will call getNonLocalPointerDependency, which will return cached + // result. + NonLocalDefsCache.try_emplace( + LI, NonLocalDepResult(ClosestDependency->getParent(), + MemDepResult::getDef(ClosestDependency), nullptr)); + ReverseNonLocalDefsCache[ClosestDependency].insert(LI); + return MemDepResult::getNonLocal(); +} + +MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( + const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, + BasicBlock *BB, Instruction *QueryInst, unsigned *Limit, + OrderedBasicBlock *OBB) { + bool isInvariantLoad = false; + + unsigned DefaultLimit = getDefaultBlockScanLimit(); + if (!Limit) + Limit = &DefaultLimit; + + // We must be careful with atomic accesses, as they may allow another thread + // to touch this location, clobbering it. We are conservative: if the + // QueryInst is not a simple (non-atomic) memory access, we automatically + // return getClobber. + // If it is simple, we know based on the results of + // "Compiler testing via a theory of sound optimisations in the C11/C++11 + // memory model" in PLDI 2013, that a non-atomic location can only be + // clobbered between a pair of a release and an acquire action, with no + // access to the location in between. + // Here is an example for giving the general intuition behind this rule. + // In the following code: + // store x 0; + // release action; [1] + // acquire action; [4] + // %val = load x; + // It is unsafe to replace %val by 0 because another thread may be running: + // acquire action; [2] + // store x 42; + // release action; [3] + // with synchronization from 1 to 2 and from 3 to 4, resulting in %val + // being 42. A key property of this program however is that if either + // 1 or 4 were missing, there would be a race between the store of 42 + // either the store of 0 or the load (making the whole program racy). + // The paper mentioned above shows that the same property is respected + // by every program that can detect any optimization of that kind: either + // it is racy (undefined) or there is a release followed by an acquire + // between the pair of accesses under consideration. + + // If the load is invariant, we "know" that it doesn't alias *any* write. We + // do want to respect mustalias results since defs are useful for value + // forwarding, but any mayalias write can be assumed to be noalias. + // Arguably, this logic should be pushed inside AliasAnalysis itself. + if (isLoad && QueryInst) { + LoadInst *LI = dyn_cast<LoadInst>(QueryInst); + if (LI && LI->hasMetadata(LLVMContext::MD_invariant_load)) + isInvariantLoad = true; + } + + const DataLayout &DL = BB->getModule()->getDataLayout(); + + // If the caller did not provide an ordered basic block, + // create one to lazily compute and cache instruction + // positions inside a BB. This is used to provide fast queries for relative + // position between two instructions in a BB and can be used by + // AliasAnalysis::callCapturesBefore. + OrderedBasicBlock OBBTmp(BB); + if (!OBB) + OBB = &OBBTmp; + + // Return "true" if and only if the instruction I is either a non-simple + // load or a non-simple store. + auto isNonSimpleLoadOrStore = [](Instruction *I) -> bool { + if (auto *LI = dyn_cast<LoadInst>(I)) + return !LI->isSimple(); + if (auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isSimple(); + return false; + }; + + // Return "true" if I is not a load and not a store, but it does access + // memory. + auto isOtherMemAccess = [](Instruction *I) -> bool { + return !isa<LoadInst>(I) && !isa<StoreInst>(I) && I->mayReadOrWriteMemory(); + }; + + // Walk backwards through the basic block, looking for dependencies. + while (ScanIt != BB->begin()) { + Instruction *Inst = &*--ScanIt; + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) + // Debug intrinsics don't (and can't) cause dependencies. + if (isa<DbgInfoIntrinsic>(II)) + continue; + + // Limit the amount of scanning we do so we don't end up with quadratic + // running time on extreme testcases. + --*Limit; + if (!*Limit) + return MemDepResult::getUnknown(); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + // If we reach a lifetime begin or end marker, then the query ends here + // because the value is undefined. + if (II->getIntrinsicID() == Intrinsic::lifetime_start) { + // FIXME: This only considers queries directly on the invariant-tagged + // pointer, not on query pointers that are indexed off of them. It'd + // be nice to handle that at some point (the right approach is to use + // GetPointerBaseWithConstantOffset). + if (AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), MemLoc)) + return MemDepResult::getDef(II); + continue; + } + } + + // Values depend on loads if the pointers are must aliased. This means + // that a load depends on another must aliased load from the same value. + // One exception is atomic loads: a value can depend on an atomic load that + // it does not alias with when this atomic load indicates that another + // thread may be accessing the location. + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses, for example, can be safely + // reordered with volatile accesses. + if (LI->isVolatile()) { + if (!QueryInst) + // Original QueryInst *may* be volatile + return MemDepResult::getClobber(LI); + if (isVolatile(QueryInst)) + // Ordering required if QueryInst is itself volatile + return MemDepResult::getClobber(LI); + // Otherwise, volatile doesn't imply any special ordering + } + + // Atomic loads have complications involved. + // A Monotonic (or higher) load is OK if the query inst is itself not + // atomic. + // FIXME: This is overly conservative. + if (LI->isAtomic() && isStrongerThanUnordered(LI->getOrdering())) { + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) + return MemDepResult::getClobber(LI); + if (LI->getOrdering() != AtomicOrdering::Monotonic) + return MemDepResult::getClobber(LI); + } + + MemoryLocation LoadLoc = MemoryLocation::get(LI); + + // If we found a pointer, check if it could be the same as our pointer. + AliasResult R = AA.alias(LoadLoc, MemLoc); + + if (isLoad) { + if (R == NoAlias) + continue; + + // Must aliased loads are defs of each other. + if (R == MustAlias) + return MemDepResult::getDef(Inst); + +#if 0 // FIXME: Temporarily disabled. GVN is cleverly rewriting loads + // in terms of clobbering loads, but since it does this by looking + // at the clobbering load directly, it doesn't know about any + // phi translation that may have happened along the way. + + // If we have a partial alias, then return this as a clobber for the + // client to handle. + if (R == PartialAlias) + return MemDepResult::getClobber(Inst); +#endif + + // Random may-alias loads don't depend on each other without a + // dependence. + continue; + } + + // Stores don't depend on other no-aliased accesses. + if (R == NoAlias) + continue; + + // Stores don't alias loads from read-only memory. + if (AA.pointsToConstantMemory(LoadLoc)) + continue; + + // Stores depend on may/must aliased loads. + return MemDepResult::getDef(Inst); + } + + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + // Atomic stores have complications involved. + // A Monotonic store is OK if the query inst is itself not atomic. + // FIXME: This is overly conservative. + if (!SI->isUnordered() && SI->isAtomic()) { + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) + return MemDepResult::getClobber(SI); + if (SI->getOrdering() != AtomicOrdering::Monotonic) + return MemDepResult::getClobber(SI); + } + + // FIXME: this is overly conservative. + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses can for example be reordered + // with volatile accesses. + if (SI->isVolatile()) + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) + return MemDepResult::getClobber(SI); + + // If alias analysis can tell that this store is guaranteed to not modify + // the query pointer, ignore it. Use getModRefInfo to handle cases where + // the query pointer points to constant memory etc. + if (!isModOrRefSet(AA.getModRefInfo(SI, MemLoc))) + continue; + + // Ok, this store might clobber the query pointer. Check to see if it is + // a must alias: in this case, we want to return this as a def. + // FIXME: Use ModRefInfo::Must bit from getModRefInfo call above. + MemoryLocation StoreLoc = MemoryLocation::get(SI); + + // If we found a pointer, check if it could be the same as our pointer. + AliasResult R = AA.alias(StoreLoc, MemLoc); + + if (R == NoAlias) + continue; + if (R == MustAlias) + return MemDepResult::getDef(Inst); + if (isInvariantLoad) + continue; + return MemDepResult::getClobber(Inst); + } + + // If this is an allocation, and if we know that the accessed pointer is to + // the allocation, return Def. This means that there is no dependence and + // the access can be optimized based on that. For example, a load could + // turn into undef. Note that we can bypass the allocation itself when + // looking for a clobber in many cases; that's an alias property and is + // handled by BasicAA. + if (isa<AllocaInst>(Inst) || isNoAliasFn(Inst, &TLI)) { + const Value *AccessPtr = GetUnderlyingObject(MemLoc.Ptr, DL); + if (AccessPtr == Inst || AA.isMustAlias(Inst, AccessPtr)) + return MemDepResult::getDef(Inst); + } + + if (isInvariantLoad) + continue; + + // A release fence requires that all stores complete before it, but does + // not prevent the reordering of following loads or stores 'before' the + // fence. As a result, we look past it when finding a dependency for + // loads. DSE uses this to find preceding stores to delete and thus we + // can't bypass the fence if the query instruction is a store. + if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + if (isLoad && FI->getOrdering() == AtomicOrdering::Release) + continue; + + // See if this instruction (e.g. a call or vaarg) mod/ref's the pointer. + ModRefInfo MR = AA.getModRefInfo(Inst, MemLoc); + // If necessary, perform additional analysis. + if (isModAndRefSet(MR)) + MR = AA.callCapturesBefore(Inst, MemLoc, &DT, OBB); + switch (clearMust(MR)) { + case ModRefInfo::NoModRef: + // If the call has no effect on the queried pointer, just ignore it. + continue; + case ModRefInfo::Mod: + return MemDepResult::getClobber(Inst); + case ModRefInfo::Ref: + // If the call is known to never store to the pointer, and if this is a + // load query, we can safely ignore it (scan past it). + if (isLoad) + continue; + LLVM_FALLTHROUGH; + default: + // Otherwise, there is a potential dependence. Return a clobber. + return MemDepResult::getClobber(Inst); + } + } + + // No dependence found. If this is the entry block of the function, it is + // unknown, otherwise it is non-local. + if (BB != &BB->getParent()->getEntryBlock()) + return MemDepResult::getNonLocal(); + return MemDepResult::getNonFuncLocal(); +} + +MemDepResult MemoryDependenceResults::getDependency(Instruction *QueryInst, + OrderedBasicBlock *OBB) { + Instruction *ScanPos = QueryInst; + + // Check for a cached result + MemDepResult &LocalCache = LocalDeps[QueryInst]; + + // If the cached entry is non-dirty, just return it. Note that this depends + // on MemDepResult's default constructing to 'dirty'. + if (!LocalCache.isDirty()) + return LocalCache; + + // Otherwise, if we have a dirty entry, we know we can start the scan at that + // instruction, which may save us some work. + if (Instruction *Inst = LocalCache.getInst()) { + ScanPos = Inst; + + RemoveFromReverseMap(ReverseLocalDeps, Inst, QueryInst); + } + + BasicBlock *QueryParent = QueryInst->getParent(); + + // Do the scan. + if (BasicBlock::iterator(QueryInst) == QueryParent->begin()) { + // No dependence found. If this is the entry block of the function, it is + // unknown, otherwise it is non-local. + if (QueryParent != &QueryParent->getParent()->getEntryBlock()) + LocalCache = MemDepResult::getNonLocal(); + else + LocalCache = MemDepResult::getNonFuncLocal(); + } else { + MemoryLocation MemLoc; + ModRefInfo MR = GetLocation(QueryInst, MemLoc, TLI); + if (MemLoc.Ptr) { + // If we can do a pointer scan, make it happen. + bool isLoad = !isModSet(MR); + if (auto *II = dyn_cast<IntrinsicInst>(QueryInst)) + isLoad |= II->getIntrinsicID() == Intrinsic::lifetime_start; + + LocalCache = + getPointerDependencyFrom(MemLoc, isLoad, ScanPos->getIterator(), + QueryParent, QueryInst, nullptr, OBB); + } else if (auto *QueryCall = dyn_cast<CallBase>(QueryInst)) { + bool isReadOnly = AA.onlyReadsMemory(QueryCall); + LocalCache = getCallDependencyFrom(QueryCall, isReadOnly, + ScanPos->getIterator(), QueryParent); + } else + // Non-memory instruction. + LocalCache = MemDepResult::getUnknown(); + } + + // Remember the result! + if (Instruction *I = LocalCache.getInst()) + ReverseLocalDeps[I].insert(QueryInst); + + return LocalCache; +} + +#ifndef NDEBUG +/// This method is used when -debug is specified to verify that cache arrays +/// are properly kept sorted. +static void AssertSorted(MemoryDependenceResults::NonLocalDepInfo &Cache, + int Count = -1) { + if (Count == -1) + Count = Cache.size(); + assert(std::is_sorted(Cache.begin(), Cache.begin() + Count) && + "Cache isn't sorted!"); +} +#endif + +const MemoryDependenceResults::NonLocalDepInfo & +MemoryDependenceResults::getNonLocalCallDependency(CallBase *QueryCall) { + assert(getDependency(QueryCall).isNonLocal() && + "getNonLocalCallDependency should only be used on calls with " + "non-local deps!"); + PerInstNLInfo &CacheP = NonLocalDeps[QueryCall]; + NonLocalDepInfo &Cache = CacheP.first; + + // This is the set of blocks that need to be recomputed. In the cached case, + // this can happen due to instructions being deleted etc. In the uncached + // case, this starts out as the set of predecessors we care about. + SmallVector<BasicBlock *, 32> DirtyBlocks; + + if (!Cache.empty()) { + // Okay, we have a cache entry. If we know it is not dirty, just return it + // with no computation. + if (!CacheP.second) { + ++NumCacheNonLocal; + return Cache; + } + + // If we already have a partially computed set of results, scan them to + // determine what is dirty, seeding our initial DirtyBlocks worklist. + for (auto &Entry : Cache) + if (Entry.getResult().isDirty()) + DirtyBlocks.push_back(Entry.getBB()); + + // Sort the cache so that we can do fast binary search lookups below. + llvm::sort(Cache); + + ++NumCacheDirtyNonLocal; + // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " + // << Cache.size() << " cached: " << *QueryInst; + } else { + // Seed DirtyBlocks with each of the preds of QueryInst's block. + BasicBlock *QueryBB = QueryCall->getParent(); + for (BasicBlock *Pred : PredCache.get(QueryBB)) + DirtyBlocks.push_back(Pred); + ++NumUncacheNonLocal; + } + + // isReadonlyCall - If this is a read-only call, we can be more aggressive. + bool isReadonlyCall = AA.onlyReadsMemory(QueryCall); + + SmallPtrSet<BasicBlock *, 32> Visited; + + unsigned NumSortedEntries = Cache.size(); + LLVM_DEBUG(AssertSorted(Cache)); + + // Iterate while we still have blocks to update. + while (!DirtyBlocks.empty()) { + BasicBlock *DirtyBB = DirtyBlocks.back(); + DirtyBlocks.pop_back(); + + // Already processed this block? + if (!Visited.insert(DirtyBB).second) + continue; + + // Do a binary search to see if we already have an entry for this block in + // the cache set. If so, find it. + LLVM_DEBUG(AssertSorted(Cache, NumSortedEntries)); + NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.begin() + NumSortedEntries, + NonLocalDepEntry(DirtyBB)); + if (Entry != Cache.begin() && std::prev(Entry)->getBB() == DirtyBB) + --Entry; + + NonLocalDepEntry *ExistingResult = nullptr; + if (Entry != Cache.begin() + NumSortedEntries && + Entry->getBB() == DirtyBB) { + // If we already have an entry, and if it isn't already dirty, the block + // is done. + if (!Entry->getResult().isDirty()) + continue; + + // Otherwise, remember this slot so we can update the value. + ExistingResult = &*Entry; + } + + // If the dirty entry has a pointer, start scanning from it so we don't have + // to rescan the entire block. + BasicBlock::iterator ScanPos = DirtyBB->end(); + if (ExistingResult) { + if (Instruction *Inst = ExistingResult->getResult().getInst()) { + ScanPos = Inst->getIterator(); + // We're removing QueryInst's use of Inst. + RemoveFromReverseMap<Instruction *>(ReverseNonLocalDeps, Inst, + QueryCall); + } + } + + // Find out if this block has a local dependency for QueryInst. + MemDepResult Dep; + + if (ScanPos != DirtyBB->begin()) { + Dep = getCallDependencyFrom(QueryCall, isReadonlyCall, ScanPos, DirtyBB); + } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) { + // No dependence found. If this is the entry block of the function, it is + // a clobber, otherwise it is unknown. + Dep = MemDepResult::getNonLocal(); + } else { + Dep = MemDepResult::getNonFuncLocal(); + } + + // If we had a dirty entry for the block, update it. Otherwise, just add + // a new entry. + if (ExistingResult) + ExistingResult->setResult(Dep); + else + Cache.push_back(NonLocalDepEntry(DirtyBB, Dep)); + + // If the block has a dependency (i.e. it isn't completely transparent to + // the value), remember the association! + if (!Dep.isNonLocal()) { + // Keep the ReverseNonLocalDeps map up to date so we can efficiently + // update this when we remove instructions. + if (Instruction *Inst = Dep.getInst()) + ReverseNonLocalDeps[Inst].insert(QueryCall); + } else { + + // If the block *is* completely transparent to the load, we need to check + // the predecessors of this block. Add them to our worklist. + for (BasicBlock *Pred : PredCache.get(DirtyBB)) + DirtyBlocks.push_back(Pred); + } + } + + return Cache; +} + +void MemoryDependenceResults::getNonLocalPointerDependency( + Instruction *QueryInst, SmallVectorImpl<NonLocalDepResult> &Result) { + const MemoryLocation Loc = MemoryLocation::get(QueryInst); + bool isLoad = isa<LoadInst>(QueryInst); + BasicBlock *FromBB = QueryInst->getParent(); + assert(FromBB); + + assert(Loc.Ptr->getType()->isPointerTy() && + "Can't get pointer deps of a non-pointer!"); + Result.clear(); + { + // Check if there is cached Def with invariant.group. + auto NonLocalDefIt = NonLocalDefsCache.find(QueryInst); + if (NonLocalDefIt != NonLocalDefsCache.end()) { + Result.push_back(NonLocalDefIt->second); + ReverseNonLocalDefsCache[NonLocalDefIt->second.getResult().getInst()] + .erase(QueryInst); + NonLocalDefsCache.erase(NonLocalDefIt); + return; + } + } + // This routine does not expect to deal with volatile instructions. + // Doing so would require piping through the QueryInst all the way through. + // TODO: volatiles can't be elided, but they can be reordered with other + // non-volatile accesses. + + // We currently give up on any instruction which is ordered, but we do handle + // atomic instructions which are unordered. + // TODO: Handle ordered instructions + auto isOrdered = [](Instruction *Inst) { + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return !LI->isUnordered(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return !SI->isUnordered(); + } + return false; + }; + if (isVolatile(QueryInst) || isOrdered(QueryInst)) { + Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), + const_cast<Value *>(Loc.Ptr))); + return; + } + const DataLayout &DL = FromBB->getModule()->getDataLayout(); + PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, &AC); + + // This is the set of blocks we've inspected, and the pointer we consider in + // each block. Because of critical edges, we currently bail out if querying + // a block with multiple different pointers. This can happen during PHI + // translation. + DenseMap<BasicBlock *, Value *> Visited; + if (getNonLocalPointerDepFromBB(QueryInst, Address, Loc, isLoad, FromBB, + Result, Visited, true)) + return; + Result.clear(); + Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), + const_cast<Value *>(Loc.Ptr))); +} + +/// Compute the memdep value for BB with Pointer/PointeeSize using either +/// cached information in Cache or by doing a lookup (which may use dirty cache +/// info if available). +/// +/// If we do a lookup, add the result to the cache. +MemDepResult MemoryDependenceResults::GetNonLocalInfoForBlock( + Instruction *QueryInst, const MemoryLocation &Loc, bool isLoad, + BasicBlock *BB, NonLocalDepInfo *Cache, unsigned NumSortedEntries) { + + // Do a binary search to see if we already have an entry for this block in + // the cache set. If so, find it. + NonLocalDepInfo::iterator Entry = std::upper_bound( + Cache->begin(), Cache->begin() + NumSortedEntries, NonLocalDepEntry(BB)); + if (Entry != Cache->begin() && (Entry - 1)->getBB() == BB) + --Entry; + + NonLocalDepEntry *ExistingResult = nullptr; + if (Entry != Cache->begin() + NumSortedEntries && Entry->getBB() == BB) + ExistingResult = &*Entry; + + // If we have a cached entry, and it is non-dirty, use it as the value for + // this dependency. + if (ExistingResult && !ExistingResult->getResult().isDirty()) { + ++NumCacheNonLocalPtr; + return ExistingResult->getResult(); + } + + // Otherwise, we have to scan for the value. If we have a dirty cache + // entry, start scanning from its position, otherwise we scan from the end + // of the block. + BasicBlock::iterator ScanPos = BB->end(); + if (ExistingResult && ExistingResult->getResult().getInst()) { + assert(ExistingResult->getResult().getInst()->getParent() == BB && + "Instruction invalidated?"); + ++NumCacheDirtyNonLocalPtr; + ScanPos = ExistingResult->getResult().getInst()->getIterator(); + + // Eliminating the dirty entry from 'Cache', so update the reverse info. + ValueIsLoadPair CacheKey(Loc.Ptr, isLoad); + RemoveFromReverseMap(ReverseNonLocalPtrDeps, &*ScanPos, CacheKey); + } else { + ++NumUncacheNonLocalPtr; + } + + // Scan the block for the dependency. + MemDepResult Dep = + getPointerDependencyFrom(Loc, isLoad, ScanPos, BB, QueryInst); + + // If we had a dirty entry for the block, update it. Otherwise, just add + // a new entry. + if (ExistingResult) + ExistingResult->setResult(Dep); + else + Cache->push_back(NonLocalDepEntry(BB, Dep)); + + // If the block has a dependency (i.e. it isn't completely transparent to + // the value), remember the reverse association because we just added it + // to Cache! + if (!Dep.isDef() && !Dep.isClobber()) + return Dep; + + // Keep the ReverseNonLocalPtrDeps map up to date so we can efficiently + // update MemDep when we remove instructions. + Instruction *Inst = Dep.getInst(); + assert(Inst && "Didn't depend on anything?"); + ValueIsLoadPair CacheKey(Loc.Ptr, isLoad); + ReverseNonLocalPtrDeps[Inst].insert(CacheKey); + return Dep; +} + +/// Sort the NonLocalDepInfo cache, given a certain number of elements in the +/// array that are already properly ordered. +/// +/// This is optimized for the case when only a few entries are added. +static void +SortNonLocalDepInfoCache(MemoryDependenceResults::NonLocalDepInfo &Cache, + unsigned NumSortedEntries) { + switch (Cache.size() - NumSortedEntries) { + case 0: + // done, no new entries. + break; + case 2: { + // Two new entries, insert the last one into place. + NonLocalDepEntry Val = Cache.back(); + Cache.pop_back(); + MemoryDependenceResults::NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.end() - 1, Val); + Cache.insert(Entry, Val); + LLVM_FALLTHROUGH; + } + case 1: + // One new entry, Just insert the new value at the appropriate position. + if (Cache.size() != 1) { + NonLocalDepEntry Val = Cache.back(); + Cache.pop_back(); + MemoryDependenceResults::NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.end(), Val); + Cache.insert(Entry, Val); + } + break; + default: + // Added many values, do a full scale sort. + llvm::sort(Cache); + break; + } +} + +/// Perform a dependency query based on pointer/pointeesize starting at the end +/// of StartBB. +/// +/// Add any clobber/def results to the results vector and keep track of which +/// blocks are visited in 'Visited'. +/// +/// This has special behavior for the first block queries (when SkipFirstBlock +/// is true). In this special case, it ignores the contents of the specified +/// block and starts returning dependence info for its predecessors. +/// +/// This function returns true on success, or false to indicate that it could +/// not compute dependence information for some reason. This should be treated +/// as a clobber dependence on the first instruction in the predecessor block. +bool MemoryDependenceResults::getNonLocalPointerDepFromBB( + Instruction *QueryInst, const PHITransAddr &Pointer, + const MemoryLocation &Loc, bool isLoad, BasicBlock *StartBB, + SmallVectorImpl<NonLocalDepResult> &Result, + DenseMap<BasicBlock *, Value *> &Visited, bool SkipFirstBlock) { + // Look up the cached info for Pointer. + ValueIsLoadPair CacheKey(Pointer.getAddr(), isLoad); + + // Set up a temporary NLPI value. If the map doesn't yet have an entry for + // CacheKey, this value will be inserted as the associated value. Otherwise, + // it'll be ignored, and we'll have to check to see if the cached size and + // aa tags are consistent with the current query. + NonLocalPointerInfo InitialNLPI; + InitialNLPI.Size = Loc.Size; + InitialNLPI.AATags = Loc.AATags; + + // Get the NLPI for CacheKey, inserting one into the map if it doesn't + // already have one. + std::pair<CachedNonLocalPointerInfo::iterator, bool> Pair = + NonLocalPointerDeps.insert(std::make_pair(CacheKey, InitialNLPI)); + NonLocalPointerInfo *CacheInfo = &Pair.first->second; + + // If we already have a cache entry for this CacheKey, we may need to do some + // work to reconcile the cache entry and the current query. + if (!Pair.second) { + if (CacheInfo->Size != Loc.Size) { + bool ThrowOutEverything; + if (CacheInfo->Size.hasValue() && Loc.Size.hasValue()) { + // FIXME: We may be able to do better in the face of results with mixed + // precision. We don't appear to get them in practice, though, so just + // be conservative. + ThrowOutEverything = + CacheInfo->Size.isPrecise() != Loc.Size.isPrecise() || + CacheInfo->Size.getValue() < Loc.Size.getValue(); + } else { + // For our purposes, unknown size > all others. + ThrowOutEverything = !Loc.Size.hasValue(); + } + + if (ThrowOutEverything) { + // The query's Size is greater than the cached one. Throw out the + // cached data and proceed with the query at the greater size. + CacheInfo->Pair = BBSkipFirstBlockPair(); + CacheInfo->Size = Loc.Size; + for (auto &Entry : CacheInfo->NonLocalDeps) + if (Instruction *Inst = Entry.getResult().getInst()) + RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); + CacheInfo->NonLocalDeps.clear(); + } else { + // This query's Size is less than the cached one. Conservatively restart + // the query using the greater size. + return getNonLocalPointerDepFromBB( + QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad, + StartBB, Result, Visited, SkipFirstBlock); + } + } + + // If the query's AATags are inconsistent with the cached one, + // conservatively throw out the cached data and restart the query with + // no tag if needed. + if (CacheInfo->AATags != Loc.AATags) { + if (CacheInfo->AATags) { + CacheInfo->Pair = BBSkipFirstBlockPair(); + CacheInfo->AATags = AAMDNodes(); + for (auto &Entry : CacheInfo->NonLocalDeps) + if (Instruction *Inst = Entry.getResult().getInst()) + RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); + CacheInfo->NonLocalDeps.clear(); + } + if (Loc.AATags) + return getNonLocalPointerDepFromBB( + QueryInst, Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result, + Visited, SkipFirstBlock); + } + } + + NonLocalDepInfo *Cache = &CacheInfo->NonLocalDeps; + + // If we have valid cached information for exactly the block we are + // investigating, just return it with no recomputation. + if (CacheInfo->Pair == BBSkipFirstBlockPair(StartBB, SkipFirstBlock)) { + // We have a fully cached result for this query then we can just return the + // cached results and populate the visited set. However, we have to verify + // that we don't already have conflicting results for these blocks. Check + // to ensure that if a block in the results set is in the visited set that + // it was for the same pointer query. + if (!Visited.empty()) { + for (auto &Entry : *Cache) { + DenseMap<BasicBlock *, Value *>::iterator VI = + Visited.find(Entry.getBB()); + if (VI == Visited.end() || VI->second == Pointer.getAddr()) + continue; + + // We have a pointer mismatch in a block. Just return false, saying + // that something was clobbered in this result. We could also do a + // non-fully cached query, but there is little point in doing this. + return false; + } + } + + Value *Addr = Pointer.getAddr(); + for (auto &Entry : *Cache) { + Visited.insert(std::make_pair(Entry.getBB(), Addr)); + if (Entry.getResult().isNonLocal()) { + continue; + } + + if (DT.isReachableFromEntry(Entry.getBB())) { + Result.push_back( + NonLocalDepResult(Entry.getBB(), Entry.getResult(), Addr)); + } + } + ++NumCacheCompleteNonLocalPtr; + return true; + } + + // Otherwise, either this is a new block, a block with an invalid cache + // pointer or one that we're about to invalidate by putting more info into it + // than its valid cache info. If empty, the result will be valid cache info, + // otherwise it isn't. + if (Cache->empty()) + CacheInfo->Pair = BBSkipFirstBlockPair(StartBB, SkipFirstBlock); + else + CacheInfo->Pair = BBSkipFirstBlockPair(); + + SmallVector<BasicBlock *, 32> Worklist; + Worklist.push_back(StartBB); + + // PredList used inside loop. + SmallVector<std::pair<BasicBlock *, PHITransAddr>, 16> PredList; + + // Keep track of the entries that we know are sorted. Previously cached + // entries will all be sorted. The entries we add we only sort on demand (we + // don't insert every element into its sorted position). We know that we + // won't get any reuse from currently inserted values, because we don't + // revisit blocks after we insert info for them. + unsigned NumSortedEntries = Cache->size(); + unsigned WorklistEntries = BlockNumberLimit; + bool GotWorklistLimit = false; + LLVM_DEBUG(AssertSorted(*Cache)); + + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + + // If we do process a large number of blocks it becomes very expensive and + // likely it isn't worth worrying about + if (Result.size() > NumResultsLimit) { + Worklist.clear(); + // Sort it now (if needed) so that recursive invocations of + // getNonLocalPointerDepFromBB and other routines that could reuse the + // cache value will only see properly sorted cache arrays. + if (Cache && NumSortedEntries != Cache->size()) { + SortNonLocalDepInfoCache(*Cache, NumSortedEntries); + } + // Since we bail out, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set". Clear out the indicator for this. + CacheInfo->Pair = BBSkipFirstBlockPair(); + return false; + } + + // Skip the first block if we have it. + if (!SkipFirstBlock) { + // Analyze the dependency of *Pointer in FromBB. See if we already have + // been here. + assert(Visited.count(BB) && "Should check 'visited' before adding to WL"); + + // Get the dependency info for Pointer in BB. If we have cached + // information, we will use it, otherwise we compute it. + LLVM_DEBUG(AssertSorted(*Cache, NumSortedEntries)); + MemDepResult Dep = GetNonLocalInfoForBlock(QueryInst, Loc, isLoad, BB, + Cache, NumSortedEntries); + + // If we got a Def or Clobber, add this to the list of results. + if (!Dep.isNonLocal()) { + if (DT.isReachableFromEntry(BB)) { + Result.push_back(NonLocalDepResult(BB, Dep, Pointer.getAddr())); + continue; + } + } + } + + // If 'Pointer' is an instruction defined in this block, then we need to do + // phi translation to change it into a value live in the predecessor block. + // If not, we just add the predecessors to the worklist and scan them with + // the same Pointer. + if (!Pointer.NeedsPHITranslationFromBlock(BB)) { + SkipFirstBlock = false; + SmallVector<BasicBlock *, 16> NewBlocks; + for (BasicBlock *Pred : PredCache.get(BB)) { + // Verify that we haven't looked at this block yet. + std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = + Visited.insert(std::make_pair(Pred, Pointer.getAddr())); + if (InsertRes.second) { + // First time we've looked at *PI. + NewBlocks.push_back(Pred); + continue; + } + + // If we have seen this block before, but it was with a different + // pointer then we have a phi translation failure and we have to treat + // this as a clobber. + if (InsertRes.first->second != Pointer.getAddr()) { + // Make sure to clean up the Visited map before continuing on to + // PredTranslationFailure. + for (unsigned i = 0; i < NewBlocks.size(); i++) + Visited.erase(NewBlocks[i]); + goto PredTranslationFailure; + } + } + if (NewBlocks.size() > WorklistEntries) { + // Make sure to clean up the Visited map before continuing on to + // PredTranslationFailure. + for (unsigned i = 0; i < NewBlocks.size(); i++) + Visited.erase(NewBlocks[i]); + GotWorklistLimit = true; + goto PredTranslationFailure; + } + WorklistEntries -= NewBlocks.size(); + Worklist.append(NewBlocks.begin(), NewBlocks.end()); + continue; + } + + // We do need to do phi translation, if we know ahead of time we can't phi + // translate this value, don't even try. + if (!Pointer.IsPotentiallyPHITranslatable()) + goto PredTranslationFailure; + + // We may have added values to the cache list before this PHI translation. + // If so, we haven't done anything to ensure that the cache remains sorted. + // Sort it now (if needed) so that recursive invocations of + // getNonLocalPointerDepFromBB and other routines that could reuse the cache + // value will only see properly sorted cache arrays. + if (Cache && NumSortedEntries != Cache->size()) { + SortNonLocalDepInfoCache(*Cache, NumSortedEntries); + NumSortedEntries = Cache->size(); + } + Cache = nullptr; + + PredList.clear(); + for (BasicBlock *Pred : PredCache.get(BB)) { + PredList.push_back(std::make_pair(Pred, Pointer)); + + // Get the PHI translated pointer in this predecessor. This can fail if + // not translatable, in which case the getAddr() returns null. + PHITransAddr &PredPointer = PredList.back().second; + PredPointer.PHITranslateValue(BB, Pred, &DT, /*MustDominate=*/false); + Value *PredPtrVal = PredPointer.getAddr(); + + // Check to see if we have already visited this pred block with another + // pointer. If so, we can't do this lookup. This failure can occur + // with PHI translation when a critical edge exists and the PHI node in + // the successor translates to a pointer value different than the + // pointer the block was first analyzed with. + std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = + Visited.insert(std::make_pair(Pred, PredPtrVal)); + + if (!InsertRes.second) { + // We found the pred; take it off the list of preds to visit. + PredList.pop_back(); + + // If the predecessor was visited with PredPtr, then we already did + // the analysis and can ignore it. + if (InsertRes.first->second == PredPtrVal) + continue; + + // Otherwise, the block was previously analyzed with a different + // pointer. We can't represent the result of this case, so we just + // treat this as a phi translation failure. + + // Make sure to clean up the Visited map before continuing on to + // PredTranslationFailure. + for (unsigned i = 0, n = PredList.size(); i < n; ++i) + Visited.erase(PredList[i].first); + + goto PredTranslationFailure; + } + } + + // Actually process results here; this need to be a separate loop to avoid + // calling getNonLocalPointerDepFromBB for blocks we don't want to return + // any results for. (getNonLocalPointerDepFromBB will modify our + // datastructures in ways the code after the PredTranslationFailure label + // doesn't expect.) + for (unsigned i = 0, n = PredList.size(); i < n; ++i) { + BasicBlock *Pred = PredList[i].first; + PHITransAddr &PredPointer = PredList[i].second; + Value *PredPtrVal = PredPointer.getAddr(); + + bool CanTranslate = true; + // If PHI translation was unable to find an available pointer in this + // predecessor, then we have to assume that the pointer is clobbered in + // that predecessor. We can still do PRE of the load, which would insert + // a computation of the pointer in this predecessor. + if (!PredPtrVal) + CanTranslate = false; + + // FIXME: it is entirely possible that PHI translating will end up with + // the same value. Consider PHI translating something like: + // X = phi [x, bb1], [y, bb2]. PHI translating for bb1 doesn't *need* + // to recurse here, pedantically speaking. + + // If getNonLocalPointerDepFromBB fails here, that means the cached + // result conflicted with the Visited list; we have to conservatively + // assume it is unknown, but this also does not block PRE of the load. + if (!CanTranslate || + !getNonLocalPointerDepFromBB(QueryInst, PredPointer, + Loc.getWithNewPtr(PredPtrVal), isLoad, + Pred, Result, Visited)) { + // Add the entry to the Result list. + NonLocalDepResult Entry(Pred, MemDepResult::getUnknown(), PredPtrVal); + Result.push_back(Entry); + + // Since we had a phi translation failure, the cache for CacheKey won't + // include all of the entries that we need to immediately satisfy future + // queries. Mark this in NonLocalPointerDeps by setting the + // BBSkipFirstBlockPair pointer to null. This requires reuse of the + // cached value to do more work but not miss the phi trans failure. + NonLocalPointerInfo &NLPI = NonLocalPointerDeps[CacheKey]; + NLPI.Pair = BBSkipFirstBlockPair(); + continue; + } + } + + // Refresh the CacheInfo/Cache pointer so that it isn't invalidated. + CacheInfo = &NonLocalPointerDeps[CacheKey]; + Cache = &CacheInfo->NonLocalDeps; + NumSortedEntries = Cache->size(); + + // Since we did phi translation, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set" Clear out the indicator for this. + CacheInfo->Pair = BBSkipFirstBlockPair(); + SkipFirstBlock = false; + continue; + + PredTranslationFailure: + // The following code is "failure"; we can't produce a sane translation + // for the given block. It assumes that we haven't modified any of + // our datastructures while processing the current block. + + if (!Cache) { + // Refresh the CacheInfo/Cache pointer if it got invalidated. + CacheInfo = &NonLocalPointerDeps[CacheKey]; + Cache = &CacheInfo->NonLocalDeps; + NumSortedEntries = Cache->size(); + } + + // Since we failed phi translation, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set". Clear out the indicator for this. + CacheInfo->Pair = BBSkipFirstBlockPair(); + + // If *nothing* works, mark the pointer as unknown. + // + // If this is the magic first block, return this as a clobber of the whole + // incoming value. Since we can't phi translate to one of the predecessors, + // we have to bail out. + if (SkipFirstBlock) + return false; + + bool foundBlock = false; + for (NonLocalDepEntry &I : llvm::reverse(*Cache)) { + if (I.getBB() != BB) + continue; + + assert((GotWorklistLimit || I.getResult().isNonLocal() || + !DT.isReachableFromEntry(BB)) && + "Should only be here with transparent block"); + foundBlock = true; + I.setResult(MemDepResult::getUnknown()); + Result.push_back( + NonLocalDepResult(I.getBB(), I.getResult(), Pointer.getAddr())); + break; + } + (void)foundBlock; (void)GotWorklistLimit; + assert((foundBlock || GotWorklistLimit) && "Current block not in cache?"); + } + + // Okay, we're done now. If we added new values to the cache, re-sort it. + SortNonLocalDepInfoCache(*Cache, NumSortedEntries); + LLVM_DEBUG(AssertSorted(*Cache)); + return true; +} + +/// If P exists in CachedNonLocalPointerInfo or NonLocalDefsCache, remove it. +void MemoryDependenceResults::RemoveCachedNonLocalPointerDependencies( + ValueIsLoadPair P) { + + // Most of the time this cache is empty. + if (!NonLocalDefsCache.empty()) { + auto it = NonLocalDefsCache.find(P.getPointer()); + if (it != NonLocalDefsCache.end()) { + RemoveFromReverseMap(ReverseNonLocalDefsCache, + it->second.getResult().getInst(), P.getPointer()); + NonLocalDefsCache.erase(it); + } + + if (auto *I = dyn_cast<Instruction>(P.getPointer())) { + auto toRemoveIt = ReverseNonLocalDefsCache.find(I); + if (toRemoveIt != ReverseNonLocalDefsCache.end()) { + for (const auto &entry : toRemoveIt->second) + NonLocalDefsCache.erase(entry); + ReverseNonLocalDefsCache.erase(toRemoveIt); + } + } + } + + CachedNonLocalPointerInfo::iterator It = NonLocalPointerDeps.find(P); + if (It == NonLocalPointerDeps.end()) + return; + + // Remove all of the entries in the BB->val map. This involves removing + // instructions from the reverse map. + NonLocalDepInfo &PInfo = It->second.NonLocalDeps; + + for (unsigned i = 0, e = PInfo.size(); i != e; ++i) { + Instruction *Target = PInfo[i].getResult().getInst(); + if (!Target) + continue; // Ignore non-local dep results. + assert(Target->getParent() == PInfo[i].getBB()); + + // Eliminating the dirty entry from 'Cache', so update the reverse info. + RemoveFromReverseMap(ReverseNonLocalPtrDeps, Target, P); + } + + // Remove P from NonLocalPointerDeps (which deletes NonLocalDepInfo). + NonLocalPointerDeps.erase(It); +} + +void MemoryDependenceResults::invalidateCachedPointerInfo(Value *Ptr) { + // If Ptr isn't really a pointer, just ignore it. + if (!Ptr->getType()->isPointerTy()) + return; + // Flush store info for the pointer. + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, false)); + // Flush load info for the pointer. + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, true)); + // Invalidate phis that use the pointer. + PV.invalidateValue(Ptr); +} + +void MemoryDependenceResults::invalidateCachedPredecessors() { + PredCache.clear(); +} + +void MemoryDependenceResults::removeInstruction(Instruction *RemInst) { + // Walk through the Non-local dependencies, removing this one as the value + // for any cached queries. + NonLocalDepMapType::iterator NLDI = NonLocalDeps.find(RemInst); + if (NLDI != NonLocalDeps.end()) { + NonLocalDepInfo &BlockMap = NLDI->second.first; + for (auto &Entry : BlockMap) + if (Instruction *Inst = Entry.getResult().getInst()) + RemoveFromReverseMap(ReverseNonLocalDeps, Inst, RemInst); + NonLocalDeps.erase(NLDI); + } + + // If we have a cached local dependence query for this instruction, remove it. + LocalDepMapType::iterator LocalDepEntry = LocalDeps.find(RemInst); + if (LocalDepEntry != LocalDeps.end()) { + // Remove us from DepInst's reverse set now that the local dep info is gone. + if (Instruction *Inst = LocalDepEntry->second.getInst()) + RemoveFromReverseMap(ReverseLocalDeps, Inst, RemInst); + + // Remove this local dependency info. + LocalDeps.erase(LocalDepEntry); + } + + // If we have any cached pointer dependencies on this instruction, remove + // them. If the instruction has non-pointer type, then it can't be a pointer + // base. + + // Remove it from both the load info and the store info. The instruction + // can't be in either of these maps if it is non-pointer. + if (RemInst->getType()->isPointerTy()) { + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, false)); + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, true)); + } + + // Loop over all of the things that depend on the instruction we're removing. + SmallVector<std::pair<Instruction *, Instruction *>, 8> ReverseDepsToAdd; + + // If we find RemInst as a clobber or Def in any of the maps for other values, + // we need to replace its entry with a dirty version of the instruction after + // it. If RemInst is a terminator, we use a null dirty value. + // + // Using a dirty version of the instruction after RemInst saves having to scan + // the entire block to get to this point. + MemDepResult NewDirtyVal; + if (!RemInst->isTerminator()) + NewDirtyVal = MemDepResult::getDirty(&*++RemInst->getIterator()); + + ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); + if (ReverseDepIt != ReverseLocalDeps.end()) { + // RemInst can't be the terminator if it has local stuff depending on it. + assert(!ReverseDepIt->second.empty() && !RemInst->isTerminator() && + "Nothing can locally depend on a terminator"); + + for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) { + assert(InstDependingOnRemInst != RemInst && + "Already removed our local dep info"); + + LocalDeps[InstDependingOnRemInst] = NewDirtyVal; + + // Make sure to remember that new things depend on NewDepInst. + assert(NewDirtyVal.getInst() && + "There is no way something else can have " + "a local dep on this if it is a terminator!"); + ReverseDepsToAdd.push_back( + std::make_pair(NewDirtyVal.getInst(), InstDependingOnRemInst)); + } + + ReverseLocalDeps.erase(ReverseDepIt); + + // Add new reverse deps after scanning the set, to avoid invalidating the + // 'ReverseDeps' reference. + while (!ReverseDepsToAdd.empty()) { + ReverseLocalDeps[ReverseDepsToAdd.back().first].insert( + ReverseDepsToAdd.back().second); + ReverseDepsToAdd.pop_back(); + } + } + + ReverseDepIt = ReverseNonLocalDeps.find(RemInst); + if (ReverseDepIt != ReverseNonLocalDeps.end()) { + for (Instruction *I : ReverseDepIt->second) { + assert(I != RemInst && "Already removed NonLocalDep info for RemInst"); + + PerInstNLInfo &INLD = NonLocalDeps[I]; + // The information is now dirty! + INLD.second = true; + + for (auto &Entry : INLD.first) { + if (Entry.getResult().getInst() != RemInst) + continue; + + // Convert to a dirty entry for the subsequent instruction. + Entry.setResult(NewDirtyVal); + + if (Instruction *NextI = NewDirtyVal.getInst()) + ReverseDepsToAdd.push_back(std::make_pair(NextI, I)); + } + } + + ReverseNonLocalDeps.erase(ReverseDepIt); + + // Add new reverse deps after scanning the set, to avoid invalidating 'Set' + while (!ReverseDepsToAdd.empty()) { + ReverseNonLocalDeps[ReverseDepsToAdd.back().first].insert( + ReverseDepsToAdd.back().second); + ReverseDepsToAdd.pop_back(); + } + } + + // If the instruction is in ReverseNonLocalPtrDeps then it appears as a + // value in the NonLocalPointerDeps info. + ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = + ReverseNonLocalPtrDeps.find(RemInst); + if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { + SmallVector<std::pair<Instruction *, ValueIsLoadPair>, 8> + ReversePtrDepsToAdd; + + for (ValueIsLoadPair P : ReversePtrDepIt->second) { + assert(P.getPointer() != RemInst && + "Already removed NonLocalPointerDeps info for RemInst"); + + NonLocalDepInfo &NLPDI = NonLocalPointerDeps[P].NonLocalDeps; + + // The cache is not valid for any specific block anymore. + NonLocalPointerDeps[P].Pair = BBSkipFirstBlockPair(); + + // Update any entries for RemInst to use the instruction after it. + for (auto &Entry : NLPDI) { + if (Entry.getResult().getInst() != RemInst) + continue; + + // Convert to a dirty entry for the subsequent instruction. + Entry.setResult(NewDirtyVal); + + if (Instruction *NewDirtyInst = NewDirtyVal.getInst()) + ReversePtrDepsToAdd.push_back(std::make_pair(NewDirtyInst, P)); + } + + // Re-sort the NonLocalDepInfo. Changing the dirty entry to its + // subsequent value may invalidate the sortedness. + llvm::sort(NLPDI); + } + + ReverseNonLocalPtrDeps.erase(ReversePtrDepIt); + + while (!ReversePtrDepsToAdd.empty()) { + ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first].insert( + ReversePtrDepsToAdd.back().second); + ReversePtrDepsToAdd.pop_back(); + } + } + + // Invalidate phis that use the removed instruction. + PV.invalidateValue(RemInst); + + assert(!NonLocalDeps.count(RemInst) && "RemInst got reinserted?"); + LLVM_DEBUG(verifyRemoved(RemInst)); +} + +/// Verify that the specified instruction does not occur in our internal data +/// structures. +/// +/// This function verifies by asserting in debug builds. +void MemoryDependenceResults::verifyRemoved(Instruction *D) const { +#ifndef NDEBUG + for (const auto &DepKV : LocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + assert(DepKV.second.getInst() != D && "Inst occurs in data structures"); + } + + for (const auto &DepKV : NonLocalPointerDeps) { + assert(DepKV.first.getPointer() != D && "Inst occurs in NLPD map key"); + for (const auto &Entry : DepKV.second.NonLocalDeps) + assert(Entry.getResult().getInst() != D && "Inst occurs as NLPD value"); + } + + for (const auto &DepKV : NonLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + const PerInstNLInfo &INLD = DepKV.second; + for (const auto &Entry : INLD.first) + assert(Entry.getResult().getInst() != D && + "Inst occurs in data structures"); + } + + for (const auto &DepKV : ReverseLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + for (Instruction *Inst : DepKV.second) + assert(Inst != D && "Inst occurs in data structures"); + } + + for (const auto &DepKV : ReverseNonLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + for (Instruction *Inst : DepKV.second) + assert(Inst != D && "Inst occurs in data structures"); + } + + for (const auto &DepKV : ReverseNonLocalPtrDeps) { + assert(DepKV.first != D && "Inst occurs in rev NLPD map"); + + for (ValueIsLoadPair P : DepKV.second) + assert(P != ValueIsLoadPair(D, false) && P != ValueIsLoadPair(D, true) && + "Inst occurs in ReverseNonLocalPtrDeps map"); + } +#endif +} + +AnalysisKey MemoryDependenceAnalysis::Key; + +MemoryDependenceAnalysis::MemoryDependenceAnalysis() + : DefaultBlockScanLimit(BlockScanLimit) {} + +MemoryDependenceResults +MemoryDependenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &PV = AM.getResult<PhiValuesAnalysis>(F); + return MemoryDependenceResults(AA, AC, TLI, DT, PV, DefaultBlockScanLimit); +} + +char MemoryDependenceWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(MemoryDependenceWrapperPass, "memdep", + "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PhiValuesWrapperPass) +INITIALIZE_PASS_END(MemoryDependenceWrapperPass, "memdep", + "Memory Dependence Analysis", false, true) + +MemoryDependenceWrapperPass::MemoryDependenceWrapperPass() : FunctionPass(ID) { + initializeMemoryDependenceWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +MemoryDependenceWrapperPass::~MemoryDependenceWrapperPass() = default; + +void MemoryDependenceWrapperPass::releaseMemory() { + MemDep.reset(); +} + +void MemoryDependenceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PhiValuesWrapperPass>(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); + AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); +} + +bool MemoryDependenceResults::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // Check whether our analysis is preserved. + auto PAC = PA.getChecker<MemoryDependenceAnalysis>(); + if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>()) + // If not, give up now. + return true; + + // Check whether the analyses we depend on became invalid for any reason. + if (Inv.invalidate<AAManager>(F, PA) || + Inv.invalidate<AssumptionAnalysis>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA) || + Inv.invalidate<PhiValuesAnalysis>(F, PA)) + return true; + + // Otherwise this analysis result remains valid. + return false; +} + +unsigned MemoryDependenceResults::getDefaultBlockScanLimit() const { + return DefaultBlockScanLimit; +} + +bool MemoryDependenceWrapperPass::runOnFunction(Function &F) { + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PV = getAnalysis<PhiValuesWrapperPass>().getResult(); + MemDep.emplace(AA, AC, TLI, DT, PV, BlockScanLimit); + return false; +} diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp new file mode 100644 index 000000000000..163830eee797 --- /dev/null +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -0,0 +1,211 @@ +//===- MemoryLocation.cpp - Memory location descriptions -------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +using namespace llvm; + +void LocationSize::print(raw_ostream &OS) const { + OS << "LocationSize::"; + if (*this == unknown()) + OS << "unknown"; + else if (*this == mapEmpty()) + OS << "mapEmpty"; + else if (*this == mapTombstone()) + OS << "mapTombstone"; + else if (isPrecise()) + OS << "precise(" << getValue() << ')'; + else + OS << "upperBound(" << getValue() << ')'; +} + +MemoryLocation MemoryLocation::get(const LoadInst *LI) { + AAMDNodes AATags; + LI->getAAMetadata(AATags); + const auto &DL = LI->getModule()->getDataLayout(); + + return MemoryLocation( + LI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize(LI->getType())), AATags); +} + +MemoryLocation MemoryLocation::get(const StoreInst *SI) { + AAMDNodes AATags; + SI->getAAMetadata(AATags); + const auto &DL = SI->getModule()->getDataLayout(); + + return MemoryLocation(SI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize( + SI->getValueOperand()->getType())), + AATags); +} + +MemoryLocation MemoryLocation::get(const VAArgInst *VI) { + AAMDNodes AATags; + VI->getAAMetadata(AATags); + + return MemoryLocation(VI->getPointerOperand(), LocationSize::unknown(), + AATags); +} + +MemoryLocation MemoryLocation::get(const AtomicCmpXchgInst *CXI) { + AAMDNodes AATags; + CXI->getAAMetadata(AATags); + const auto &DL = CXI->getModule()->getDataLayout(); + + return MemoryLocation(CXI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize( + CXI->getCompareOperand()->getType())), + AATags); +} + +MemoryLocation MemoryLocation::get(const AtomicRMWInst *RMWI) { + AAMDNodes AATags; + RMWI->getAAMetadata(AATags); + const auto &DL = RMWI->getModule()->getDataLayout(); + + return MemoryLocation(RMWI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize( + RMWI->getValOperand()->getType())), + AATags); +} + +MemoryLocation MemoryLocation::getForSource(const MemTransferInst *MTI) { + return getForSource(cast<AnyMemTransferInst>(MTI)); +} + +MemoryLocation MemoryLocation::getForSource(const AtomicMemTransferInst *MTI) { + return getForSource(cast<AnyMemTransferInst>(MTI)); +} + +MemoryLocation MemoryLocation::getForSource(const AnyMemTransferInst *MTI) { + auto Size = LocationSize::unknown(); + if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) + Size = LocationSize::precise(C->getValue().getZExtValue()); + + // memcpy/memmove can have AA tags. For memcpy, they apply + // to both the source and the destination. + AAMDNodes AATags; + MTI->getAAMetadata(AATags); + + return MemoryLocation(MTI->getRawSource(), Size, AATags); +} + +MemoryLocation MemoryLocation::getForDest(const MemIntrinsic *MI) { + return getForDest(cast<AnyMemIntrinsic>(MI)); +} + +MemoryLocation MemoryLocation::getForDest(const AtomicMemIntrinsic *MI) { + return getForDest(cast<AnyMemIntrinsic>(MI)); +} + +MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) { + auto Size = LocationSize::unknown(); + if (ConstantInt *C = dyn_cast<ConstantInt>(MI->getLength())) + Size = LocationSize::precise(C->getValue().getZExtValue()); + + // memcpy/memmove can have AA tags. For memcpy, they apply + // to both the source and the destination. + AAMDNodes AATags; + MI->getAAMetadata(AATags); + + return MemoryLocation(MI->getRawDest(), Size, AATags); +} + +MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, + unsigned ArgIdx, + const TargetLibraryInfo *TLI) { + AAMDNodes AATags; + Call->getAAMetadata(AATags); + const Value *Arg = Call->getArgOperand(ArgIdx); + + // We may be able to produce an exact size for known intrinsics. + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call)) { + const DataLayout &DL = II->getModule()->getDataLayout(); + + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::memset: + case Intrinsic::memcpy: + case Intrinsic::memmove: + assert((ArgIdx == 0 || ArgIdx == 1) && + "Invalid argument index for memory intrinsic"); + if (ConstantInt *LenCI = dyn_cast<ConstantInt>(II->getArgOperand(2))) + return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()), + AATags); + break; + + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + assert(ArgIdx == 1 && "Invalid argument index"); + return MemoryLocation( + Arg, + LocationSize::precise( + cast<ConstantInt>(II->getArgOperand(0))->getZExtValue()), + AATags); + + case Intrinsic::invariant_end: + // The first argument to an invariant.end is a "descriptor" type (e.g. a + // pointer to a empty struct) which is never actually dereferenced. + if (ArgIdx == 0) + return MemoryLocation(Arg, LocationSize::precise(0), AATags); + assert(ArgIdx == 2 && "Invalid argument index"); + return MemoryLocation( + Arg, + LocationSize::precise( + cast<ConstantInt>(II->getArgOperand(1))->getZExtValue()), + AATags); + + case Intrinsic::arm_neon_vld1: + assert(ArgIdx == 0 && "Invalid argument index"); + // LLVM's vld1 and vst1 intrinsics currently only support a single + // vector register. + return MemoryLocation( + Arg, LocationSize::precise(DL.getTypeStoreSize(II->getType())), + AATags); + + case Intrinsic::arm_neon_vst1: + assert(ArgIdx == 0 && "Invalid argument index"); + return MemoryLocation(Arg, + LocationSize::precise(DL.getTypeStoreSize( + II->getArgOperand(1)->getType())), + AATags); + } + } + + // We can bound the aliasing properties of memset_pattern16 just as we can + // for memcpy/memset. This is particularly important because the + // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 + // whenever possible. + LibFunc F; + if (TLI && Call->getCalledFunction() && + TLI->getLibFunc(*Call->getCalledFunction(), F) && + F == LibFunc_memset_pattern16 && TLI->has(F)) { + assert((ArgIdx == 0 || ArgIdx == 1) && + "Invalid argument index for memset_pattern16"); + if (ArgIdx == 1) + return MemoryLocation(Arg, LocationSize::precise(16), AATags); + if (const ConstantInt *LenCI = + dyn_cast<ConstantInt>(Call->getArgOperand(2))) + return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()), + AATags); + } + // FIXME: Handle memset_pattern4 and memset_pattern8 also. + + return MemoryLocation(Call->getArgOperand(ArgIdx), LocationSize::unknown(), + AATags); +} diff --git a/llvm/lib/Analysis/MemorySSA.cpp b/llvm/lib/Analysis/MemorySSA.cpp new file mode 100644 index 000000000000..cfb8b7e7dcb5 --- /dev/null +++ b/llvm/lib/Analysis/MemorySSA.cpp @@ -0,0 +1,2475 @@ +//===- MemorySSA.cpp - Memory SSA Builder ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the MemorySSA class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstdlib> +#include <iterator> +#include <memory> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memoryssa" + +INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) + +INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) + +static cl::opt<unsigned> MaxCheckLimit( + "memssa-check-limit", cl::Hidden, cl::init(100), + cl::desc("The maximum number of stores/phis MemorySSA" + "will consider trying to walk past (default = 100)")); + +// Always verify MemorySSA if expensive checking is enabled. +#ifdef EXPENSIVE_CHECKS +bool llvm::VerifyMemorySSA = true; +#else +bool llvm::VerifyMemorySSA = false; +#endif +/// Enables memory ssa as a dependency for loop passes in legacy pass manager. +cl::opt<bool> llvm::EnableMSSALoopDependency( + "enable-mssa-loop-dependency", cl::Hidden, cl::init(true), + cl::desc("Enable MemorySSA dependency for loop pass manager")); + +static cl::opt<bool, true> + VerifyMemorySSAX("verify-memoryssa", cl::location(VerifyMemorySSA), + cl::Hidden, cl::desc("Enable verification of MemorySSA.")); + +namespace llvm { + +/// An assembly annotator class to print Memory SSA information in +/// comments. +class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { + friend class MemorySSA; + + const MemorySSA *MSSA; + +public: + MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} + + void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) override { + if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) + OS << "; " << *MA << "\n"; + } + + void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS) override { + if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) + OS << "; " << *MA << "\n"; + } +}; + +} // end namespace llvm + +namespace { + +/// Our current alias analysis API differentiates heavily between calls and +/// non-calls, and functions called on one usually assert on the other. +/// This class encapsulates the distinction to simplify other code that wants +/// "Memory affecting instructions and related data" to use as a key. +/// For example, this class is used as a densemap key in the use optimizer. +class MemoryLocOrCall { +public: + bool IsCall = false; + + MemoryLocOrCall(MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + MemoryLocOrCall(const MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + + MemoryLocOrCall(Instruction *Inst) { + if (auto *C = dyn_cast<CallBase>(Inst)) { + IsCall = true; + Call = C; + } else { + IsCall = false; + // There is no such thing as a memorylocation for a fence inst, and it is + // unique in that regard. + if (!isa<FenceInst>(Inst)) + Loc = MemoryLocation::get(Inst); + } + } + + explicit MemoryLocOrCall(const MemoryLocation &Loc) : Loc(Loc) {} + + const CallBase *getCall() const { + assert(IsCall); + return Call; + } + + MemoryLocation getLoc() const { + assert(!IsCall); + return Loc; + } + + bool operator==(const MemoryLocOrCall &Other) const { + if (IsCall != Other.IsCall) + return false; + + if (!IsCall) + return Loc == Other.Loc; + + if (Call->getCalledValue() != Other.Call->getCalledValue()) + return false; + + return Call->arg_size() == Other.Call->arg_size() && + std::equal(Call->arg_begin(), Call->arg_end(), + Other.Call->arg_begin()); + } + +private: + union { + const CallBase *Call; + MemoryLocation Loc; + }; +}; + +} // end anonymous namespace + +namespace llvm { + +template <> struct DenseMapInfo<MemoryLocOrCall> { + static inline MemoryLocOrCall getEmptyKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); + } + + static inline MemoryLocOrCall getTombstoneKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); + } + + static unsigned getHashValue(const MemoryLocOrCall &MLOC) { + if (!MLOC.IsCall) + return hash_combine( + MLOC.IsCall, + DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); + + hash_code hash = + hash_combine(MLOC.IsCall, DenseMapInfo<const Value *>::getHashValue( + MLOC.getCall()->getCalledValue())); + + for (const Value *Arg : MLOC.getCall()->args()) + hash = hash_combine(hash, DenseMapInfo<const Value *>::getHashValue(Arg)); + return hash; + } + + static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { + return LHS == RHS; + } +}; + +} // end namespace llvm + +/// This does one-way checks to see if Use could theoretically be hoisted above +/// MayClobber. This will not check the other way around. +/// +/// This assumes that, for the purposes of MemorySSA, Use comes directly after +/// MayClobber, with no potentially clobbering operations in between them. +/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) +static bool areLoadsReorderable(const LoadInst *Use, + const LoadInst *MayClobber) { + bool VolatileUse = Use->isVolatile(); + bool VolatileClobber = MayClobber->isVolatile(); + // Volatile operations may never be reordered with other volatile operations. + if (VolatileUse && VolatileClobber) + return false; + // Otherwise, volatile doesn't matter here. From the language reference: + // 'optimizers may change the order of volatile operations relative to + // non-volatile operations.'" + + // If a load is seq_cst, it cannot be moved above other loads. If its ordering + // is weaker, it can be moved above other loads. We just need to be sure that + // MayClobber isn't an acquire load, because loads can't be moved above + // acquire loads. + // + // Note that this explicitly *does* allow the free reordering of monotonic (or + // weaker) loads of the same address. + bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; + bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), + AtomicOrdering::Acquire); + return !(SeqCstUse || MayClobberIsAcquire); +} + +namespace { + +struct ClobberAlias { + bool IsClobber; + Optional<AliasResult> AR; +}; + +} // end anonymous namespace + +// Return a pair of {IsClobber (bool), AR (AliasResult)}. It relies on AR being +// ignored if IsClobber = false. +template <typename AliasAnalysisType> +static ClobberAlias +instructionClobbersQuery(const MemoryDef *MD, const MemoryLocation &UseLoc, + const Instruction *UseInst, AliasAnalysisType &AA) { + Instruction *DefInst = MD->getMemoryInst(); + assert(DefInst && "Defining instruction not actually an instruction"); + const auto *UseCall = dyn_cast<CallBase>(UseInst); + Optional<AliasResult> AR; + + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { + // These intrinsics will show up as affecting memory, but they are just + // markers, mostly. + // + // FIXME: We probably don't actually want MemorySSA to model these at all + // (including creating MemoryAccesses for them): we just end up inventing + // clobbers where they don't really exist at all. Please see D43269 for + // context. + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + if (UseCall) + return {false, NoAlias}; + AR = AA.alias(MemoryLocation(II->getArgOperand(1)), UseLoc); + return {AR != NoAlias, AR}; + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::assume: + return {false, NoAlias}; + case Intrinsic::dbg_addr: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_label: + case Intrinsic::dbg_value: + llvm_unreachable("debuginfo shouldn't have associated defs!"); + default: + break; + } + } + + if (UseCall) { + ModRefInfo I = AA.getModRefInfo(DefInst, UseCall); + AR = isMustSet(I) ? MustAlias : MayAlias; + return {isModOrRefSet(I), AR}; + } + + if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) + if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) + return {!areLoadsReorderable(UseLoad, DefLoad), MayAlias}; + + ModRefInfo I = AA.getModRefInfo(DefInst, UseLoc); + AR = isMustSet(I) ? MustAlias : MayAlias; + return {isModSet(I), AR}; +} + +template <typename AliasAnalysisType> +static ClobberAlias instructionClobbersQuery(MemoryDef *MD, + const MemoryUseOrDef *MU, + const MemoryLocOrCall &UseMLOC, + AliasAnalysisType &AA) { + // FIXME: This is a temporary hack to allow a single instructionClobbersQuery + // to exist while MemoryLocOrCall is pushed through places. + if (UseMLOC.IsCall) + return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), + AA); + return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), + AA); +} + +// Return true when MD may alias MU, return false otherwise. +bool MemorySSAUtil::defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, + AliasAnalysis &AA) { + return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA).IsClobber; +} + +namespace { + +struct UpwardsMemoryQuery { + // True if our original query started off as a call + bool IsCall = false; + // The pointer location we started the query with. This will be empty if + // IsCall is true. + MemoryLocation StartingLoc; + // This is the instruction we were querying about. + const Instruction *Inst = nullptr; + // The MemoryAccess we actually got called with, used to test local domination + const MemoryAccess *OriginalAccess = nullptr; + Optional<AliasResult> AR = MayAlias; + bool SkipSelfAccess = false; + + UpwardsMemoryQuery() = default; + + UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) + : IsCall(isa<CallBase>(Inst)), Inst(Inst), OriginalAccess(Access) { + if (!IsCall) + StartingLoc = MemoryLocation::get(Inst); + } +}; + +} // end anonymous namespace + +static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, + BatchAAResults &AA) { + Instruction *Inst = MD->getMemoryInst(); + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_end: + return AA.alias(MemoryLocation(II->getArgOperand(1)), Loc) == MustAlias; + default: + return false; + } + } + return false; +} + +template <typename AliasAnalysisType> +static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysisType &AA, + const Instruction *I) { + // If the memory can't be changed, then loads of the memory can't be + // clobbered. + return isa<LoadInst>(I) && (I->hasMetadata(LLVMContext::MD_invariant_load) || + AA.pointsToConstantMemory(MemoryLocation( + cast<LoadInst>(I)->getPointerOperand()))); +} + +/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing +/// inbetween `Start` and `ClobberAt` can clobbers `Start`. +/// +/// This is meant to be as simple and self-contained as possible. Because it +/// uses no cache, etc., it can be relatively expensive. +/// +/// \param Start The MemoryAccess that we want to walk from. +/// \param ClobberAt A clobber for Start. +/// \param StartLoc The MemoryLocation for Start. +/// \param MSSA The MemorySSA instance that Start and ClobberAt belong to. +/// \param Query The UpwardsMemoryQuery we used for our search. +/// \param AA The AliasAnalysis we used for our search. +/// \param AllowImpreciseClobber Always false, unless we do relaxed verify. + +template <typename AliasAnalysisType> +LLVM_ATTRIBUTE_UNUSED static void +checkClobberSanity(const MemoryAccess *Start, MemoryAccess *ClobberAt, + const MemoryLocation &StartLoc, const MemorySSA &MSSA, + const UpwardsMemoryQuery &Query, AliasAnalysisType &AA, + bool AllowImpreciseClobber = false) { + assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); + + if (MSSA.isLiveOnEntryDef(Start)) { + assert(MSSA.isLiveOnEntryDef(ClobberAt) && + "liveOnEntry must clobber itself"); + return; + } + + bool FoundClobber = false; + DenseSet<ConstMemoryAccessPair> VisitedPhis; + SmallVector<ConstMemoryAccessPair, 8> Worklist; + Worklist.emplace_back(Start, StartLoc); + // Walk all paths from Start to ClobberAt, while looking for clobbers. If one + // is found, complain. + while (!Worklist.empty()) { + auto MAP = Worklist.pop_back_val(); + // All we care about is that nothing from Start to ClobberAt clobbers Start. + // We learn nothing from revisiting nodes. + if (!VisitedPhis.insert(MAP).second) + continue; + + for (const auto *MA : def_chain(MAP.first)) { + if (MA == ClobberAt) { + if (const auto *MD = dyn_cast<MemoryDef>(MA)) { + // instructionClobbersQuery isn't essentially free, so don't use `|=`, + // since it won't let us short-circuit. + // + // Also, note that this can't be hoisted out of the `Worklist` loop, + // since MD may only act as a clobber for 1 of N MemoryLocations. + FoundClobber = FoundClobber || MSSA.isLiveOnEntryDef(MD); + if (!FoundClobber) { + ClobberAlias CA = + instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); + if (CA.IsClobber) { + FoundClobber = true; + // Not used: CA.AR; + } + } + } + break; + } + + // We should never hit liveOnEntry, unless it's the clobber. + assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); + + if (const auto *MD = dyn_cast<MemoryDef>(MA)) { + // If Start is a Def, skip self. + if (MD == Start) + continue; + + assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) + .IsClobber && + "Found clobber before reaching ClobberAt!"); + continue; + } + + if (const auto *MU = dyn_cast<MemoryUse>(MA)) { + (void)MU; + assert (MU == Start && + "Can only find use in def chain if Start is a use"); + continue; + } + + assert(isa<MemoryPhi>(MA)); + Worklist.append( + upward_defs_begin({const_cast<MemoryAccess *>(MA), MAP.second}), + upward_defs_end()); + } + } + + // If the verify is done following an optimization, it's possible that + // ClobberAt was a conservative clobbering, that we can now infer is not a + // true clobbering access. Don't fail the verify if that's the case. + // We do have accesses that claim they're optimized, but could be optimized + // further. Updating all these can be expensive, so allow it for now (FIXME). + if (AllowImpreciseClobber) + return; + + // If ClobberAt is a MemoryPhi, we can assume something above it acted as a + // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. + assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && + "ClobberAt never acted as a clobber"); +} + +namespace { + +/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up +/// in one class. +template <class AliasAnalysisType> class ClobberWalker { + /// Save a few bytes by using unsigned instead of size_t. + using ListIndex = unsigned; + + /// Represents a span of contiguous MemoryDefs, potentially ending in a + /// MemoryPhi. + struct DefPath { + MemoryLocation Loc; + // Note that, because we always walk in reverse, Last will always dominate + // First. Also note that First and Last are inclusive. + MemoryAccess *First; + MemoryAccess *Last; + Optional<ListIndex> Previous; + + DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, + Optional<ListIndex> Previous) + : Loc(Loc), First(First), Last(Last), Previous(Previous) {} + + DefPath(const MemoryLocation &Loc, MemoryAccess *Init, + Optional<ListIndex> Previous) + : DefPath(Loc, Init, Init, Previous) {} + }; + + const MemorySSA &MSSA; + AliasAnalysisType &AA; + DominatorTree &DT; + UpwardsMemoryQuery *Query; + unsigned *UpwardWalkLimit; + + // Phi optimization bookkeeping + SmallVector<DefPath, 32> Paths; + DenseSet<ConstMemoryAccessPair> VisitedPhis; + + /// Find the nearest def or phi that `From` can legally be optimized to. + const MemoryAccess *getWalkTarget(const MemoryPhi *From) const { + assert(From->getNumOperands() && "Phi with no operands?"); + + BasicBlock *BB = From->getBlock(); + MemoryAccess *Result = MSSA.getLiveOnEntryDef(); + DomTreeNode *Node = DT.getNode(BB); + while ((Node = Node->getIDom())) { + auto *Defs = MSSA.getBlockDefs(Node->getBlock()); + if (Defs) + return &*Defs->rbegin(); + } + return Result; + } + + /// Result of calling walkToPhiOrClobber. + struct UpwardsWalkResult { + /// The "Result" of the walk. Either a clobber, the last thing we walked, or + /// both. Include alias info when clobber found. + MemoryAccess *Result; + bool IsKnownClobber; + Optional<AliasResult> AR; + }; + + /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. + /// This will update Desc.Last as it walks. It will (optionally) also stop at + /// StopAt. + /// + /// This does not test for whether StopAt is a clobber + UpwardsWalkResult + walkToPhiOrClobber(DefPath &Desc, const MemoryAccess *StopAt = nullptr, + const MemoryAccess *SkipStopAt = nullptr) const { + assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); + assert(UpwardWalkLimit && "Need a valid walk limit"); + bool LimitAlreadyReached = false; + // (*UpwardWalkLimit) may be 0 here, due to the loop in tryOptimizePhi. Set + // it to 1. This will not do any alias() calls. It either returns in the + // first iteration in the loop below, or is set back to 0 if all def chains + // are free of MemoryDefs. + if (!*UpwardWalkLimit) { + *UpwardWalkLimit = 1; + LimitAlreadyReached = true; + } + + for (MemoryAccess *Current : def_chain(Desc.Last)) { + Desc.Last = Current; + if (Current == StopAt || Current == SkipStopAt) + return {Current, false, MayAlias}; + + if (auto *MD = dyn_cast<MemoryDef>(Current)) { + if (MSSA.isLiveOnEntryDef(MD)) + return {MD, true, MustAlias}; + + if (!--*UpwardWalkLimit) + return {Current, true, MayAlias}; + + ClobberAlias CA = + instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA); + if (CA.IsClobber) + return {MD, true, CA.AR}; + } + } + + if (LimitAlreadyReached) + *UpwardWalkLimit = 0; + + assert(isa<MemoryPhi>(Desc.Last) && + "Ended at a non-clobber that's not a phi?"); + return {Desc.Last, false, MayAlias}; + } + + void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, + ListIndex PriorNode) { + auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), + upward_defs_end()); + for (const MemoryAccessPair &P : UpwardDefs) { + PausedSearches.push_back(Paths.size()); + Paths.emplace_back(P.second, P.first, PriorNode); + } + } + + /// Represents a search that terminated after finding a clobber. This clobber + /// may or may not be present in the path of defs from LastNode..SearchStart, + /// since it may have been retrieved from cache. + struct TerminatedPath { + MemoryAccess *Clobber; + ListIndex LastNode; + }; + + /// Get an access that keeps us from optimizing to the given phi. + /// + /// PausedSearches is an array of indices into the Paths array. Its incoming + /// value is the indices of searches that stopped at the last phi optimization + /// target. It's left in an unspecified state. + /// + /// If this returns None, NewPaused is a vector of searches that terminated + /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. + Optional<TerminatedPath> + getBlockingAccess(const MemoryAccess *StopWhere, + SmallVectorImpl<ListIndex> &PausedSearches, + SmallVectorImpl<ListIndex> &NewPaused, + SmallVectorImpl<TerminatedPath> &Terminated) { + assert(!PausedSearches.empty() && "No searches to continue?"); + + // BFS vs DFS really doesn't make a difference here, so just do a DFS with + // PausedSearches as our stack. + while (!PausedSearches.empty()) { + ListIndex PathIndex = PausedSearches.pop_back_val(); + DefPath &Node = Paths[PathIndex]; + + // If we've already visited this path with this MemoryLocation, we don't + // need to do so again. + // + // NOTE: That we just drop these paths on the ground makes caching + // behavior sporadic. e.g. given a diamond: + // A + // B C + // D + // + // ...If we walk D, B, A, C, we'll only cache the result of phi + // optimization for A, B, and D; C will be skipped because it dies here. + // This arguably isn't the worst thing ever, since: + // - We generally query things in a top-down order, so if we got below D + // without needing cache entries for {C, MemLoc}, then chances are + // that those cache entries would end up ultimately unused. + // - We still cache things for A, so C only needs to walk up a bit. + // If this behavior becomes problematic, we can fix without a ton of extra + // work. + if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) + continue; + + const MemoryAccess *SkipStopWhere = nullptr; + if (Query->SkipSelfAccess && Node.Loc == Query->StartingLoc) { + assert(isa<MemoryDef>(Query->OriginalAccess)); + SkipStopWhere = Query->OriginalAccess; + } + + UpwardsWalkResult Res = walkToPhiOrClobber(Node, + /*StopAt=*/StopWhere, + /*SkipStopAt=*/SkipStopWhere); + if (Res.IsKnownClobber) { + assert(Res.Result != StopWhere && Res.Result != SkipStopWhere); + + // If this wasn't a cache hit, we hit a clobber when walking. That's a + // failure. + TerminatedPath Term{Res.Result, PathIndex}; + if (!MSSA.dominates(Res.Result, StopWhere)) + return Term; + + // Otherwise, it's a valid thing to potentially optimize to. + Terminated.push_back(Term); + continue; + } + + if (Res.Result == StopWhere || Res.Result == SkipStopWhere) { + // We've hit our target. Save this path off for if we want to continue + // walking. If we are in the mode of skipping the OriginalAccess, and + // we've reached back to the OriginalAccess, do not save path, we've + // just looped back to self. + if (Res.Result != SkipStopWhere) + NewPaused.push_back(PathIndex); + continue; + } + + assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); + addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); + } + + return None; + } + + template <typename T, typename Walker> + struct generic_def_path_iterator + : public iterator_facade_base<generic_def_path_iterator<T, Walker>, + std::forward_iterator_tag, T *> { + generic_def_path_iterator() {} + generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} + + T &operator*() const { return curNode(); } + + generic_def_path_iterator &operator++() { + N = curNode().Previous; + return *this; + } + + bool operator==(const generic_def_path_iterator &O) const { + if (N.hasValue() != O.N.hasValue()) + return false; + return !N.hasValue() || *N == *O.N; + } + + private: + T &curNode() const { return W->Paths[*N]; } + + Walker *W = nullptr; + Optional<ListIndex> N = None; + }; + + using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; + using const_def_path_iterator = + generic_def_path_iterator<const DefPath, const ClobberWalker>; + + iterator_range<def_path_iterator> def_path(ListIndex From) { + return make_range(def_path_iterator(this, From), def_path_iterator()); + } + + iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { + return make_range(const_def_path_iterator(this, From), + const_def_path_iterator()); + } + + struct OptznResult { + /// The path that contains our result. + TerminatedPath PrimaryClobber; + /// The paths that we can legally cache back from, but that aren't + /// necessarily the result of the Phi optimization. + SmallVector<TerminatedPath, 4> OtherClobbers; + }; + + ListIndex defPathIndex(const DefPath &N) const { + // The assert looks nicer if we don't need to do &N + const DefPath *NP = &N; + assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && + "Out of bounds DefPath!"); + return NP - &Paths.front(); + } + + /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths + /// that act as legal clobbers. Note that this won't return *all* clobbers. + /// + /// Phi optimization algorithm tl;dr: + /// - Find the earliest def/phi, A, we can optimize to + /// - Find if all paths from the starting memory access ultimately reach A + /// - If not, optimization isn't possible. + /// - Otherwise, walk from A to another clobber or phi, A'. + /// - If A' is a def, we're done. + /// - If A' is a phi, try to optimize it. + /// + /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path + /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. + OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, + const MemoryLocation &Loc) { + assert(Paths.empty() && VisitedPhis.empty() && + "Reset the optimization state."); + + Paths.emplace_back(Loc, Start, Phi, None); + // Stores how many "valid" optimization nodes we had prior to calling + // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. + auto PriorPathsSize = Paths.size(); + + SmallVector<ListIndex, 16> PausedSearches; + SmallVector<ListIndex, 8> NewPaused; + SmallVector<TerminatedPath, 4> TerminatedPaths; + + addSearches(Phi, PausedSearches, 0); + + // Moves the TerminatedPath with the "most dominated" Clobber to the end of + // Paths. + auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { + assert(!Paths.empty() && "Need a path to move"); + auto Dom = Paths.begin(); + for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) + if (!MSSA.dominates(I->Clobber, Dom->Clobber)) + Dom = I; + auto Last = Paths.end() - 1; + if (Last != Dom) + std::iter_swap(Last, Dom); + }; + + MemoryPhi *Current = Phi; + while (true) { + assert(!MSSA.isLiveOnEntryDef(Current) && + "liveOnEntry wasn't treated as a clobber?"); + + const auto *Target = getWalkTarget(Current); + // If a TerminatedPath doesn't dominate Target, then it wasn't a legal + // optimization for the prior phi. + assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, Target); + })); + + // FIXME: This is broken, because the Blocker may be reported to be + // liveOnEntry, and we'll happily wait for that to disappear (read: never) + // For the moment, this is fine, since we do nothing with blocker info. + if (Optional<TerminatedPath> Blocker = getBlockingAccess( + Target, PausedSearches, NewPaused, TerminatedPaths)) { + + // Find the node we started at. We can't search based on N->Last, since + // we may have gone around a loop with a different MemoryLocation. + auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { + return defPathIndex(N) < PriorPathsSize; + }); + assert(Iter != def_path_iterator()); + + DefPath &CurNode = *Iter; + assert(CurNode.Last == Current); + + // Two things: + // A. We can't reliably cache all of NewPaused back. Consider a case + // where we have two paths in NewPaused; one of which can't optimize + // above this phi, whereas the other can. If we cache the second path + // back, we'll end up with suboptimal cache entries. We can handle + // cases like this a bit better when we either try to find all + // clobbers that block phi optimization, or when our cache starts + // supporting unfinished searches. + // B. We can't reliably cache TerminatedPaths back here without doing + // extra checks; consider a case like: + // T + // / \ + // D C + // \ / + // S + // Where T is our target, C is a node with a clobber on it, D is a + // diamond (with a clobber *only* on the left or right node, N), and + // S is our start. Say we walk to D, through the node opposite N + // (read: ignoring the clobber), and see a cache entry in the top + // node of D. That cache entry gets put into TerminatedPaths. We then + // walk up to C (N is later in our worklist), find the clobber, and + // quit. If we append TerminatedPaths to OtherClobbers, we'll cache + // the bottom part of D to the cached clobber, ignoring the clobber + // in N. Again, this problem goes away if we start tracking all + // blockers for a given phi optimization. + TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; + return {Result, {}}; + } + + // If there's nothing left to search, then all paths led to valid clobbers + // that we got from our cache; pick the nearest to the start, and allow + // the rest to be cached back. + if (NewPaused.empty()) { + MoveDominatedPathToEnd(TerminatedPaths); + TerminatedPath Result = TerminatedPaths.pop_back_val(); + return {Result, std::move(TerminatedPaths)}; + } + + MemoryAccess *DefChainEnd = nullptr; + SmallVector<TerminatedPath, 4> Clobbers; + for (ListIndex Paused : NewPaused) { + UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); + if (WR.IsKnownClobber) + Clobbers.push_back({WR.Result, Paused}); + else + // Micro-opt: If we hit the end of the chain, save it. + DefChainEnd = WR.Result; + } + + if (!TerminatedPaths.empty()) { + // If we couldn't find the dominating phi/liveOnEntry in the above loop, + // do it now. + if (!DefChainEnd) + for (auto *MA : def_chain(const_cast<MemoryAccess *>(Target))) + DefChainEnd = MA; + assert(DefChainEnd && "Failed to find dominating phi/liveOnEntry"); + + // If any of the terminated paths don't dominate the phi we'll try to + // optimize, we need to figure out what they are and quit. + const BasicBlock *ChainBB = DefChainEnd->getBlock(); + for (const TerminatedPath &TP : TerminatedPaths) { + // Because we know that DefChainEnd is as "high" as we can go, we + // don't need local dominance checks; BB dominance is sufficient. + if (DT.dominates(ChainBB, TP.Clobber->getBlock())) + Clobbers.push_back(TP); + } + } + + // If we have clobbers in the def chain, find the one closest to Current + // and quit. + if (!Clobbers.empty()) { + MoveDominatedPathToEnd(Clobbers); + TerminatedPath Result = Clobbers.pop_back_val(); + return {Result, std::move(Clobbers)}; + } + + assert(all_of(NewPaused, + [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); + + // Because liveOnEntry is a clobber, this must be a phi. + auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); + + PriorPathsSize = Paths.size(); + PausedSearches.clear(); + for (ListIndex I : NewPaused) + addSearches(DefChainPhi, PausedSearches, I); + NewPaused.clear(); + + Current = DefChainPhi; + } + } + + void verifyOptResult(const OptznResult &R) const { + assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); + })); + } + + void resetPhiOptznState() { + Paths.clear(); + VisitedPhis.clear(); + } + +public: + ClobberWalker(const MemorySSA &MSSA, AliasAnalysisType &AA, DominatorTree &DT) + : MSSA(MSSA), AA(AA), DT(DT) {} + + AliasAnalysisType *getAA() { return &AA; } + /// Finds the nearest clobber for the given query, optimizing phis if + /// possible. + MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q, + unsigned &UpWalkLimit) { + Query = &Q; + UpwardWalkLimit = &UpWalkLimit; + // Starting limit must be > 0. + if (!UpWalkLimit) + UpWalkLimit++; + + MemoryAccess *Current = Start; + // This walker pretends uses don't exist. If we're handed one, silently grab + // its def. (This has the nice side-effect of ensuring we never cache uses) + if (auto *MU = dyn_cast<MemoryUse>(Start)) + Current = MU->getDefiningAccess(); + + DefPath FirstDesc(Q.StartingLoc, Current, Current, None); + // Fast path for the overly-common case (no crazy phi optimization + // necessary) + UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); + MemoryAccess *Result; + if (WalkResult.IsKnownClobber) { + Result = WalkResult.Result; + Q.AR = WalkResult.AR; + } else { + OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), + Current, Q.StartingLoc); + verifyOptResult(OptRes); + resetPhiOptznState(); + Result = OptRes.PrimaryClobber.Clobber; + } + +#ifdef EXPENSIVE_CHECKS + if (!Q.SkipSelfAccess && *UpwardWalkLimit > 0) + checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); +#endif + return Result; + } +}; + +struct RenamePassData { + DomTreeNode *DTN; + DomTreeNode::const_iterator ChildIt; + MemoryAccess *IncomingVal; + + RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, + MemoryAccess *M) + : DTN(D), ChildIt(It), IncomingVal(M) {} + + void swap(RenamePassData &RHS) { + std::swap(DTN, RHS.DTN); + std::swap(ChildIt, RHS.ChildIt); + std::swap(IncomingVal, RHS.IncomingVal); + } +}; + +} // end anonymous namespace + +namespace llvm { + +template <class AliasAnalysisType> class MemorySSA::ClobberWalkerBase { + ClobberWalker<AliasAnalysisType> Walker; + MemorySSA *MSSA; + +public: + ClobberWalkerBase(MemorySSA *M, AliasAnalysisType *A, DominatorTree *D) + : Walker(*M, *A, *D), MSSA(M) {} + + MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *, + const MemoryLocation &, + unsigned &); + // Third argument (bool), defines whether the clobber search should skip the + // original queried access. If true, there will be a follow-up query searching + // for a clobber access past "self". Note that the Optimized access is not + // updated if a new clobber is found by this SkipSelf search. If this + // additional query becomes heavily used we may decide to cache the result. + // Walker instantiations will decide how to set the SkipSelf bool. + MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *, unsigned &, bool); +}; + +/// A MemorySSAWalker that does AA walks to disambiguate accesses. It no +/// longer does caching on its own, but the name has been retained for the +/// moment. +template <class AliasAnalysisType> +class MemorySSA::CachingWalker final : public MemorySSAWalker { + ClobberWalkerBase<AliasAnalysisType> *Walker; + +public: + CachingWalker(MemorySSA *M, ClobberWalkerBase<AliasAnalysisType> *W) + : MemorySSAWalker(M), Walker(W) {} + ~CachingWalker() override = default; + + using MemorySSAWalker::getClobberingMemoryAccess; + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, unsigned &UWL) { + return Walker->getClobberingMemoryAccessBase(MA, UWL, false); + } + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc, + unsigned &UWL) { + return Walker->getClobberingMemoryAccessBase(MA, Loc, UWL); + } + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override { + unsigned UpwardWalkLimit = MaxCheckLimit; + return getClobberingMemoryAccess(MA, UpwardWalkLimit); + } + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) override { + unsigned UpwardWalkLimit = MaxCheckLimit; + return getClobberingMemoryAccess(MA, Loc, UpwardWalkLimit); + } + + void invalidateInfo(MemoryAccess *MA) override { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->resetOptimized(); + } +}; + +template <class AliasAnalysisType> +class MemorySSA::SkipSelfWalker final : public MemorySSAWalker { + ClobberWalkerBase<AliasAnalysisType> *Walker; + +public: + SkipSelfWalker(MemorySSA *M, ClobberWalkerBase<AliasAnalysisType> *W) + : MemorySSAWalker(M), Walker(W) {} + ~SkipSelfWalker() override = default; + + using MemorySSAWalker::getClobberingMemoryAccess; + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, unsigned &UWL) { + return Walker->getClobberingMemoryAccessBase(MA, UWL, true); + } + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc, + unsigned &UWL) { + return Walker->getClobberingMemoryAccessBase(MA, Loc, UWL); + } + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override { + unsigned UpwardWalkLimit = MaxCheckLimit; + return getClobberingMemoryAccess(MA, UpwardWalkLimit); + } + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) override { + unsigned UpwardWalkLimit = MaxCheckLimit; + return getClobberingMemoryAccess(MA, Loc, UpwardWalkLimit); + } + + void invalidateInfo(MemoryAccess *MA) override { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->resetOptimized(); + } +}; + +} // end namespace llvm + +void MemorySSA::renameSuccessorPhis(BasicBlock *BB, MemoryAccess *IncomingVal, + bool RenameAllUses) { + // Pass through values to our successors + for (const BasicBlock *S : successors(BB)) { + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + if (RenameAllUses) { + bool ReplacementDone = false; + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) + if (Phi->getIncomingBlock(I) == BB) { + Phi->setIncomingValue(I, IncomingVal); + ReplacementDone = true; + } + (void) ReplacementDone; + assert(ReplacementDone && "Incomplete phi during partial rename"); + } else + Phi->addIncoming(IncomingVal, BB); + } +} + +/// Rename a single basic block into MemorySSA form. +/// Uses the standard SSA renaming algorithm. +/// \returns The new incoming value. +MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, MemoryAccess *IncomingVal, + bool RenameAllUses) { + auto It = PerBlockAccesses.find(BB); + // Skip most processing if the list is empty. + if (It != PerBlockAccesses.end()) { + AccessList *Accesses = It->second.get(); + for (MemoryAccess &L : *Accesses) { + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { + if (MUD->getDefiningAccess() == nullptr || RenameAllUses) + MUD->setDefiningAccess(IncomingVal); + if (isa<MemoryDef>(&L)) + IncomingVal = &L; + } else { + IncomingVal = &L; + } + } + } + return IncomingVal; +} + +/// This is the standard SSA renaming algorithm. +/// +/// We walk the dominator tree in preorder, renaming accesses, and then filling +/// in phi nodes in our successors. +void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, + SmallPtrSetImpl<BasicBlock *> &Visited, + bool SkipVisited, bool RenameAllUses) { + assert(Root && "Trying to rename accesses in an unreachable block"); + + SmallVector<RenamePassData, 32> WorkStack; + // Skip everything if we already renamed this block and we are skipping. + // Note: You can't sink this into the if, because we need it to occur + // regardless of whether we skip blocks or not. + bool AlreadyVisited = !Visited.insert(Root->getBlock()).second; + if (SkipVisited && AlreadyVisited) + return; + + IncomingVal = renameBlock(Root->getBlock(), IncomingVal, RenameAllUses); + renameSuccessorPhis(Root->getBlock(), IncomingVal, RenameAllUses); + WorkStack.push_back({Root, Root->begin(), IncomingVal}); + + while (!WorkStack.empty()) { + DomTreeNode *Node = WorkStack.back().DTN; + DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; + IncomingVal = WorkStack.back().IncomingVal; + + if (ChildIt == Node->end()) { + WorkStack.pop_back(); + } else { + DomTreeNode *Child = *ChildIt; + ++WorkStack.back().ChildIt; + BasicBlock *BB = Child->getBlock(); + // Note: You can't sink this into the if, because we need it to occur + // regardless of whether we skip blocks or not. + AlreadyVisited = !Visited.insert(BB).second; + if (SkipVisited && AlreadyVisited) { + // We already visited this during our renaming, which can happen when + // being asked to rename multiple blocks. Figure out the incoming val, + // which is the last def. + // Incoming value can only change if there is a block def, and in that + // case, it's the last block def in the list. + if (auto *BlockDefs = getWritableBlockDefs(BB)) + IncomingVal = &*BlockDefs->rbegin(); + } else + IncomingVal = renameBlock(BB, IncomingVal, RenameAllUses); + renameSuccessorPhis(BB, IncomingVal, RenameAllUses); + WorkStack.push_back({Child, Child->begin(), IncomingVal}); + } + } +} + +/// This handles unreachable block accesses by deleting phi nodes in +/// unreachable blocks, and marking all other unreachable MemoryAccess's as +/// being uses of the live on entry definition. +void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { + assert(!DT->isReachableFromEntry(BB) && + "Reachable block found while handling unreachable blocks"); + + // Make sure phi nodes in our reachable successors end up with a + // LiveOnEntryDef for our incoming edge, even though our block is forward + // unreachable. We could just disconnect these blocks from the CFG fully, + // but we do not right now. + for (const BasicBlock *S : successors(BB)) { + if (!DT->isReachableFromEntry(S)) + continue; + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + Phi->addIncoming(LiveOnEntryDef.get(), BB); + } + + auto It = PerBlockAccesses.find(BB); + if (It == PerBlockAccesses.end()) + return; + + auto &Accesses = It->second; + for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { + auto Next = std::next(AI); + // If we have a phi, just remove it. We are going to replace all + // users with live on entry. + if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) + UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); + else + Accesses->erase(AI); + AI = Next; + } +} + +MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) + : AA(nullptr), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), + SkipWalker(nullptr), NextID(0) { + // Build MemorySSA using a batch alias analysis. This reuses the internal + // state that AA collects during an alias()/getModRefInfo() call. This is + // safe because there are no CFG changes while building MemorySSA and can + // significantly reduce the time spent by the compiler in AA, because we will + // make queries about all the instructions in the Function. + BatchAAResults BatchAA(*AA); + buildMemorySSA(BatchAA); + // Intentionally leave AA to nullptr while building so we don't accidently + // use non-batch AliasAnalysis. + this->AA = AA; + // Also create the walker here. + getWalker(); +} + +MemorySSA::~MemorySSA() { + // Drop all our references + for (const auto &Pair : PerBlockAccesses) + for (MemoryAccess &MA : *Pair.second) + MA.dropAllReferences(); +} + +MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { + auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = std::make_unique<AccessList>(); + return Res.first->second.get(); +} + +MemorySSA::DefsList *MemorySSA::getOrCreateDefsList(const BasicBlock *BB) { + auto Res = PerBlockDefs.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = std::make_unique<DefsList>(); + return Res.first->second.get(); +} + +namespace llvm { + +/// This class is a batch walker of all MemoryUse's in the program, and points +/// their defining access at the thing that actually clobbers them. Because it +/// is a batch walker that touches everything, it does not operate like the +/// other walkers. This walker is basically performing a top-down SSA renaming +/// pass, where the version stack is used as the cache. This enables it to be +/// significantly more time and memory efficient than using the regular walker, +/// which is walking bottom-up. +class MemorySSA::OptimizeUses { +public: + OptimizeUses(MemorySSA *MSSA, CachingWalker<BatchAAResults> *Walker, + BatchAAResults *BAA, DominatorTree *DT) + : MSSA(MSSA), Walker(Walker), AA(BAA), DT(DT) {} + + void optimizeUses(); + +private: + /// This represents where a given memorylocation is in the stack. + struct MemlocStackInfo { + // This essentially is keeping track of versions of the stack. Whenever + // the stack changes due to pushes or pops, these versions increase. + unsigned long StackEpoch; + unsigned long PopEpoch; + // This is the lower bound of places on the stack to check. It is equal to + // the place the last stack walk ended. + // Note: Correctness depends on this being initialized to 0, which densemap + // does + unsigned long LowerBound; + const BasicBlock *LowerBoundBlock; + // This is where the last walk for this memory location ended. + unsigned long LastKill; + bool LastKillValid; + Optional<AliasResult> AR; + }; + + void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, + SmallVectorImpl<MemoryAccess *> &, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &); + + MemorySSA *MSSA; + CachingWalker<BatchAAResults> *Walker; + BatchAAResults *AA; + DominatorTree *DT; +}; + +} // end namespace llvm + +/// Optimize the uses in a given block This is basically the SSA renaming +/// algorithm, with one caveat: We are able to use a single stack for all +/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is +/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just +/// going to be some position in that stack of possible ones. +/// +/// We track the stack positions that each MemoryLocation needs +/// to check, and last ended at. This is because we only want to check the +/// things that changed since last time. The same MemoryLocation should +/// get clobbered by the same store (getModRefInfo does not use invariantness or +/// things like this, and if they start, we can modify MemoryLocOrCall to +/// include relevant data) +void MemorySSA::OptimizeUses::optimizeUsesInBlock( + const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, + SmallVectorImpl<MemoryAccess *> &VersionStack, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { + + /// If no accesses, nothing to do. + MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); + if (Accesses == nullptr) + return; + + // Pop everything that doesn't dominate the current block off the stack, + // increment the PopEpoch to account for this. + while (true) { + assert( + !VersionStack.empty() && + "Version stack should have liveOnEntry sentinel dominating everything"); + BasicBlock *BackBlock = VersionStack.back()->getBlock(); + if (DT->dominates(BackBlock, BB)) + break; + while (VersionStack.back()->getBlock() == BackBlock) + VersionStack.pop_back(); + ++PopEpoch; + } + + for (MemoryAccess &MA : *Accesses) { + auto *MU = dyn_cast<MemoryUse>(&MA); + if (!MU) { + VersionStack.push_back(&MA); + ++StackEpoch; + continue; + } + + if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { + MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true, None); + continue; + } + + MemoryLocOrCall UseMLOC(MU); + auto &LocInfo = LocStackInfo[UseMLOC]; + // If the pop epoch changed, it means we've removed stuff from top of + // stack due to changing blocks. We may have to reset the lower bound or + // last kill info. + if (LocInfo.PopEpoch != PopEpoch) { + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + // If the lower bound was in something that no longer dominates us, we + // have to reset it. + // We can't simply track stack size, because the stack may have had + // pushes/pops in the meantime. + // XXX: This is non-optimal, but only is slower cases with heavily + // branching dominator trees. To get the optimal number of queries would + // be to make lowerbound and lastkill a per-loc stack, and pop it until + // the top of that stack dominates us. This does not seem worth it ATM. + // A much cheaper optimization would be to always explore the deepest + // branch of the dominator tree first. This will guarantee this resets on + // the smallest set of blocks. + if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && + !DT->dominates(LocInfo.LowerBoundBlock, BB)) { + // Reset the lower bound of things to check. + // TODO: Some day we should be able to reset to last kill, rather than + // 0. + LocInfo.LowerBound = 0; + LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); + LocInfo.LastKillValid = false; + } + } else if (LocInfo.StackEpoch != StackEpoch) { + // If all that has changed is the StackEpoch, we only have to check the + // new things on the stack, because we've checked everything before. In + // this case, the lower bound of things to check remains the same. + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + } + if (!LocInfo.LastKillValid) { + LocInfo.LastKill = VersionStack.size() - 1; + LocInfo.LastKillValid = true; + LocInfo.AR = MayAlias; + } + + // At this point, we should have corrected last kill and LowerBound to be + // in bounds. + assert(LocInfo.LowerBound < VersionStack.size() && + "Lower bound out of range"); + assert(LocInfo.LastKill < VersionStack.size() && + "Last kill info out of range"); + // In any case, the new upper bound is the top of the stack. + unsigned long UpperBound = VersionStack.size() - 1; + + if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { + LLVM_DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" + << *(MU->getMemoryInst()) << ")" + << " because there are " + << UpperBound - LocInfo.LowerBound + << " stores to disambiguate\n"); + // Because we did not walk, LastKill is no longer valid, as this may + // have been a kill. + LocInfo.LastKillValid = false; + continue; + } + bool FoundClobberResult = false; + unsigned UpwardWalkLimit = MaxCheckLimit; + while (UpperBound > LocInfo.LowerBound) { + if (isa<MemoryPhi>(VersionStack[UpperBound])) { + // For phis, use the walker, see where we ended up, go there + MemoryAccess *Result = + Walker->getClobberingMemoryAccess(MU, UpwardWalkLimit); + // We are guaranteed to find it or something is wrong + while (VersionStack[UpperBound] != Result) { + assert(UpperBound != 0); + --UpperBound; + } + FoundClobberResult = true; + break; + } + + MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); + // If the lifetime of the pointer ends at this instruction, it's live on + // entry. + if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { + // Reset UpperBound to liveOnEntryDef's place in the stack + UpperBound = 0; + FoundClobberResult = true; + LocInfo.AR = MustAlias; + break; + } + ClobberAlias CA = instructionClobbersQuery(MD, MU, UseMLOC, *AA); + if (CA.IsClobber) { + FoundClobberResult = true; + LocInfo.AR = CA.AR; + break; + } + --UpperBound; + } + + // Note: Phis always have AliasResult AR set to MayAlias ATM. + + // At the end of this loop, UpperBound is either a clobber, or lower bound + // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. + if (FoundClobberResult || UpperBound < LocInfo.LastKill) { + // We were last killed now by where we got to + if (MSSA->isLiveOnEntryDef(VersionStack[UpperBound])) + LocInfo.AR = None; + MU->setDefiningAccess(VersionStack[UpperBound], true, LocInfo.AR); + LocInfo.LastKill = UpperBound; + } else { + // Otherwise, we checked all the new ones, and now we know we can get to + // LastKill. + MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true, LocInfo.AR); + } + LocInfo.LowerBound = VersionStack.size() - 1; + LocInfo.LowerBoundBlock = BB; + } +} + +/// Optimize uses to point to their actual clobbering definitions. +void MemorySSA::OptimizeUses::optimizeUses() { + SmallVector<MemoryAccess *, 16> VersionStack; + DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; + VersionStack.push_back(MSSA->getLiveOnEntryDef()); + + unsigned long StackEpoch = 1; + unsigned long PopEpoch = 1; + // We perform a non-recursive top-down dominator tree walk. + for (const auto *DomNode : depth_first(DT->getRootNode())) + optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, + LocStackInfo); +} + +void MemorySSA::placePHINodes( + const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks) { + // Determine where our MemoryPhi's should go + ForwardIDFCalculator IDFs(*DT); + IDFs.setDefiningBlocks(DefiningBlocks); + SmallVector<BasicBlock *, 32> IDFBlocks; + IDFs.calculate(IDFBlocks); + + // Now place MemoryPhi nodes. + for (auto &BB : IDFBlocks) + createMemoryPhi(BB); +} + +void MemorySSA::buildMemorySSA(BatchAAResults &BAA) { + // We create an access to represent "live on entry", for things like + // arguments or users of globals, where the memory they use is defined before + // the beginning of the function. We do not actually insert it into the IR. + // We do not define a live on exit for the immediate uses, and thus our + // semantics do *not* imply that something with no immediate uses can simply + // be removed. + BasicBlock &StartingPoint = F.getEntryBlock(); + LiveOnEntryDef.reset(new MemoryDef(F.getContext(), nullptr, nullptr, + &StartingPoint, NextID++)); + + // We maintain lists of memory accesses per-block, trading memory for time. We + // could just look up the memory access for every possible instruction in the + // stream. + SmallPtrSet<BasicBlock *, 32> DefiningBlocks; + // Go through each block, figure out where defs occur, and chain together all + // the accesses. + for (BasicBlock &B : F) { + bool InsertIntoDef = false; + AccessList *Accesses = nullptr; + DefsList *Defs = nullptr; + for (Instruction &I : B) { + MemoryUseOrDef *MUD = createNewAccess(&I, &BAA); + if (!MUD) + continue; + + if (!Accesses) + Accesses = getOrCreateAccessList(&B); + Accesses->push_back(MUD); + if (isa<MemoryDef>(MUD)) { + InsertIntoDef = true; + if (!Defs) + Defs = getOrCreateDefsList(&B); + Defs->push_back(*MUD); + } + } + if (InsertIntoDef) + DefiningBlocks.insert(&B); + } + placePHINodes(DefiningBlocks); + + // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get + // filled in with all blocks. + SmallPtrSet<BasicBlock *, 16> Visited; + renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); + + ClobberWalkerBase<BatchAAResults> WalkerBase(this, &BAA, DT); + CachingWalker<BatchAAResults> WalkerLocal(this, &WalkerBase); + OptimizeUses(this, &WalkerLocal, &BAA, DT).optimizeUses(); + + // Mark the uses in unreachable blocks as live on entry, so that they go + // somewhere. + for (auto &BB : F) + if (!Visited.count(&BB)) + markUnreachableAsLiveOnEntry(&BB); +} + +MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } + +MemorySSA::CachingWalker<AliasAnalysis> *MemorySSA::getWalkerImpl() { + if (Walker) + return Walker.get(); + + if (!WalkerBase) + WalkerBase = + std::make_unique<ClobberWalkerBase<AliasAnalysis>>(this, AA, DT); + + Walker = + std::make_unique<CachingWalker<AliasAnalysis>>(this, WalkerBase.get()); + return Walker.get(); +} + +MemorySSAWalker *MemorySSA::getSkipSelfWalker() { + if (SkipWalker) + return SkipWalker.get(); + + if (!WalkerBase) + WalkerBase = + std::make_unique<ClobberWalkerBase<AliasAnalysis>>(this, AA, DT); + + SkipWalker = + std::make_unique<SkipSelfWalker<AliasAnalysis>>(this, WalkerBase.get()); + return SkipWalker.get(); + } + + +// This is a helper function used by the creation routines. It places NewAccess +// into the access and defs lists for a given basic block, at the given +// insertion point. +void MemorySSA::insertIntoListsForBlock(MemoryAccess *NewAccess, + const BasicBlock *BB, + InsertionPlace Point) { + auto *Accesses = getOrCreateAccessList(BB); + if (Point == Beginning) { + // If it's a phi node, it goes first, otherwise, it goes after any phi + // nodes. + if (isa<MemoryPhi>(NewAccess)) { + Accesses->push_front(NewAccess); + auto *Defs = getOrCreateDefsList(BB); + Defs->push_front(*NewAccess); + } else { + auto AI = find_if_not( + *Accesses, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); + Accesses->insert(AI, NewAccess); + if (!isa<MemoryUse>(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + auto DI = find_if_not( + *Defs, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); + Defs->insert(DI, *NewAccess); + } + } + } else { + Accesses->push_back(NewAccess); + if (!isa<MemoryUse>(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + Defs->push_back(*NewAccess); + } + } + BlockNumberingValid.erase(BB); +} + +void MemorySSA::insertIntoListsBefore(MemoryAccess *What, const BasicBlock *BB, + AccessList::iterator InsertPt) { + auto *Accesses = getWritableBlockAccesses(BB); + bool WasEnd = InsertPt == Accesses->end(); + Accesses->insert(AccessList::iterator(InsertPt), What); + if (!isa<MemoryUse>(What)) { + auto *Defs = getOrCreateDefsList(BB); + // If we got asked to insert at the end, we have an easy job, just shove it + // at the end. If we got asked to insert before an existing def, we also get + // an iterator. If we got asked to insert before a use, we have to hunt for + // the next def. + if (WasEnd) { + Defs->push_back(*What); + } else if (isa<MemoryDef>(InsertPt)) { + Defs->insert(InsertPt->getDefsIterator(), *What); + } else { + while (InsertPt != Accesses->end() && !isa<MemoryDef>(InsertPt)) + ++InsertPt; + // Either we found a def, or we are inserting at the end + if (InsertPt == Accesses->end()) + Defs->push_back(*What); + else + Defs->insert(InsertPt->getDefsIterator(), *What); + } + } + BlockNumberingValid.erase(BB); +} + +void MemorySSA::prepareForMoveTo(MemoryAccess *What, BasicBlock *BB) { + // Keep it in the lookup tables, remove from the lists + removeFromLists(What, false); + + // Note that moving should implicitly invalidate the optimized state of a + // MemoryUse (and Phis can't be optimized). However, it doesn't do so for a + // MemoryDef. + if (auto *MD = dyn_cast<MemoryDef>(What)) + MD->resetOptimized(); + What->setBlock(BB); +} + +// Move What before Where in the IR. The end result is that What will belong to +// the right lists and have the right Block set, but will not otherwise be +// correct. It will not have the right defining access, and if it is a def, +// things below it will not properly be updated. +void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB, + AccessList::iterator Where) { + prepareForMoveTo(What, BB); + insertIntoListsBefore(What, BB, Where); +} + +void MemorySSA::moveTo(MemoryAccess *What, BasicBlock *BB, + InsertionPlace Point) { + if (isa<MemoryPhi>(What)) { + assert(Point == Beginning && + "Can only move a Phi at the beginning of the block"); + // Update lookup table entry + ValueToMemoryAccess.erase(What->getBlock()); + bool Inserted = ValueToMemoryAccess.insert({BB, What}).second; + (void)Inserted; + assert(Inserted && "Cannot move a Phi to a block that already has one"); + } + + prepareForMoveTo(What, BB); + insertIntoListsForBlock(What, BB, Point); +} + +MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { + assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + // Phi's always are placed at the front of the block. + insertIntoListsForBlock(Phi, BB, Beginning); + ValueToMemoryAccess[BB] = Phi; + return Phi; +} + +MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, + MemoryAccess *Definition, + const MemoryUseOrDef *Template, + bool CreationMustSucceed) { + assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); + MemoryUseOrDef *NewAccess = createNewAccess(I, AA, Template); + if (CreationMustSucceed) + assert(NewAccess != nullptr && "Tried to create a memory access for a " + "non-memory touching instruction"); + if (NewAccess) + NewAccess->setDefiningAccess(Definition); + return NewAccess; +} + +// Return true if the instruction has ordering constraints. +// Note specifically that this only considers stores and loads +// because others are still considered ModRef by getModRefInfo. +static inline bool isOrdered(const Instruction *I) { + if (auto *SI = dyn_cast<StoreInst>(I)) { + if (!SI->isUnordered()) + return true; + } else if (auto *LI = dyn_cast<LoadInst>(I)) { + if (!LI->isUnordered()) + return true; + } + return false; +} + +/// Helper function to create new memory accesses +template <typename AliasAnalysisType> +MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I, + AliasAnalysisType *AAP, + const MemoryUseOrDef *Template) { + // The assume intrinsic has a control dependency which we model by claiming + // that it writes arbitrarily. Debuginfo intrinsics may be considered + // clobbers when we have a nonstandard AA pipeline. Ignore these fake memory + // dependencies here. + // FIXME: Replace this special casing with a more accurate modelling of + // assume's control dependency. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::assume) + return nullptr; + + // Using a nonstandard AA pipelines might leave us with unexpected modref + // results for I, so add a check to not model instructions that may not read + // from or write to memory. This is necessary for correctness. + if (!I->mayReadFromMemory() && !I->mayWriteToMemory()) + return nullptr; + + bool Def, Use; + if (Template) { + Def = dyn_cast_or_null<MemoryDef>(Template) != nullptr; + Use = dyn_cast_or_null<MemoryUse>(Template) != nullptr; +#if !defined(NDEBUG) + ModRefInfo ModRef = AAP->getModRefInfo(I, None); + bool DefCheck, UseCheck; + DefCheck = isModSet(ModRef) || isOrdered(I); + UseCheck = isRefSet(ModRef); + assert(Def == DefCheck && (Def || Use == UseCheck) && "Invalid template"); +#endif + } else { + // Find out what affect this instruction has on memory. + ModRefInfo ModRef = AAP->getModRefInfo(I, None); + // The isOrdered check is used to ensure that volatiles end up as defs + // (atomics end up as ModRef right now anyway). Until we separate the + // ordering chain from the memory chain, this enables people to see at least + // some relative ordering to volatiles. Note that getClobberingMemoryAccess + // will still give an answer that bypasses other volatile loads. TODO: + // Separate memory aliasing and ordering into two different chains so that + // we can precisely represent both "what memory will this read/write/is + // clobbered by" and "what instructions can I move this past". + Def = isModSet(ModRef) || isOrdered(I); + Use = isRefSet(ModRef); + } + + // It's possible for an instruction to not modify memory at all. During + // construction, we ignore them. + if (!Def && !Use) + return nullptr; + + MemoryUseOrDef *MUD; + if (Def) + MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); + else + MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); + ValueToMemoryAccess[I] = MUD; + return MUD; +} + +/// Returns true if \p Replacer dominates \p Replacee . +bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, + const MemoryAccess *Replacee) const { + if (isa<MemoryUseOrDef>(Replacee)) + return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); + const auto *MP = cast<MemoryPhi>(Replacee); + // For a phi node, the use occurs in the predecessor block of the phi node. + // Since we may occur multiple times in the phi node, we have to check each + // operand to ensure Replacer dominates each operand where Replacee occurs. + for (const Use &Arg : MP->operands()) { + if (Arg.get() != Replacee && + !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) + return false; + } + return true; +} + +/// Properly remove \p MA from all of MemorySSA's lookup tables. +void MemorySSA::removeFromLookups(MemoryAccess *MA) { + assert(MA->use_empty() && + "Trying to remove memory access that still has uses"); + BlockNumbering.erase(MA); + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->setDefiningAccess(nullptr); + // Invalidate our walker's cache if necessary + if (!isa<MemoryUse>(MA)) + getWalker()->invalidateInfo(MA); + + Value *MemoryInst; + if (const auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MemoryInst = MUD->getMemoryInst(); + else + MemoryInst = MA->getBlock(); + + auto VMA = ValueToMemoryAccess.find(MemoryInst); + if (VMA->second == MA) + ValueToMemoryAccess.erase(VMA); +} + +/// Properly remove \p MA from all of MemorySSA's lists. +/// +/// Because of the way the intrusive list and use lists work, it is important to +/// do removal in the right order. +/// ShouldDelete defaults to true, and will cause the memory access to also be +/// deleted, not just removed. +void MemorySSA::removeFromLists(MemoryAccess *MA, bool ShouldDelete) { + BasicBlock *BB = MA->getBlock(); + // The access list owns the reference, so we erase it from the non-owning list + // first. + if (!isa<MemoryUse>(MA)) { + auto DefsIt = PerBlockDefs.find(BB); + std::unique_ptr<DefsList> &Defs = DefsIt->second; + Defs->remove(*MA); + if (Defs->empty()) + PerBlockDefs.erase(DefsIt); + } + + // The erase call here will delete it. If we don't want it deleted, we call + // remove instead. + auto AccessIt = PerBlockAccesses.find(BB); + std::unique_ptr<AccessList> &Accesses = AccessIt->second; + if (ShouldDelete) + Accesses->erase(MA); + else + Accesses->remove(MA); + + if (Accesses->empty()) { + PerBlockAccesses.erase(AccessIt); + BlockNumberingValid.erase(BB); + } +} + +void MemorySSA::print(raw_ostream &OS) const { + MemorySSAAnnotatedWriter Writer(this); + F.print(OS, &Writer); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void MemorySSA::dump() const { print(dbgs()); } +#endif + +void MemorySSA::verifyMemorySSA() const { + verifyDefUses(F); + verifyDomination(F); + verifyOrdering(F); + verifyDominationNumbers(F); + verifyPrevDefInPhis(F); + // Previously, the verification used to also verify that the clobberingAccess + // cached by MemorySSA is the same as the clobberingAccess found at a later + // query to AA. This does not hold true in general due to the current fragility + // of BasicAA which has arbitrary caps on the things it analyzes before giving + // up. As a result, transformations that are correct, will lead to BasicAA + // returning different Alias answers before and after that transformation. + // Invalidating MemorySSA is not an option, as the results in BasicAA can be so + // random, in the worst case we'd need to rebuild MemorySSA from scratch after + // every transformation, which defeats the purpose of using it. For such an + // example, see test4 added in D51960. +} + +void MemorySSA::verifyPrevDefInPhis(Function &F) const { +#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) + for (const BasicBlock &BB : F) { + if (MemoryPhi *Phi = getMemoryAccess(&BB)) { + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) { + auto *Pred = Phi->getIncomingBlock(I); + auto *IncAcc = Phi->getIncomingValue(I); + // If Pred has no unreachable predecessors, get last def looking at + // IDoms. If, while walkings IDoms, any of these has an unreachable + // predecessor, then the incoming def can be any access. + if (auto *DTNode = DT->getNode(Pred)) { + while (DTNode) { + if (auto *DefList = getBlockDefs(DTNode->getBlock())) { + auto *LastAcc = &*(--DefList->end()); + assert(LastAcc == IncAcc && + "Incorrect incoming access into phi."); + break; + } + DTNode = DTNode->getIDom(); + } + } else { + // If Pred has unreachable predecessors, but has at least a Def, the + // incoming access can be the last Def in Pred, or it could have been + // optimized to LoE. After an update, though, the LoE may have been + // replaced by another access, so IncAcc may be any access. + // If Pred has unreachable predecessors and no Defs, incoming access + // should be LoE; However, after an update, it may be any access. + } + } + } + } +#endif +} + +/// Verify that all of the blocks we believe to have valid domination numbers +/// actually have valid domination numbers. +void MemorySSA::verifyDominationNumbers(const Function &F) const { +#ifndef NDEBUG + if (BlockNumberingValid.empty()) + return; + + SmallPtrSet<const BasicBlock *, 16> ValidBlocks = BlockNumberingValid; + for (const BasicBlock &BB : F) { + if (!ValidBlocks.count(&BB)) + continue; + + ValidBlocks.erase(&BB); + + const AccessList *Accesses = getBlockAccesses(&BB); + // It's correct to say an empty block has valid numbering. + if (!Accesses) + continue; + + // Block numbering starts at 1. + unsigned long LastNumber = 0; + for (const MemoryAccess &MA : *Accesses) { + auto ThisNumberIter = BlockNumbering.find(&MA); + assert(ThisNumberIter != BlockNumbering.end() && + "MemoryAccess has no domination number in a valid block!"); + + unsigned long ThisNumber = ThisNumberIter->second; + assert(ThisNumber > LastNumber && + "Domination numbers should be strictly increasing!"); + LastNumber = ThisNumber; + } + } + + assert(ValidBlocks.empty() && + "All valid BasicBlocks should exist in F -- dangling pointers?"); +#endif +} + +/// Verify that the order and existence of MemoryAccesses matches the +/// order and existence of memory affecting instructions. +void MemorySSA::verifyOrdering(Function &F) const { +#ifndef NDEBUG + // Walk all the blocks, comparing what the lookups think and what the access + // lists think, as well as the order in the blocks vs the order in the access + // lists. + SmallVector<MemoryAccess *, 32> ActualAccesses; + SmallVector<MemoryAccess *, 32> ActualDefs; + for (BasicBlock &B : F) { + const AccessList *AL = getBlockAccesses(&B); + const auto *DL = getBlockDefs(&B); + MemoryAccess *Phi = getMemoryAccess(&B); + if (Phi) { + ActualAccesses.push_back(Phi); + ActualDefs.push_back(Phi); + } + + for (Instruction &I : B) { + MemoryAccess *MA = getMemoryAccess(&I); + assert((!MA || (AL && (isa<MemoryUse>(MA) || DL))) && + "We have memory affecting instructions " + "in this block but they are not in the " + "access list or defs list"); + if (MA) { + ActualAccesses.push_back(MA); + if (isa<MemoryDef>(MA)) + ActualDefs.push_back(MA); + } + } + // Either we hit the assert, really have no accesses, or we have both + // accesses and an access list. + // Same with defs. + if (!AL && !DL) + continue; + assert(AL->size() == ActualAccesses.size() && + "We don't have the same number of accesses in the block as on the " + "access list"); + assert((DL || ActualDefs.size() == 0) && + "Either we should have a defs list, or we should have no defs"); + assert((!DL || DL->size() == ActualDefs.size()) && + "We don't have the same number of defs in the block as on the " + "def list"); + auto ALI = AL->begin(); + auto AAI = ActualAccesses.begin(); + while (ALI != AL->end() && AAI != ActualAccesses.end()) { + assert(&*ALI == *AAI && "Not the same accesses in the same order"); + ++ALI; + ++AAI; + } + ActualAccesses.clear(); + if (DL) { + auto DLI = DL->begin(); + auto ADI = ActualDefs.begin(); + while (DLI != DL->end() && ADI != ActualDefs.end()) { + assert(&*DLI == *ADI && "Not the same defs in the same order"); + ++DLI; + ++ADI; + } + } + ActualDefs.clear(); + } +#endif +} + +/// Verify the domination properties of MemorySSA by checking that each +/// definition dominates all of its uses. +void MemorySSA::verifyDomination(Function &F) const { +#ifndef NDEBUG + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *MP = getMemoryAccess(&B)) + for (const Use &U : MP->uses()) + assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); + + for (Instruction &I : B) { + MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); + if (!MD) + continue; + + for (const Use &U : MD->uses()) + assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); + } + } +#endif +} + +/// Verify the def-use lists in MemorySSA, by verifying that \p Use +/// appears in the use list of \p Def. +void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { +#ifndef NDEBUG + // The live on entry use may cause us to get a NULL def here + if (!Def) + assert(isLiveOnEntryDef(Use) && + "Null def but use not point to live on entry def"); + else + assert(is_contained(Def->users(), Use) && + "Did not find use in def's use list"); +#endif +} + +/// Verify the immediate use information, by walking all the memory +/// accesses and verifying that, for each use, it appears in the +/// appropriate def's use list +void MemorySSA::verifyDefUses(Function &F) const { +#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *Phi = getMemoryAccess(&B)) { + assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( + pred_begin(&B), pred_end(&B))) && + "Incomplete MemoryPhi Node"); + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) { + verifyUseInDefs(Phi->getIncomingValue(I), Phi); + assert(find(predecessors(&B), Phi->getIncomingBlock(I)) != + pred_end(&B) && + "Incoming phi block not a block predecessor"); + } + } + + for (Instruction &I : B) { + if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { + verifyUseInDefs(MA->getDefiningAccess(), MA); + } + } + } +#endif +} + +/// Perform a local numbering on blocks so that instruction ordering can be +/// determined in constant time. +/// TODO: We currently just number in order. If we numbered by N, we could +/// allow at least N-1 sequences of insertBefore or insertAfter (and at least +/// log2(N) sequences of mixed before and after) without needing to invalidate +/// the numbering. +void MemorySSA::renumberBlock(const BasicBlock *B) const { + // The pre-increment ensures the numbers really start at 1. + unsigned long CurrentNumber = 0; + const AccessList *AL = getBlockAccesses(B); + assert(AL != nullptr && "Asking to renumber an empty block"); + for (const auto &I : *AL) + BlockNumbering[&I] = ++CurrentNumber; + BlockNumberingValid.insert(B); +} + +/// Determine, for two memory accesses in the same block, +/// whether \p Dominator dominates \p Dominatee. +/// \returns True if \p Dominator dominates \p Dominatee. +bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + const BasicBlock *DominatorBlock = Dominator->getBlock(); + + assert((DominatorBlock == Dominatee->getBlock()) && + "Asking for local domination when accesses are in different blocks!"); + // A node dominates itself. + if (Dominatee == Dominator) + return true; + + // When Dominatee is defined on function entry, it is not dominated by another + // memory access. + if (isLiveOnEntryDef(Dominatee)) + return false; + + // When Dominator is defined on function entry, it dominates the other memory + // access. + if (isLiveOnEntryDef(Dominator)) + return true; + + if (!BlockNumberingValid.count(DominatorBlock)) + renumberBlock(DominatorBlock); + + unsigned long DominatorNum = BlockNumbering.lookup(Dominator); + // All numbers start with 1 + assert(DominatorNum != 0 && "Block was not numbered properly"); + unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); + assert(DominateeNum != 0 && "Block was not numbered properly"); + return DominatorNum < DominateeNum; +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + if (Dominator == Dominatee) + return true; + + if (isLiveOnEntryDef(Dominatee)) + return false; + + if (Dominator->getBlock() != Dominatee->getBlock()) + return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); + return locallyDominates(Dominator, Dominatee); +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const Use &Dominatee) const { + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { + BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); + // The def must dominate the incoming block of the phi. + if (UseBB != Dominator->getBlock()) + return DT->dominates(Dominator->getBlock(), UseBB); + // If the UseBB and the DefBB are the same, compare locally. + return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); + } + // If it's not a PHI node use, the normal dominates can already handle it. + return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); +} + +const static char LiveOnEntryStr[] = "liveOnEntry"; + +void MemoryAccess::print(raw_ostream &OS) const { + switch (getValueID()) { + case MemoryPhiVal: return static_cast<const MemoryPhi *>(this)->print(OS); + case MemoryDefVal: return static_cast<const MemoryDef *>(this)->print(OS); + case MemoryUseVal: return static_cast<const MemoryUse *>(this)->print(OS); + } + llvm_unreachable("invalid value id"); +} + +void MemoryDef::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + + auto printID = [&OS](MemoryAccess *A) { + if (A && A->getID()) + OS << A->getID(); + else + OS << LiveOnEntryStr; + }; + + OS << getID() << " = MemoryDef("; + printID(UO); + OS << ")"; + + if (isOptimized()) { + OS << "->"; + printID(getOptimized()); + + if (Optional<AliasResult> AR = getOptimizedAccessType()) + OS << " " << *AR; + } +} + +void MemoryPhi::print(raw_ostream &OS) const { + bool First = true; + OS << getID() << " = MemoryPhi("; + for (const auto &Op : operands()) { + BasicBlock *BB = getIncomingBlock(Op); + MemoryAccess *MA = cast<MemoryAccess>(Op); + if (!First) + OS << ','; + else + First = false; + + OS << '{'; + if (BB->hasName()) + OS << BB->getName(); + else + BB->printAsOperand(OS, false); + OS << ','; + if (unsigned ID = MA->getID()) + OS << ID; + else + OS << LiveOnEntryStr; + OS << '}'; + } + OS << ')'; +} + +void MemoryUse::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + OS << "MemoryUse("; + if (UO && UO->getID()) + OS << UO->getID(); + else + OS << LiveOnEntryStr; + OS << ')'; + + if (Optional<AliasResult> AR = getOptimizedAccessType()) + OS << " " << *AR; +} + +void MemoryAccess::dump() const { +// Cannot completely remove virtual function even in release mode. +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + print(dbgs()); + dbgs() << "\n"; +#endif +} + +char MemorySSAPrinterLegacyPass::ID = 0; + +MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { + initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<MemorySSAWrapperPass>(); +} + +bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { + auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSA.print(dbgs()); + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); + return false; +} + +AnalysisKey MemorySSAAnalysis::Key; + +MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + return MemorySSAAnalysis::Result(std::make_unique<MemorySSA>(F, &AA, &DT)); +} + +bool MemorySSAAnalysis::Result::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + auto PAC = PA.getChecker<MemorySSAAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + Inv.invalidate<AAManager>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA); +} + +PreservedAnalyses MemorySSAPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "MemorySSA for function: " << F.getName() << "\n"; + AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses MemorySSAVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); + + return PreservedAnalyses::all(); +} + +char MemorySSAWrapperPass::ID = 0; + +MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { + initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } + +void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); +} + +bool MemorySSAWrapperPass::runOnFunction(Function &F) { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + MSSA.reset(new MemorySSA(F, &AA, &DT)); + return false; +} + +void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } + +void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { + MSSA->print(OS); +} + +MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} + +/// Walk the use-def chains starting at \p StartingAccess and find +/// the MemoryAccess that actually clobbers Loc. +/// +/// \returns our clobbering memory access +template <typename AliasAnalysisType> +MemoryAccess * +MemorySSA::ClobberWalkerBase<AliasAnalysisType>::getClobberingMemoryAccessBase( + MemoryAccess *StartingAccess, const MemoryLocation &Loc, + unsigned &UpwardWalkLimit) { + if (isa<MemoryPhi>(StartingAccess)) + return StartingAccess; + + auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); + if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) + return StartingUseOrDef; + + Instruction *I = StartingUseOrDef->getMemoryInst(); + + // Conservatively, fences are always clobbers, so don't perform the walk if we + // hit a fence. + if (!isa<CallBase>(I) && I->isFenceLike()) + return StartingUseOrDef; + + UpwardsMemoryQuery Q; + Q.OriginalAccess = StartingUseOrDef; + Q.StartingLoc = Loc; + Q.Inst = I; + Q.IsCall = false; + + // Unlike the other function, do not walk to the def of a def, because we are + // handed something we already believe is the clobbering access. + // We never set SkipSelf to true in Q in this method. + MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) + ? StartingUseOrDef->getDefiningAccess() + : StartingUseOrDef; + + MemoryAccess *Clobber = + Walker.findClobber(DefiningAccess, Q, UpwardWalkLimit); + LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + LLVM_DEBUG(dbgs() << *StartingUseOrDef << "\n"); + LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); + LLVM_DEBUG(dbgs() << *Clobber << "\n"); + return Clobber; +} + +template <typename AliasAnalysisType> +MemoryAccess * +MemorySSA::ClobberWalkerBase<AliasAnalysisType>::getClobberingMemoryAccessBase( + MemoryAccess *MA, unsigned &UpwardWalkLimit, bool SkipSelf) { + auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); + // If this is a MemoryPhi, we can't do anything. + if (!StartingAccess) + return MA; + + bool IsOptimized = false; + + // If this is an already optimized use or def, return the optimized result. + // Note: Currently, we store the optimized def result in a separate field, + // since we can't use the defining access. + if (StartingAccess->isOptimized()) { + if (!SkipSelf || !isa<MemoryDef>(StartingAccess)) + return StartingAccess->getOptimized(); + IsOptimized = true; + } + + const Instruction *I = StartingAccess->getMemoryInst(); + // We can't sanely do anything with a fence, since they conservatively clobber + // all memory, and have no locations to get pointers from to try to + // disambiguate. + if (!isa<CallBase>(I) && I->isFenceLike()) + return StartingAccess; + + UpwardsMemoryQuery Q(I, StartingAccess); + + if (isUseTriviallyOptimizableToLiveOnEntry(*Walker.getAA(), I)) { + MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); + StartingAccess->setOptimized(LiveOnEntry); + StartingAccess->setOptimizedAccessType(None); + return LiveOnEntry; + } + + MemoryAccess *OptimizedAccess; + if (!IsOptimized) { + // Start with the thing we already think clobbers this location + MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + + // At this point, DefiningAccess may be the live on entry def. + // If it is, we will not get a better result. + if (MSSA->isLiveOnEntryDef(DefiningAccess)) { + StartingAccess->setOptimized(DefiningAccess); + StartingAccess->setOptimizedAccessType(None); + return DefiningAccess; + } + + OptimizedAccess = Walker.findClobber(DefiningAccess, Q, UpwardWalkLimit); + StartingAccess->setOptimized(OptimizedAccess); + if (MSSA->isLiveOnEntryDef(OptimizedAccess)) + StartingAccess->setOptimizedAccessType(None); + else if (Q.AR == MustAlias) + StartingAccess->setOptimizedAccessType(MustAlias); + } else + OptimizedAccess = StartingAccess->getOptimized(); + + LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + LLVM_DEBUG(dbgs() << *StartingAccess << "\n"); + LLVM_DEBUG(dbgs() << "Optimized Memory SSA clobber for " << *I << " is "); + LLVM_DEBUG(dbgs() << *OptimizedAccess << "\n"); + + MemoryAccess *Result; + if (SkipSelf && isa<MemoryPhi>(OptimizedAccess) && + isa<MemoryDef>(StartingAccess) && UpwardWalkLimit) { + assert(isa<MemoryDef>(Q.OriginalAccess)); + Q.SkipSelfAccess = true; + Result = Walker.findClobber(OptimizedAccess, Q, UpwardWalkLimit); + } else + Result = OptimizedAccess; + + LLVM_DEBUG(dbgs() << "Result Memory SSA clobber [SkipSelf = " << SkipSelf); + LLVM_DEBUG(dbgs() << "] for " << *I << " is " << *Result << "\n"); + + return Result; +} + +MemoryAccess * +DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) + return Use->getDefiningAccess(); + return MA; +} + +MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, const MemoryLocation &) { + if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) + return Use->getDefiningAccess(); + return StartingAccess; +} + +void MemoryPhi::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryPhi *>(Self); +} + +void MemoryDef::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryDef *>(Self); +} + +void MemoryUse::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryUse *>(Self); +} diff --git a/llvm/lib/Analysis/MemorySSAUpdater.cpp b/llvm/lib/Analysis/MemorySSAUpdater.cpp new file mode 100644 index 000000000000..f2d56b05d968 --- /dev/null +++ b/llvm/lib/Analysis/MemorySSAUpdater.cpp @@ -0,0 +1,1440 @@ +//===-- MemorySSAUpdater.cpp - Memory SSA Updater--------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------===// +// +// This file implements the MemorySSAUpdater class. +// +//===----------------------------------------------------------------===// +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include <algorithm> + +#define DEBUG_TYPE "memoryssa" +using namespace llvm; + +// This is the marker algorithm from "Simple and Efficient Construction of +// Static Single Assignment Form" +// The simple, non-marker algorithm places phi nodes at any join +// Here, we place markers, and only place phi nodes if they end up necessary. +// They are only necessary if they break a cycle (IE we recursively visit +// ourselves again), or we discover, while getting the value of the operands, +// that there are two or more definitions needing to be merged. +// This still will leave non-minimal form in the case of irreducible control +// flow, where phi nodes may be in cycles with themselves, but unnecessary. +MemoryAccess *MemorySSAUpdater::getPreviousDefRecursive( + BasicBlock *BB, + DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) { + // First, do a cache lookup. Without this cache, certain CFG structures + // (like a series of if statements) take exponential time to visit. + auto Cached = CachedPreviousDef.find(BB); + if (Cached != CachedPreviousDef.end()) + return Cached->second; + + // If this method is called from an unreachable block, return LoE. + if (!MSSA->DT->isReachableFromEntry(BB)) + return MSSA->getLiveOnEntryDef(); + + if (BasicBlock *Pred = BB->getUniquePredecessor()) { + VisitedBlocks.insert(BB); + // Single predecessor case, just recurse, we can only have one definition. + MemoryAccess *Result = getPreviousDefFromEnd(Pred, CachedPreviousDef); + CachedPreviousDef.insert({BB, Result}); + return Result; + } + + if (VisitedBlocks.count(BB)) { + // We hit our node again, meaning we had a cycle, we must insert a phi + // node to break it so we have an operand. The only case this will + // insert useless phis is if we have irreducible control flow. + MemoryAccess *Result = MSSA->createMemoryPhi(BB); + CachedPreviousDef.insert({BB, Result}); + return Result; + } + + if (VisitedBlocks.insert(BB).second) { + // Mark us visited so we can detect a cycle + SmallVector<TrackingVH<MemoryAccess>, 8> PhiOps; + + // Recurse to get the values in our predecessors for placement of a + // potential phi node. This will insert phi nodes if we cycle in order to + // break the cycle and have an operand. + bool UniqueIncomingAccess = true; + MemoryAccess *SingleAccess = nullptr; + for (auto *Pred : predecessors(BB)) { + if (MSSA->DT->isReachableFromEntry(Pred)) { + auto *IncomingAccess = getPreviousDefFromEnd(Pred, CachedPreviousDef); + if (!SingleAccess) + SingleAccess = IncomingAccess; + else if (IncomingAccess != SingleAccess) + UniqueIncomingAccess = false; + PhiOps.push_back(IncomingAccess); + } else + PhiOps.push_back(MSSA->getLiveOnEntryDef()); + } + + // Now try to simplify the ops to avoid placing a phi. + // This may return null if we never created a phi yet, that's okay + MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MSSA->getMemoryAccess(BB)); + + // See if we can avoid the phi by simplifying it. + auto *Result = tryRemoveTrivialPhi(Phi, PhiOps); + // If we couldn't simplify, we may have to create a phi + if (Result == Phi && UniqueIncomingAccess && SingleAccess) { + // A concrete Phi only exists if we created an empty one to break a cycle. + if (Phi) { + assert(Phi->operands().empty() && "Expected empty Phi"); + Phi->replaceAllUsesWith(SingleAccess); + removeMemoryAccess(Phi); + } + Result = SingleAccess; + } else if (Result == Phi && !(UniqueIncomingAccess && SingleAccess)) { + if (!Phi) + Phi = MSSA->createMemoryPhi(BB); + + // See if the existing phi operands match what we need. + // Unlike normal SSA, we only allow one phi node per block, so we can't just + // create a new one. + if (Phi->getNumOperands() != 0) { + // FIXME: Figure out whether this is dead code and if so remove it. + if (!std::equal(Phi->op_begin(), Phi->op_end(), PhiOps.begin())) { + // These will have been filled in by the recursive read we did above. + llvm::copy(PhiOps, Phi->op_begin()); + std::copy(pred_begin(BB), pred_end(BB), Phi->block_begin()); + } + } else { + unsigned i = 0; + for (auto *Pred : predecessors(BB)) + Phi->addIncoming(&*PhiOps[i++], Pred); + InsertedPHIs.push_back(Phi); + } + Result = Phi; + } + + // Set ourselves up for the next variable by resetting visited state. + VisitedBlocks.erase(BB); + CachedPreviousDef.insert({BB, Result}); + return Result; + } + llvm_unreachable("Should have hit one of the three cases above"); +} + +// This starts at the memory access, and goes backwards in the block to find the +// previous definition. If a definition is not found the block of the access, +// it continues globally, creating phi nodes to ensure we have a single +// definition. +MemoryAccess *MemorySSAUpdater::getPreviousDef(MemoryAccess *MA) { + if (auto *LocalResult = getPreviousDefInBlock(MA)) + return LocalResult; + DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> CachedPreviousDef; + return getPreviousDefRecursive(MA->getBlock(), CachedPreviousDef); +} + +// This starts at the memory access, and goes backwards in the block to the find +// the previous definition. If the definition is not found in the block of the +// access, it returns nullptr. +MemoryAccess *MemorySSAUpdater::getPreviousDefInBlock(MemoryAccess *MA) { + auto *Defs = MSSA->getWritableBlockDefs(MA->getBlock()); + + // It's possible there are no defs, or we got handed the first def to start. + if (Defs) { + // If this is a def, we can just use the def iterators. + if (!isa<MemoryUse>(MA)) { + auto Iter = MA->getReverseDefsIterator(); + ++Iter; + if (Iter != Defs->rend()) + return &*Iter; + } else { + // Otherwise, have to walk the all access iterator. + auto End = MSSA->getWritableBlockAccesses(MA->getBlock())->rend(); + for (auto &U : make_range(++MA->getReverseIterator(), End)) + if (!isa<MemoryUse>(U)) + return cast<MemoryAccess>(&U); + // Note that if MA comes before Defs->begin(), we won't hit a def. + return nullptr; + } + } + return nullptr; +} + +// This starts at the end of block +MemoryAccess *MemorySSAUpdater::getPreviousDefFromEnd( + BasicBlock *BB, + DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) { + auto *Defs = MSSA->getWritableBlockDefs(BB); + + if (Defs) { + CachedPreviousDef.insert({BB, &*Defs->rbegin()}); + return &*Defs->rbegin(); + } + + return getPreviousDefRecursive(BB, CachedPreviousDef); +} +// Recurse over a set of phi uses to eliminate the trivial ones +MemoryAccess *MemorySSAUpdater::recursePhi(MemoryAccess *Phi) { + if (!Phi) + return nullptr; + TrackingVH<MemoryAccess> Res(Phi); + SmallVector<TrackingVH<Value>, 8> Uses; + std::copy(Phi->user_begin(), Phi->user_end(), std::back_inserter(Uses)); + for (auto &U : Uses) + if (MemoryPhi *UsePhi = dyn_cast<MemoryPhi>(&*U)) + tryRemoveTrivialPhi(UsePhi); + return Res; +} + +// Eliminate trivial phis +// Phis are trivial if they are defined either by themselves, or all the same +// argument. +// IE phi(a, a) or b = phi(a, b) or c = phi(a, a, c) +// We recursively try to remove them. +MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi) { + assert(Phi && "Can only remove concrete Phi."); + auto OperRange = Phi->operands(); + return tryRemoveTrivialPhi(Phi, OperRange); +} +template <class RangeType> +MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi, + RangeType &Operands) { + // Bail out on non-opt Phis. + if (NonOptPhis.count(Phi)) + return Phi; + + // Detect equal or self arguments + MemoryAccess *Same = nullptr; + for (auto &Op : Operands) { + // If the same or self, good so far + if (Op == Phi || Op == Same) + continue; + // not the same, return the phi since it's not eliminatable by us + if (Same) + return Phi; + Same = cast<MemoryAccess>(&*Op); + } + // Never found a non-self reference, the phi is undef + if (Same == nullptr) + return MSSA->getLiveOnEntryDef(); + if (Phi) { + Phi->replaceAllUsesWith(Same); + removeMemoryAccess(Phi); + } + + // We should only end up recursing in case we replaced something, in which + // case, we may have made other Phis trivial. + return recursePhi(Same); +} + +void MemorySSAUpdater::insertUse(MemoryUse *MU, bool RenameUses) { + InsertedPHIs.clear(); + MU->setDefiningAccess(getPreviousDef(MU)); + + // In cases without unreachable blocks, because uses do not create new + // may-defs, there are only two cases: + // 1. There was a def already below us, and therefore, we should not have + // created a phi node because it was already needed for the def. + // + // 2. There is no def below us, and therefore, there is no extra renaming work + // to do. + + // In cases with unreachable blocks, where the unnecessary Phis were + // optimized out, adding the Use may re-insert those Phis. Hence, when + // inserting Uses outside of the MSSA creation process, and new Phis were + // added, rename all uses if we are asked. + + if (!RenameUses && !InsertedPHIs.empty()) { + auto *Defs = MSSA->getBlockDefs(MU->getBlock()); + (void)Defs; + assert((!Defs || (++Defs->begin() == Defs->end())) && + "Block may have only a Phi or no defs"); + } + + if (RenameUses && InsertedPHIs.size()) { + SmallPtrSet<BasicBlock *, 16> Visited; + BasicBlock *StartBlock = MU->getBlock(); + + if (auto *Defs = MSSA->getWritableBlockDefs(StartBlock)) { + MemoryAccess *FirstDef = &*Defs->begin(); + // Convert to incoming value if it's a memorydef. A phi *is* already an + // incoming value. + if (auto *MD = dyn_cast<MemoryDef>(FirstDef)) + FirstDef = MD->getDefiningAccess(); + + MSSA->renamePass(MU->getBlock(), FirstDef, Visited); + } + // We just inserted a phi into this block, so the incoming value will + // become the phi anyway, so it does not matter what we pass. + for (auto &MP : InsertedPHIs) + if (MemoryPhi *Phi = cast_or_null<MemoryPhi>(MP)) + MSSA->renamePass(Phi->getBlock(), nullptr, Visited); + } +} + +// Set every incoming edge {BB, MP->getBlock()} of MemoryPhi MP to NewDef. +static void setMemoryPhiValueForBlock(MemoryPhi *MP, const BasicBlock *BB, + MemoryAccess *NewDef) { + // Replace any operand with us an incoming block with the new defining + // access. + int i = MP->getBasicBlockIndex(BB); + assert(i != -1 && "Should have found the basic block in the phi"); + // We can't just compare i against getNumOperands since one is signed and the + // other not. So use it to index into the block iterator. + for (auto BBIter = MP->block_begin() + i; BBIter != MP->block_end(); + ++BBIter) { + if (*BBIter != BB) + break; + MP->setIncomingValue(i, NewDef); + ++i; + } +} + +// A brief description of the algorithm: +// First, we compute what should define the new def, using the SSA +// construction algorithm. +// Then, we update the defs below us (and any new phi nodes) in the graph to +// point to the correct new defs, to ensure we only have one variable, and no +// disconnected stores. +void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { + InsertedPHIs.clear(); + + // See if we had a local def, and if not, go hunting. + MemoryAccess *DefBefore = getPreviousDef(MD); + bool DefBeforeSameBlock = false; + if (DefBefore->getBlock() == MD->getBlock() && + !(isa<MemoryPhi>(DefBefore) && + std::find(InsertedPHIs.begin(), InsertedPHIs.end(), DefBefore) != + InsertedPHIs.end())) + DefBeforeSameBlock = true; + + // There is a def before us, which means we can replace any store/phi uses + // of that thing with us, since we are in the way of whatever was there + // before. + // We now define that def's memorydefs and memoryphis + if (DefBeforeSameBlock) { + DefBefore->replaceUsesWithIf(MD, [MD](Use &U) { + // Leave the MemoryUses alone. + // Also make sure we skip ourselves to avoid self references. + User *Usr = U.getUser(); + return !isa<MemoryUse>(Usr) && Usr != MD; + // Defs are automatically unoptimized when the user is set to MD below, + // because the isOptimized() call will fail to find the same ID. + }); + } + + // and that def is now our defining access. + MD->setDefiningAccess(DefBefore); + + SmallVector<WeakVH, 8> FixupList(InsertedPHIs.begin(), InsertedPHIs.end()); + + // Remember the index where we may insert new phis. + unsigned NewPhiIndex = InsertedPHIs.size(); + if (!DefBeforeSameBlock) { + // If there was a local def before us, we must have the same effect it + // did. Because every may-def is the same, any phis/etc we would create, it + // would also have created. If there was no local def before us, we + // performed a global update, and have to search all successors and make + // sure we update the first def in each of them (following all paths until + // we hit the first def along each path). This may also insert phi nodes. + // TODO: There are other cases we can skip this work, such as when we have a + // single successor, and only used a straight line of single pred blocks + // backwards to find the def. To make that work, we'd have to track whether + // getDefRecursive only ever used the single predecessor case. These types + // of paths also only exist in between CFG simplifications. + + // If this is the first def in the block and this insert is in an arbitrary + // place, compute IDF and place phis. + SmallPtrSet<BasicBlock *, 2> DefiningBlocks; + + // If this is the last Def in the block, also compute IDF based on MD, since + // this may a new Def added, and we may need additional Phis. + auto Iter = MD->getDefsIterator(); + ++Iter; + auto IterEnd = MSSA->getBlockDefs(MD->getBlock())->end(); + if (Iter == IterEnd) + DefiningBlocks.insert(MD->getBlock()); + + for (const auto &VH : InsertedPHIs) + if (const auto *RealPHI = cast_or_null<MemoryPhi>(VH)) + DefiningBlocks.insert(RealPHI->getBlock()); + ForwardIDFCalculator IDFs(*MSSA->DT); + SmallVector<BasicBlock *, 32> IDFBlocks; + IDFs.setDefiningBlocks(DefiningBlocks); + IDFs.calculate(IDFBlocks); + SmallVector<AssertingVH<MemoryPhi>, 4> NewInsertedPHIs; + for (auto *BBIDF : IDFBlocks) { + auto *MPhi = MSSA->getMemoryAccess(BBIDF); + if (!MPhi) { + MPhi = MSSA->createMemoryPhi(BBIDF); + NewInsertedPHIs.push_back(MPhi); + } + // Add the phis created into the IDF blocks to NonOptPhis, so they are not + // optimized out as trivial by the call to getPreviousDefFromEnd below. + // Once they are complete, all these Phis are added to the FixupList, and + // removed from NonOptPhis inside fixupDefs(). Existing Phis in IDF may + // need fixing as well, and potentially be trivial before this insertion, + // hence add all IDF Phis. See PR43044. + NonOptPhis.insert(MPhi); + } + for (auto &MPhi : NewInsertedPHIs) { + auto *BBIDF = MPhi->getBlock(); + for (auto *Pred : predecessors(BBIDF)) { + DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> CachedPreviousDef; + MPhi->addIncoming(getPreviousDefFromEnd(Pred, CachedPreviousDef), Pred); + } + } + + // Re-take the index where we're adding the new phis, because the above call + // to getPreviousDefFromEnd, may have inserted into InsertedPHIs. + NewPhiIndex = InsertedPHIs.size(); + for (auto &MPhi : NewInsertedPHIs) { + InsertedPHIs.push_back(&*MPhi); + FixupList.push_back(&*MPhi); + } + + FixupList.push_back(MD); + } + + // Remember the index where we stopped inserting new phis above, since the + // fixupDefs call in the loop below may insert more, that are already minimal. + unsigned NewPhiIndexEnd = InsertedPHIs.size(); + + while (!FixupList.empty()) { + unsigned StartingPHISize = InsertedPHIs.size(); + fixupDefs(FixupList); + FixupList.clear(); + // Put any new phis on the fixup list, and process them + FixupList.append(InsertedPHIs.begin() + StartingPHISize, InsertedPHIs.end()); + } + + // Optimize potentially non-minimal phis added in this method. + unsigned NewPhiSize = NewPhiIndexEnd - NewPhiIndex; + if (NewPhiSize) + tryRemoveTrivialPhis(ArrayRef<WeakVH>(&InsertedPHIs[NewPhiIndex], NewPhiSize)); + + // Now that all fixups are done, rename all uses if we are asked. + if (RenameUses) { + SmallPtrSet<BasicBlock *, 16> Visited; + BasicBlock *StartBlock = MD->getBlock(); + // We are guaranteed there is a def in the block, because we just got it + // handed to us in this function. + MemoryAccess *FirstDef = &*MSSA->getWritableBlockDefs(StartBlock)->begin(); + // Convert to incoming value if it's a memorydef. A phi *is* already an + // incoming value. + if (auto *MD = dyn_cast<MemoryDef>(FirstDef)) + FirstDef = MD->getDefiningAccess(); + + MSSA->renamePass(MD->getBlock(), FirstDef, Visited); + // We just inserted a phi into this block, so the incoming value will become + // the phi anyway, so it does not matter what we pass. + for (auto &MP : InsertedPHIs) { + MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MP); + if (Phi) + MSSA->renamePass(Phi->getBlock(), nullptr, Visited); + } + } +} + +void MemorySSAUpdater::fixupDefs(const SmallVectorImpl<WeakVH> &Vars) { + SmallPtrSet<const BasicBlock *, 8> Seen; + SmallVector<const BasicBlock *, 16> Worklist; + for (auto &Var : Vars) { + MemoryAccess *NewDef = dyn_cast_or_null<MemoryAccess>(Var); + if (!NewDef) + continue; + // First, see if there is a local def after the operand. + auto *Defs = MSSA->getWritableBlockDefs(NewDef->getBlock()); + auto DefIter = NewDef->getDefsIterator(); + + // The temporary Phi is being fixed, unmark it for not to optimize. + if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(NewDef)) + NonOptPhis.erase(Phi); + + // If there is a local def after us, we only have to rename that. + if (++DefIter != Defs->end()) { + cast<MemoryDef>(DefIter)->setDefiningAccess(NewDef); + continue; + } + + // Otherwise, we need to search down through the CFG. + // For each of our successors, handle it directly if their is a phi, or + // place on the fixup worklist. + for (const auto *S : successors(NewDef->getBlock())) { + if (auto *MP = MSSA->getMemoryAccess(S)) + setMemoryPhiValueForBlock(MP, NewDef->getBlock(), NewDef); + else + Worklist.push_back(S); + } + + while (!Worklist.empty()) { + const BasicBlock *FixupBlock = Worklist.back(); + Worklist.pop_back(); + + // Get the first def in the block that isn't a phi node. + if (auto *Defs = MSSA->getWritableBlockDefs(FixupBlock)) { + auto *FirstDef = &*Defs->begin(); + // The loop above and below should have taken care of phi nodes + assert(!isa<MemoryPhi>(FirstDef) && + "Should have already handled phi nodes!"); + // We are now this def's defining access, make sure we actually dominate + // it + assert(MSSA->dominates(NewDef, FirstDef) && + "Should have dominated the new access"); + + // This may insert new phi nodes, because we are not guaranteed the + // block we are processing has a single pred, and depending where the + // store was inserted, it may require phi nodes below it. + cast<MemoryDef>(FirstDef)->setDefiningAccess(getPreviousDef(FirstDef)); + return; + } + // We didn't find a def, so we must continue. + for (const auto *S : successors(FixupBlock)) { + // If there is a phi node, handle it. + // Otherwise, put the block on the worklist + if (auto *MP = MSSA->getMemoryAccess(S)) + setMemoryPhiValueForBlock(MP, FixupBlock, NewDef); + else { + // If we cycle, we should have ended up at a phi node that we already + // processed. FIXME: Double check this + if (!Seen.insert(S).second) + continue; + Worklist.push_back(S); + } + } + } + } +} + +void MemorySSAUpdater::removeEdge(BasicBlock *From, BasicBlock *To) { + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) { + MPhi->unorderedDeleteIncomingBlock(From); + tryRemoveTrivialPhi(MPhi); + } +} + +void MemorySSAUpdater::removeDuplicatePhiEdgesBetween(const BasicBlock *From, + const BasicBlock *To) { + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) { + bool Found = false; + MPhi->unorderedDeleteIncomingIf([&](const MemoryAccess *, BasicBlock *B) { + if (From != B) + return false; + if (Found) + return true; + Found = true; + return false; + }); + tryRemoveTrivialPhi(MPhi); + } +} + +static MemoryAccess *getNewDefiningAccessForClone(MemoryAccess *MA, + const ValueToValueMapTy &VMap, + PhiToDefMap &MPhiMap, + bool CloneWasSimplified, + MemorySSA *MSSA) { + MemoryAccess *InsnDefining = MA; + if (MemoryDef *DefMUD = dyn_cast<MemoryDef>(InsnDefining)) { + if (!MSSA->isLiveOnEntryDef(DefMUD)) { + Instruction *DefMUDI = DefMUD->getMemoryInst(); + assert(DefMUDI && "Found MemoryUseOrDef with no Instruction."); + if (Instruction *NewDefMUDI = + cast_or_null<Instruction>(VMap.lookup(DefMUDI))) { + InsnDefining = MSSA->getMemoryAccess(NewDefMUDI); + if (!CloneWasSimplified) + assert(InsnDefining && "Defining instruction cannot be nullptr."); + else if (!InsnDefining || isa<MemoryUse>(InsnDefining)) { + // The clone was simplified, it's no longer a MemoryDef, look up. + auto DefIt = DefMUD->getDefsIterator(); + // Since simplified clones only occur in single block cloning, a + // previous definition must exist, otherwise NewDefMUDI would not + // have been found in VMap. + assert(DefIt != MSSA->getBlockDefs(DefMUD->getBlock())->begin() && + "Previous def must exist"); + InsnDefining = getNewDefiningAccessForClone( + &*(--DefIt), VMap, MPhiMap, CloneWasSimplified, MSSA); + } + } + } + } else { + MemoryPhi *DefPhi = cast<MemoryPhi>(InsnDefining); + if (MemoryAccess *NewDefPhi = MPhiMap.lookup(DefPhi)) + InsnDefining = NewDefPhi; + } + assert(InsnDefining && "Defining instruction cannot be nullptr."); + return InsnDefining; +} + +void MemorySSAUpdater::cloneUsesAndDefs(BasicBlock *BB, BasicBlock *NewBB, + const ValueToValueMapTy &VMap, + PhiToDefMap &MPhiMap, + bool CloneWasSimplified) { + const MemorySSA::AccessList *Acc = MSSA->getBlockAccesses(BB); + if (!Acc) + return; + for (const MemoryAccess &MA : *Acc) { + if (const MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&MA)) { + Instruction *Insn = MUD->getMemoryInst(); + // Entry does not exist if the clone of the block did not clone all + // instructions. This occurs in LoopRotate when cloning instructions + // from the old header to the old preheader. The cloned instruction may + // also be a simplified Value, not an Instruction (see LoopRotate). + // Also in LoopRotate, even when it's an instruction, due to it being + // simplified, it may be a Use rather than a Def, so we cannot use MUD as + // template. Calls coming from updateForClonedBlockIntoPred, ensure this. + if (Instruction *NewInsn = + dyn_cast_or_null<Instruction>(VMap.lookup(Insn))) { + MemoryAccess *NewUseOrDef = MSSA->createDefinedAccess( + NewInsn, + getNewDefiningAccessForClone(MUD->getDefiningAccess(), VMap, + MPhiMap, CloneWasSimplified, MSSA), + /*Template=*/CloneWasSimplified ? nullptr : MUD, + /*CreationMustSucceed=*/CloneWasSimplified ? false : true); + if (NewUseOrDef) + MSSA->insertIntoListsForBlock(NewUseOrDef, NewBB, MemorySSA::End); + } + } + } +} + +void MemorySSAUpdater::updatePhisWhenInsertingUniqueBackedgeBlock( + BasicBlock *Header, BasicBlock *Preheader, BasicBlock *BEBlock) { + auto *MPhi = MSSA->getMemoryAccess(Header); + if (!MPhi) + return; + + // Create phi node in the backedge block and populate it with the same + // incoming values as MPhi. Skip incoming values coming from Preheader. + auto *NewMPhi = MSSA->createMemoryPhi(BEBlock); + bool HasUniqueIncomingValue = true; + MemoryAccess *UniqueValue = nullptr; + for (unsigned I = 0, E = MPhi->getNumIncomingValues(); I != E; ++I) { + BasicBlock *IBB = MPhi->getIncomingBlock(I); + MemoryAccess *IV = MPhi->getIncomingValue(I); + if (IBB != Preheader) { + NewMPhi->addIncoming(IV, IBB); + if (HasUniqueIncomingValue) { + if (!UniqueValue) + UniqueValue = IV; + else if (UniqueValue != IV) + HasUniqueIncomingValue = false; + } + } + } + + // Update incoming edges into MPhi. Remove all but the incoming edge from + // Preheader. Add an edge from NewMPhi + auto *AccFromPreheader = MPhi->getIncomingValueForBlock(Preheader); + MPhi->setIncomingValue(0, AccFromPreheader); + MPhi->setIncomingBlock(0, Preheader); + for (unsigned I = MPhi->getNumIncomingValues() - 1; I >= 1; --I) + MPhi->unorderedDeleteIncoming(I); + MPhi->addIncoming(NewMPhi, BEBlock); + + // If NewMPhi is a trivial phi, remove it. Its use in the header MPhi will be + // replaced with the unique value. + tryRemoveTrivialPhi(NewMPhi); +} + +void MemorySSAUpdater::updateForClonedLoop(const LoopBlocksRPO &LoopBlocks, + ArrayRef<BasicBlock *> ExitBlocks, + const ValueToValueMapTy &VMap, + bool IgnoreIncomingWithNoClones) { + PhiToDefMap MPhiMap; + + auto FixPhiIncomingValues = [&](MemoryPhi *Phi, MemoryPhi *NewPhi) { + assert(Phi && NewPhi && "Invalid Phi nodes."); + BasicBlock *NewPhiBB = NewPhi->getBlock(); + SmallPtrSet<BasicBlock *, 4> NewPhiBBPreds(pred_begin(NewPhiBB), + pred_end(NewPhiBB)); + for (unsigned It = 0, E = Phi->getNumIncomingValues(); It < E; ++It) { + MemoryAccess *IncomingAccess = Phi->getIncomingValue(It); + BasicBlock *IncBB = Phi->getIncomingBlock(It); + + if (BasicBlock *NewIncBB = cast_or_null<BasicBlock>(VMap.lookup(IncBB))) + IncBB = NewIncBB; + else if (IgnoreIncomingWithNoClones) + continue; + + // Now we have IncBB, and will need to add incoming from it to NewPhi. + + // If IncBB is not a predecessor of NewPhiBB, then do not add it. + // NewPhiBB was cloned without that edge. + if (!NewPhiBBPreds.count(IncBB)) + continue; + + // Determine incoming value and add it as incoming from IncBB. + if (MemoryUseOrDef *IncMUD = dyn_cast<MemoryUseOrDef>(IncomingAccess)) { + if (!MSSA->isLiveOnEntryDef(IncMUD)) { + Instruction *IncI = IncMUD->getMemoryInst(); + assert(IncI && "Found MemoryUseOrDef with no Instruction."); + if (Instruction *NewIncI = + cast_or_null<Instruction>(VMap.lookup(IncI))) { + IncMUD = MSSA->getMemoryAccess(NewIncI); + assert(IncMUD && + "MemoryUseOrDef cannot be null, all preds processed."); + } + } + NewPhi->addIncoming(IncMUD, IncBB); + } else { + MemoryPhi *IncPhi = cast<MemoryPhi>(IncomingAccess); + if (MemoryAccess *NewDefPhi = MPhiMap.lookup(IncPhi)) + NewPhi->addIncoming(NewDefPhi, IncBB); + else + NewPhi->addIncoming(IncPhi, IncBB); + } + } + }; + + auto ProcessBlock = [&](BasicBlock *BB) { + BasicBlock *NewBlock = cast_or_null<BasicBlock>(VMap.lookup(BB)); + if (!NewBlock) + return; + + assert(!MSSA->getWritableBlockAccesses(NewBlock) && + "Cloned block should have no accesses"); + + // Add MemoryPhi. + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) { + MemoryPhi *NewPhi = MSSA->createMemoryPhi(NewBlock); + MPhiMap[MPhi] = NewPhi; + } + // Update Uses and Defs. + cloneUsesAndDefs(BB, NewBlock, VMap, MPhiMap); + }; + + for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks)) + ProcessBlock(BB); + + for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks)) + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) + if (MemoryAccess *NewPhi = MPhiMap.lookup(MPhi)) + FixPhiIncomingValues(MPhi, cast<MemoryPhi>(NewPhi)); +} + +void MemorySSAUpdater::updateForClonedBlockIntoPred( + BasicBlock *BB, BasicBlock *P1, const ValueToValueMapTy &VM) { + // All defs/phis from outside BB that are used in BB, are valid uses in P1. + // Since those defs/phis must have dominated BB, and also dominate P1. + // Defs from BB being used in BB will be replaced with the cloned defs from + // VM. The uses of BB's Phi (if it exists) in BB will be replaced by the + // incoming def into the Phi from P1. + // Instructions cloned into the predecessor are in practice sometimes + // simplified, so disable the use of the template, and create an access from + // scratch. + PhiToDefMap MPhiMap; + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) + MPhiMap[MPhi] = MPhi->getIncomingValueForBlock(P1); + cloneUsesAndDefs(BB, P1, VM, MPhiMap, /*CloneWasSimplified=*/true); +} + +template <typename Iter> +void MemorySSAUpdater::privateUpdateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, Iter ValuesBegin, Iter ValuesEnd, + DominatorTree &DT) { + SmallVector<CFGUpdate, 4> Updates; + // Update/insert phis in all successors of exit blocks. + for (auto *Exit : ExitBlocks) + for (const ValueToValueMapTy *VMap : make_range(ValuesBegin, ValuesEnd)) + if (BasicBlock *NewExit = cast_or_null<BasicBlock>(VMap->lookup(Exit))) { + BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); + Updates.push_back({DT.Insert, NewExit, ExitSucc}); + } + applyInsertUpdates(Updates, DT); +} + +void MemorySSAUpdater::updateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, const ValueToValueMapTy &VMap, + DominatorTree &DT) { + const ValueToValueMapTy *const Arr[] = {&VMap}; + privateUpdateExitBlocksForClonedLoop(ExitBlocks, std::begin(Arr), + std::end(Arr), DT); +} + +void MemorySSAUpdater::updateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, + ArrayRef<std::unique_ptr<ValueToValueMapTy>> VMaps, DominatorTree &DT) { + auto GetPtr = [&](const std::unique_ptr<ValueToValueMapTy> &I) { + return I.get(); + }; + using MappedIteratorType = + mapped_iterator<const std::unique_ptr<ValueToValueMapTy> *, + decltype(GetPtr)>; + auto MapBegin = MappedIteratorType(VMaps.begin(), GetPtr); + auto MapEnd = MappedIteratorType(VMaps.end(), GetPtr); + privateUpdateExitBlocksForClonedLoop(ExitBlocks, MapBegin, MapEnd, DT); +} + +void MemorySSAUpdater::applyUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT) { + SmallVector<CFGUpdate, 4> RevDeleteUpdates; + SmallVector<CFGUpdate, 4> InsertUpdates; + for (auto &Update : Updates) { + if (Update.getKind() == DT.Insert) + InsertUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()}); + else + RevDeleteUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()}); + } + + if (!RevDeleteUpdates.empty()) { + // Update for inserted edges: use newDT and snapshot CFG as if deletes had + // not occurred. + // FIXME: This creates a new DT, so it's more expensive to do mix + // delete/inserts vs just inserts. We can do an incremental update on the DT + // to revert deletes, than re-delete the edges. Teaching DT to do this, is + // part of a pending cleanup. + DominatorTree NewDT(DT, RevDeleteUpdates); + GraphDiff<BasicBlock *> GD(RevDeleteUpdates); + applyInsertUpdates(InsertUpdates, NewDT, &GD); + } else { + GraphDiff<BasicBlock *> GD; + applyInsertUpdates(InsertUpdates, DT, &GD); + } + + // Update for deleted edges + for (auto &Update : RevDeleteUpdates) + removeEdge(Update.getFrom(), Update.getTo()); +} + +void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT) { + GraphDiff<BasicBlock *> GD; + applyInsertUpdates(Updates, DT, &GD); +} + +void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT, + const GraphDiff<BasicBlock *> *GD) { + // Get recursive last Def, assuming well formed MSSA and updated DT. + auto GetLastDef = [&](BasicBlock *BB) -> MemoryAccess * { + while (true) { + MemorySSA::DefsList *Defs = MSSA->getWritableBlockDefs(BB); + // Return last Def or Phi in BB, if it exists. + if (Defs) + return &*(--Defs->end()); + + // Check number of predecessors, we only care if there's more than one. + unsigned Count = 0; + BasicBlock *Pred = nullptr; + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) { + Pred = Pair.second; + Count++; + if (Count == 2) + break; + } + + // If BB has multiple predecessors, get last definition from IDom. + if (Count != 1) { + // [SimpleLoopUnswitch] If BB is a dead block, about to be deleted, its + // DT is invalidated. Return LoE as its last def. This will be added to + // MemoryPhi node, and later deleted when the block is deleted. + if (!DT.getNode(BB)) + return MSSA->getLiveOnEntryDef(); + if (auto *IDom = DT.getNode(BB)->getIDom()) + if (IDom->getBlock() != BB) { + BB = IDom->getBlock(); + continue; + } + return MSSA->getLiveOnEntryDef(); + } else { + // Single predecessor, BB cannot be dead. GetLastDef of Pred. + assert(Count == 1 && Pred && "Single predecessor expected."); + // BB can be unreachable though, return LoE if that is the case. + if (!DT.getNode(BB)) + return MSSA->getLiveOnEntryDef(); + BB = Pred; + } + }; + llvm_unreachable("Unable to get last definition."); + }; + + // Get nearest IDom given a set of blocks. + // TODO: this can be optimized by starting the search at the node with the + // lowest level (highest in the tree). + auto FindNearestCommonDominator = + [&](const SmallSetVector<BasicBlock *, 2> &BBSet) -> BasicBlock * { + BasicBlock *PrevIDom = *BBSet.begin(); + for (auto *BB : BBSet) + PrevIDom = DT.findNearestCommonDominator(PrevIDom, BB); + return PrevIDom; + }; + + // Get all blocks that dominate PrevIDom, stop when reaching CurrIDom. Do not + // include CurrIDom. + auto GetNoLongerDomBlocks = + [&](BasicBlock *PrevIDom, BasicBlock *CurrIDom, + SmallVectorImpl<BasicBlock *> &BlocksPrevDom) { + if (PrevIDom == CurrIDom) + return; + BlocksPrevDom.push_back(PrevIDom); + BasicBlock *NextIDom = PrevIDom; + while (BasicBlock *UpIDom = + DT.getNode(NextIDom)->getIDom()->getBlock()) { + if (UpIDom == CurrIDom) + break; + BlocksPrevDom.push_back(UpIDom); + NextIDom = UpIDom; + } + }; + + // Map a BB to its predecessors: added + previously existing. To get a + // deterministic order, store predecessors as SetVectors. The order in each + // will be defined by the order in Updates (fixed) and the order given by + // children<> (also fixed). Since we further iterate over these ordered sets, + // we lose the information of multiple edges possibly existing between two + // blocks, so we'll keep and EdgeCount map for that. + // An alternate implementation could keep unordered set for the predecessors, + // traverse either Updates or children<> each time to get the deterministic + // order, and drop the usage of EdgeCount. This alternate approach would still + // require querying the maps for each predecessor, and children<> call has + // additional computation inside for creating the snapshot-graph predecessors. + // As such, we favor using a little additional storage and less compute time. + // This decision can be revisited if we find the alternative more favorable. + + struct PredInfo { + SmallSetVector<BasicBlock *, 2> Added; + SmallSetVector<BasicBlock *, 2> Prev; + }; + SmallDenseMap<BasicBlock *, PredInfo> PredMap; + + for (auto &Edge : Updates) { + BasicBlock *BB = Edge.getTo(); + auto &AddedBlockSet = PredMap[BB].Added; + AddedBlockSet.insert(Edge.getFrom()); + } + + // Store all existing predecessor for each BB, at least one must exist. + SmallDenseMap<std::pair<BasicBlock *, BasicBlock *>, int> EdgeCountMap; + SmallPtrSet<BasicBlock *, 2> NewBlocks; + for (auto &BBPredPair : PredMap) { + auto *BB = BBPredPair.first; + const auto &AddedBlockSet = BBPredPair.second.Added; + auto &PrevBlockSet = BBPredPair.second.Prev; + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) { + BasicBlock *Pi = Pair.second; + if (!AddedBlockSet.count(Pi)) + PrevBlockSet.insert(Pi); + EdgeCountMap[{Pi, BB}]++; + } + + if (PrevBlockSet.empty()) { + assert(pred_size(BB) == AddedBlockSet.size() && "Duplicate edges added."); + LLVM_DEBUG( + dbgs() + << "Adding a predecessor to a block with no predecessors. " + "This must be an edge added to a new, likely cloned, block. " + "Its memory accesses must be already correct, assuming completed " + "via the updateExitBlocksForClonedLoop API. " + "Assert a single such edge is added so no phi addition or " + "additional processing is required.\n"); + assert(AddedBlockSet.size() == 1 && + "Can only handle adding one predecessor to a new block."); + // Need to remove new blocks from PredMap. Remove below to not invalidate + // iterator here. + NewBlocks.insert(BB); + } + } + // Nothing to process for new/cloned blocks. + for (auto *BB : NewBlocks) + PredMap.erase(BB); + + SmallVector<BasicBlock *, 16> BlocksWithDefsToReplace; + SmallVector<WeakVH, 8> InsertedPhis; + + // First create MemoryPhis in all blocks that don't have one. Create in the + // order found in Updates, not in PredMap, to get deterministic numbering. + for (auto &Edge : Updates) { + BasicBlock *BB = Edge.getTo(); + if (PredMap.count(BB) && !MSSA->getMemoryAccess(BB)) + InsertedPhis.push_back(MSSA->createMemoryPhi(BB)); + } + + // Now we'll fill in the MemoryPhis with the right incoming values. + for (auto &BBPredPair : PredMap) { + auto *BB = BBPredPair.first; + const auto &PrevBlockSet = BBPredPair.second.Prev; + const auto &AddedBlockSet = BBPredPair.second.Added; + assert(!PrevBlockSet.empty() && + "At least one previous predecessor must exist."); + + // TODO: if this becomes a bottleneck, we can save on GetLastDef calls by + // keeping this map before the loop. We can reuse already populated entries + // if an edge is added from the same predecessor to two different blocks, + // and this does happen in rotate. Note that the map needs to be updated + // when deleting non-necessary phis below, if the phi is in the map by + // replacing the value with DefP1. + SmallDenseMap<BasicBlock *, MemoryAccess *> LastDefAddedPred; + for (auto *AddedPred : AddedBlockSet) { + auto *DefPn = GetLastDef(AddedPred); + assert(DefPn != nullptr && "Unable to find last definition."); + LastDefAddedPred[AddedPred] = DefPn; + } + + MemoryPhi *NewPhi = MSSA->getMemoryAccess(BB); + // If Phi is not empty, add an incoming edge from each added pred. Must + // still compute blocks with defs to replace for this block below. + if (NewPhi->getNumOperands()) { + for (auto *Pred : AddedBlockSet) { + auto *LastDefForPred = LastDefAddedPred[Pred]; + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(LastDefForPred, Pred); + } + } else { + // Pick any existing predecessor and get its definition. All other + // existing predecessors should have the same one, since no phi existed. + auto *P1 = *PrevBlockSet.begin(); + MemoryAccess *DefP1 = GetLastDef(P1); + + // Check DefP1 against all Defs in LastDefPredPair. If all the same, + // nothing to add. + bool InsertPhi = false; + for (auto LastDefPredPair : LastDefAddedPred) + if (DefP1 != LastDefPredPair.second) { + InsertPhi = true; + break; + } + if (!InsertPhi) { + // Since NewPhi may be used in other newly added Phis, replace all uses + // of NewPhi with the definition coming from all predecessors (DefP1), + // before deleting it. + NewPhi->replaceAllUsesWith(DefP1); + removeMemoryAccess(NewPhi); + continue; + } + + // Update Phi with new values for new predecessors and old value for all + // other predecessors. Since AddedBlockSet and PrevBlockSet are ordered + // sets, the order of entries in NewPhi is deterministic. + for (auto *Pred : AddedBlockSet) { + auto *LastDefForPred = LastDefAddedPred[Pred]; + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(LastDefForPred, Pred); + } + for (auto *Pred : PrevBlockSet) + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(DefP1, Pred); + } + + // Get all blocks that used to dominate BB and no longer do after adding + // AddedBlockSet, where PrevBlockSet are the previously known predecessors. + assert(DT.getNode(BB)->getIDom() && "BB does not have valid idom"); + BasicBlock *PrevIDom = FindNearestCommonDominator(PrevBlockSet); + assert(PrevIDom && "Previous IDom should exists"); + BasicBlock *NewIDom = DT.getNode(BB)->getIDom()->getBlock(); + assert(NewIDom && "BB should have a new valid idom"); + assert(DT.dominates(NewIDom, PrevIDom) && + "New idom should dominate old idom"); + GetNoLongerDomBlocks(PrevIDom, NewIDom, BlocksWithDefsToReplace); + } + + tryRemoveTrivialPhis(InsertedPhis); + // Create the set of blocks that now have a definition. We'll use this to + // compute IDF and add Phis there next. + SmallVector<BasicBlock *, 8> BlocksToProcess; + for (auto &VH : InsertedPhis) + if (auto *MPhi = cast_or_null<MemoryPhi>(VH)) + BlocksToProcess.push_back(MPhi->getBlock()); + + // Compute IDF and add Phis in all IDF blocks that do not have one. + SmallVector<BasicBlock *, 32> IDFBlocks; + if (!BlocksToProcess.empty()) { + ForwardIDFCalculator IDFs(DT, GD); + SmallPtrSet<BasicBlock *, 16> DefiningBlocks(BlocksToProcess.begin(), + BlocksToProcess.end()); + IDFs.setDefiningBlocks(DefiningBlocks); + IDFs.calculate(IDFBlocks); + + SmallSetVector<MemoryPhi *, 4> PhisToFill; + // First create all needed Phis. + for (auto *BBIDF : IDFBlocks) + if (!MSSA->getMemoryAccess(BBIDF)) { + auto *IDFPhi = MSSA->createMemoryPhi(BBIDF); + InsertedPhis.push_back(IDFPhi); + PhisToFill.insert(IDFPhi); + } + // Then update or insert their correct incoming values. + for (auto *BBIDF : IDFBlocks) { + auto *IDFPhi = MSSA->getMemoryAccess(BBIDF); + assert(IDFPhi && "Phi must exist"); + if (!PhisToFill.count(IDFPhi)) { + // Update existing Phi. + // FIXME: some updates may be redundant, try to optimize and skip some. + for (unsigned I = 0, E = IDFPhi->getNumIncomingValues(); I < E; ++I) + IDFPhi->setIncomingValue(I, GetLastDef(IDFPhi->getIncomingBlock(I))); + } else { + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BBIDF})) { + BasicBlock *Pi = Pair.second; + IDFPhi->addIncoming(GetLastDef(Pi), Pi); + } + } + } + } + + // Now for all defs in BlocksWithDefsToReplace, if there are uses they no + // longer dominate, replace those with the closest dominating def. + // This will also update optimized accesses, as they're also uses. + for (auto *BlockWithDefsToReplace : BlocksWithDefsToReplace) { + if (auto DefsList = MSSA->getWritableBlockDefs(BlockWithDefsToReplace)) { + for (auto &DefToReplaceUses : *DefsList) { + BasicBlock *DominatingBlock = DefToReplaceUses.getBlock(); + Value::use_iterator UI = DefToReplaceUses.use_begin(), + E = DefToReplaceUses.use_end(); + for (; UI != E;) { + Use &U = *UI; + ++UI; + MemoryAccess *Usr = cast<MemoryAccess>(U.getUser()); + if (MemoryPhi *UsrPhi = dyn_cast<MemoryPhi>(Usr)) { + BasicBlock *DominatedBlock = UsrPhi->getIncomingBlock(U); + if (!DT.dominates(DominatingBlock, DominatedBlock)) + U.set(GetLastDef(DominatedBlock)); + } else { + BasicBlock *DominatedBlock = Usr->getBlock(); + if (!DT.dominates(DominatingBlock, DominatedBlock)) { + if (auto *DomBlPhi = MSSA->getMemoryAccess(DominatedBlock)) + U.set(DomBlPhi); + else { + auto *IDom = DT.getNode(DominatedBlock)->getIDom(); + assert(IDom && "Block must have a valid IDom."); + U.set(GetLastDef(IDom->getBlock())); + } + cast<MemoryUseOrDef>(Usr)->resetOptimized(); + } + } + } + } + } + } + tryRemoveTrivialPhis(InsertedPhis); +} + +// Move What before Where in the MemorySSA IR. +template <class WhereType> +void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB, + WhereType Where) { + // Mark MemoryPhi users of What not to be optimized. + for (auto *U : What->users()) + if (MemoryPhi *PhiUser = dyn_cast<MemoryPhi>(U)) + NonOptPhis.insert(PhiUser); + + // Replace all our users with our defining access. + What->replaceAllUsesWith(What->getDefiningAccess()); + + // Let MemorySSA take care of moving it around in the lists. + MSSA->moveTo(What, BB, Where); + + // Now reinsert it into the IR and do whatever fixups needed. + if (auto *MD = dyn_cast<MemoryDef>(What)) + insertDef(MD, /*RenameUses=*/true); + else + insertUse(cast<MemoryUse>(What), /*RenameUses=*/true); + + // Clear dangling pointers. We added all MemoryPhi users, but not all + // of them are removed by fixupDefs(). + NonOptPhis.clear(); +} + +// Move What before Where in the MemorySSA IR. +void MemorySSAUpdater::moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where) { + moveTo(What, Where->getBlock(), Where->getIterator()); +} + +// Move What after Where in the MemorySSA IR. +void MemorySSAUpdater::moveAfter(MemoryUseOrDef *What, MemoryUseOrDef *Where) { + moveTo(What, Where->getBlock(), ++Where->getIterator()); +} + +void MemorySSAUpdater::moveToPlace(MemoryUseOrDef *What, BasicBlock *BB, + MemorySSA::InsertionPlace Where) { + return moveTo(What, BB, Where); +} + +// All accesses in To used to be in From. Move to end and update access lists. +void MemorySSAUpdater::moveAllAccesses(BasicBlock *From, BasicBlock *To, + Instruction *Start) { + + MemorySSA::AccessList *Accs = MSSA->getWritableBlockAccesses(From); + if (!Accs) + return; + + assert(Start->getParent() == To && "Incorrect Start instruction"); + MemoryAccess *FirstInNew = nullptr; + for (Instruction &I : make_range(Start->getIterator(), To->end())) + if ((FirstInNew = MSSA->getMemoryAccess(&I))) + break; + if (FirstInNew) { + auto *MUD = cast<MemoryUseOrDef>(FirstInNew); + do { + auto NextIt = ++MUD->getIterator(); + MemoryUseOrDef *NextMUD = (!Accs || NextIt == Accs->end()) + ? nullptr + : cast<MemoryUseOrDef>(&*NextIt); + MSSA->moveTo(MUD, To, MemorySSA::End); + // Moving MUD from Accs in the moveTo above, may delete Accs, so we need + // to retrieve it again. + Accs = MSSA->getWritableBlockAccesses(From); + MUD = NextMUD; + } while (MUD); + } + + // If all accesses were moved and only a trivial Phi remains, we try to remove + // that Phi. This is needed when From is going to be deleted. + auto *Defs = MSSA->getWritableBlockDefs(From); + if (Defs && !Defs->empty()) + if (auto *Phi = dyn_cast<MemoryPhi>(&*Defs->begin())) + tryRemoveTrivialPhi(Phi); +} + +void MemorySSAUpdater::moveAllAfterSpliceBlocks(BasicBlock *From, + BasicBlock *To, + Instruction *Start) { + assert(MSSA->getBlockAccesses(To) == nullptr && + "To block is expected to be free of MemoryAccesses."); + moveAllAccesses(From, To, Start); + for (BasicBlock *Succ : successors(To)) + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ)) + MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To); +} + +void MemorySSAUpdater::moveAllAfterMergeBlocks(BasicBlock *From, BasicBlock *To, + Instruction *Start) { + assert(From->getUniquePredecessor() == To && + "From block is expected to have a single predecessor (To)."); + moveAllAccesses(From, To, Start); + for (BasicBlock *Succ : successors(From)) + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ)) + MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To); +} + +/// If all arguments of a MemoryPHI are defined by the same incoming +/// argument, return that argument. +static MemoryAccess *onlySingleValue(MemoryPhi *MP) { + MemoryAccess *MA = nullptr; + + for (auto &Arg : MP->operands()) { + if (!MA) + MA = cast<MemoryAccess>(Arg); + else if (MA != Arg) + return nullptr; + } + return MA; +} + +void MemorySSAUpdater::wireOldPredecessorsToNewImmediatePredecessor( + BasicBlock *Old, BasicBlock *New, ArrayRef<BasicBlock *> Preds, + bool IdenticalEdgesWereMerged) { + assert(!MSSA->getWritableBlockAccesses(New) && + "Access list should be null for a new block."); + MemoryPhi *Phi = MSSA->getMemoryAccess(Old); + if (!Phi) + return; + if (Old->hasNPredecessors(1)) { + assert(pred_size(New) == Preds.size() && + "Should have moved all predecessors."); + MSSA->moveTo(Phi, New, MemorySSA::Beginning); + } else { + assert(!Preds.empty() && "Must be moving at least one predecessor to the " + "new immediate predecessor."); + MemoryPhi *NewPhi = MSSA->createMemoryPhi(New); + SmallPtrSet<BasicBlock *, 16> PredsSet(Preds.begin(), Preds.end()); + // Currently only support the case of removing a single incoming edge when + // identical edges were not merged. + if (!IdenticalEdgesWereMerged) + assert(PredsSet.size() == Preds.size() && + "If identical edges were not merged, we cannot have duplicate " + "blocks in the predecessors"); + Phi->unorderedDeleteIncomingIf([&](MemoryAccess *MA, BasicBlock *B) { + if (PredsSet.count(B)) { + NewPhi->addIncoming(MA, B); + if (!IdenticalEdgesWereMerged) + PredsSet.erase(B); + return true; + } + return false; + }); + Phi->addIncoming(NewPhi, New); + tryRemoveTrivialPhi(NewPhi); + } +} + +void MemorySSAUpdater::removeMemoryAccess(MemoryAccess *MA, bool OptimizePhis) { + assert(!MSSA->isLiveOnEntryDef(MA) && + "Trying to remove the live on entry def"); + // We can only delete phi nodes if they have no uses, or we can replace all + // uses with a single definition. + MemoryAccess *NewDefTarget = nullptr; + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { + // Note that it is sufficient to know that all edges of the phi node have + // the same argument. If they do, by the definition of dominance frontiers + // (which we used to place this phi), that argument must dominate this phi, + // and thus, must dominate the phi's uses, and so we will not hit the assert + // below. + NewDefTarget = onlySingleValue(MP); + assert((NewDefTarget || MP->use_empty()) && + "We can't delete this memory phi"); + } else { + NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); + } + + SmallSetVector<MemoryPhi *, 4> PhisToCheck; + + // Re-point the uses at our defining access + if (!isa<MemoryUse>(MA) && !MA->use_empty()) { + // Reset optimized on users of this store, and reset the uses. + // A few notes: + // 1. This is a slightly modified version of RAUW to avoid walking the + // uses twice here. + // 2. If we wanted to be complete, we would have to reset the optimized + // flags on users of phi nodes if doing the below makes a phi node have all + // the same arguments. Instead, we prefer users to removeMemoryAccess those + // phi nodes, because doing it here would be N^3. + if (MA->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); + // Note: We assume MemorySSA is not used in metadata since it's not really + // part of the IR. + + while (!MA->use_empty()) { + Use &U = *MA->use_begin(); + if (auto *MUD = dyn_cast<MemoryUseOrDef>(U.getUser())) + MUD->resetOptimized(); + if (OptimizePhis) + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U.getUser())) + PhisToCheck.insert(MP); + U.set(NewDefTarget); + } + } + + // The call below to erase will destroy MA, so we can't change the order we + // are doing things here + MSSA->removeFromLookups(MA); + MSSA->removeFromLists(MA); + + // Optionally optimize Phi uses. This will recursively remove trivial phis. + if (!PhisToCheck.empty()) { + SmallVector<WeakVH, 16> PhisToOptimize{PhisToCheck.begin(), + PhisToCheck.end()}; + PhisToCheck.clear(); + + unsigned PhisSize = PhisToOptimize.size(); + while (PhisSize-- > 0) + if (MemoryPhi *MP = + cast_or_null<MemoryPhi>(PhisToOptimize.pop_back_val())) + tryRemoveTrivialPhi(MP); + } +} + +void MemorySSAUpdater::removeBlocks( + const SmallSetVector<BasicBlock *, 8> &DeadBlocks) { + // First delete all uses of BB in MemoryPhis. + for (BasicBlock *BB : DeadBlocks) { + Instruction *TI = BB->getTerminator(); + assert(TI && "Basic block expected to have a terminator instruction"); + for (BasicBlock *Succ : successors(TI)) + if (!DeadBlocks.count(Succ)) + if (MemoryPhi *MP = MSSA->getMemoryAccess(Succ)) { + MP->unorderedDeleteIncomingBlock(BB); + tryRemoveTrivialPhi(MP); + } + // Drop all references of all accesses in BB + if (MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB)) + for (MemoryAccess &MA : *Acc) + MA.dropAllReferences(); + } + + // Next, delete all memory accesses in each block + for (BasicBlock *BB : DeadBlocks) { + MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB); + if (!Acc) + continue; + for (auto AB = Acc->begin(), AE = Acc->end(); AB != AE;) { + MemoryAccess *MA = &*AB; + ++AB; + MSSA->removeFromLookups(MA); + MSSA->removeFromLists(MA); + } + } +} + +void MemorySSAUpdater::tryRemoveTrivialPhis(ArrayRef<WeakVH> UpdatedPHIs) { + for (auto &VH : UpdatedPHIs) + if (auto *MPhi = cast_or_null<MemoryPhi>(VH)) + tryRemoveTrivialPhi(MPhi); +} + +void MemorySSAUpdater::changeToUnreachable(const Instruction *I) { + const BasicBlock *BB = I->getParent(); + // Remove memory accesses in BB for I and all following instructions. + auto BBI = I->getIterator(), BBE = BB->end(); + // FIXME: If this becomes too expensive, iterate until the first instruction + // with a memory access, then iterate over MemoryAccesses. + while (BBI != BBE) + removeMemoryAccess(&*(BBI++)); + // Update phis in BB's successors to remove BB. + SmallVector<WeakVH, 16> UpdatedPHIs; + for (const BasicBlock *Successor : successors(BB)) { + removeDuplicatePhiEdgesBetween(BB, Successor); + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Successor)) { + MPhi->unorderedDeleteIncomingBlock(BB); + UpdatedPHIs.push_back(MPhi); + } + } + // Optimize trivial phis. + tryRemoveTrivialPhis(UpdatedPHIs); +} + +void MemorySSAUpdater::changeCondBranchToUnconditionalTo(const BranchInst *BI, + const BasicBlock *To) { + const BasicBlock *BB = BI->getParent(); + SmallVector<WeakVH, 16> UpdatedPHIs; + for (const BasicBlock *Succ : successors(BB)) { + removeDuplicatePhiEdgesBetween(BB, Succ); + if (Succ != To) + if (auto *MPhi = MSSA->getMemoryAccess(Succ)) { + MPhi->unorderedDeleteIncomingBlock(BB); + UpdatedPHIs.push_back(MPhi); + } + } + // Optimize trivial phis. + tryRemoveTrivialPhis(UpdatedPHIs); +} + +MemoryAccess *MemorySSAUpdater::createMemoryAccessInBB( + Instruction *I, MemoryAccess *Definition, const BasicBlock *BB, + MemorySSA::InsertionPlace Point) { + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsForBlock(NewAccess, BB, Point); + return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessBefore( + Instruction *I, MemoryAccess *Definition, MemoryUseOrDef *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), + InsertPt->getIterator()); + return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessAfter( + Instruction *I, MemoryAccess *Definition, MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), + ++InsertPt->getIterator()); + return NewAccess; +} diff --git a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp new file mode 100644 index 000000000000..519242759824 --- /dev/null +++ b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp @@ -0,0 +1,127 @@ +//===-- ModuleDebugInfoPrinter.cpp - Prints module debug info metadata ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass decodes the debug info metadata in a module and prints in a +// (sufficiently-prepared-) human-readable form. +// +// For example, run this pass from opt along with the -analyze option, and +// it'll print to standard output. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { + class ModuleDebugInfoPrinter : public ModulePass { + DebugInfoFinder Finder; + public: + static char ID; // Pass identification, replacement for typeid + ModuleDebugInfoPrinter() : ModulePass(ID) { + initializeModuleDebugInfoPrinterPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + void print(raw_ostream &O, const Module *M) const override; + }; +} + +char ModuleDebugInfoPrinter::ID = 0; +INITIALIZE_PASS(ModuleDebugInfoPrinter, "module-debuginfo", + "Decodes module-level debug info", false, true) + +ModulePass *llvm::createModuleDebugInfoPrinterPass() { + return new ModuleDebugInfoPrinter(); +} + +bool ModuleDebugInfoPrinter::runOnModule(Module &M) { + Finder.processModule(M); + return false; +} + +static void printFile(raw_ostream &O, StringRef Filename, StringRef Directory, + unsigned Line = 0) { + if (Filename.empty()) + return; + + O << " from "; + if (!Directory.empty()) + O << Directory << "/"; + O << Filename; + if (Line) + O << ":" << Line; +} + +void ModuleDebugInfoPrinter::print(raw_ostream &O, const Module *M) const { + // Printing the nodes directly isn't particularly helpful (since they + // reference other nodes that won't be printed, particularly for the + // filenames), so just print a few useful things. + for (DICompileUnit *CU : Finder.compile_units()) { + O << "Compile unit: "; + auto Lang = dwarf::LanguageString(CU->getSourceLanguage()); + if (!Lang.empty()) + O << Lang; + else + O << "unknown-language(" << CU->getSourceLanguage() << ")"; + printFile(O, CU->getFilename(), CU->getDirectory()); + O << '\n'; + } + + for (DISubprogram *S : Finder.subprograms()) { + O << "Subprogram: " << S->getName(); + printFile(O, S->getFilename(), S->getDirectory(), S->getLine()); + if (!S->getLinkageName().empty()) + O << " ('" << S->getLinkageName() << "')"; + O << '\n'; + } + + for (auto GVU : Finder.global_variables()) { + const auto *GV = GVU->getVariable(); + O << "Global variable: " << GV->getName(); + printFile(O, GV->getFilename(), GV->getDirectory(), GV->getLine()); + if (!GV->getLinkageName().empty()) + O << " ('" << GV->getLinkageName() << "')"; + O << '\n'; + } + + for (const DIType *T : Finder.types()) { + O << "Type:"; + if (!T->getName().empty()) + O << ' ' << T->getName(); + printFile(O, T->getFilename(), T->getDirectory(), T->getLine()); + if (auto *BT = dyn_cast<DIBasicType>(T)) { + O << " "; + auto Encoding = dwarf::AttributeEncodingString(BT->getEncoding()); + if (!Encoding.empty()) + O << Encoding; + else + O << "unknown-encoding(" << BT->getEncoding() << ')'; + } else { + O << ' '; + auto Tag = dwarf::TagString(T->getTag()); + if (!Tag.empty()) + O << Tag; + else + O << "unknown-tag(" << T->getTag() << ")"; + } + if (auto *CT = dyn_cast<DICompositeType>(T)) { + if (auto *S = CT->getRawIdentifier()) + O << " (identifier: '" << S->getString() << "')"; + } + O << '\n'; + } +} diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp new file mode 100644 index 000000000000..8232bf07cafc --- /dev/null +++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -0,0 +1,881 @@ +//===- ModuleSummaryAnalysis.cpp - Module summary index builder -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass builds a ModuleSummaryIndex object for the module, to be written +// to bitcode or LLVM assembly. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndex.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/Object/ModuleSymbolTable.h" +#include "llvm/Object/SymbolicFile.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "module-summary-analysis" + +// Option to force edges cold which will block importing when the +// -import-cold-multiplier is set to 0. Useful for debugging. +FunctionSummary::ForceSummaryHotnessType ForceSummaryEdgesCold = + FunctionSummary::FSHT_None; +cl::opt<FunctionSummary::ForceSummaryHotnessType, true> FSEC( + "force-summary-edges-cold", cl::Hidden, cl::location(ForceSummaryEdgesCold), + cl::desc("Force all edges in the function summary to cold"), + cl::values(clEnumValN(FunctionSummary::FSHT_None, "none", "None."), + clEnumValN(FunctionSummary::FSHT_AllNonCritical, + "all-non-critical", "All non-critical edges."), + clEnumValN(FunctionSummary::FSHT_All, "all", "All edges."))); + +cl::opt<std::string> ModuleSummaryDotFile( + "module-summary-dot-file", cl::init(""), cl::Hidden, + cl::value_desc("filename"), + cl::desc("File to emit dot graph of new summary into.")); + +// Walk through the operands of a given User via worklist iteration and populate +// the set of GlobalValue references encountered. Invoked either on an +// Instruction or a GlobalVariable (which walks its initializer). +// Return true if any of the operands contains blockaddress. This is important +// to know when computing summary for global var, because if global variable +// references basic block address we can't import it separately from function +// containing that basic block. For simplicity we currently don't import such +// global vars at all. When importing function we aren't interested if any +// instruction in it takes an address of any basic block, because instruction +// can only take an address of basic block located in the same function. +static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, + SetVector<ValueInfo> &RefEdges, + SmallPtrSet<const User *, 8> &Visited) { + bool HasBlockAddress = false; + SmallVector<const User *, 32> Worklist; + Worklist.push_back(CurUser); + + while (!Worklist.empty()) { + const User *U = Worklist.pop_back_val(); + + if (!Visited.insert(U).second) + continue; + + ImmutableCallSite CS(U); + + for (const auto &OI : U->operands()) { + const User *Operand = dyn_cast<User>(OI); + if (!Operand) + continue; + if (isa<BlockAddress>(Operand)) { + HasBlockAddress = true; + continue; + } + if (auto *GV = dyn_cast<GlobalValue>(Operand)) { + // We have a reference to a global value. This should be added to + // the reference set unless it is a callee. Callees are handled + // specially by WriteFunction and are added to a separate list. + if (!(CS && CS.isCallee(&OI))) + RefEdges.insert(Index.getOrInsertValueInfo(GV)); + continue; + } + Worklist.push_back(Operand); + } + } + return HasBlockAddress; +} + +static CalleeInfo::HotnessType getHotness(uint64_t ProfileCount, + ProfileSummaryInfo *PSI) { + if (!PSI) + return CalleeInfo::HotnessType::Unknown; + if (PSI->isHotCount(ProfileCount)) + return CalleeInfo::HotnessType::Hot; + if (PSI->isColdCount(ProfileCount)) + return CalleeInfo::HotnessType::Cold; + return CalleeInfo::HotnessType::None; +} + +static bool isNonRenamableLocal(const GlobalValue &GV) { + return GV.hasSection() && GV.hasLocalLinkage(); +} + +/// Determine whether this call has all constant integer arguments (excluding +/// "this") and summarize it to VCalls or ConstVCalls as appropriate. +static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid, + SetVector<FunctionSummary::VFuncId> &VCalls, + SetVector<FunctionSummary::ConstVCall> &ConstVCalls) { + std::vector<uint64_t> Args; + // Start from the second argument to skip the "this" pointer. + for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) { + auto *CI = dyn_cast<ConstantInt>(Arg); + if (!CI || CI->getBitWidth() > 64) { + VCalls.insert({Guid, Call.Offset}); + return; + } + Args.push_back(CI->getZExtValue()); + } + ConstVCalls.insert({{Guid, Call.Offset}, std::move(Args)}); +} + +/// If this intrinsic call requires that we add information to the function +/// summary, do so via the non-constant reference arguments. +static void addIntrinsicToSummary( + const CallInst *CI, SetVector<GlobalValue::GUID> &TypeTests, + SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls, + SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls, + SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls, + SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls, + DominatorTree &DT) { + switch (CI->getCalledFunction()->getIntrinsicID()) { + case Intrinsic::type_test: { + auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); + auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); + if (!TypeId) + break; + GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + + // Produce a summary from type.test intrinsics. We only summarize type.test + // intrinsics that are used other than by an llvm.assume intrinsic. + // Intrinsics that are assumed are relevant only to the devirtualization + // pass, not the type test lowering pass. + bool HasNonAssumeUses = llvm::any_of(CI->uses(), [](const Use &CIU) { + auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser()); + if (!AssumeCI) + return true; + Function *F = AssumeCI->getCalledFunction(); + return !F || F->getIntrinsicID() != Intrinsic::assume; + }); + if (HasNonAssumeUses) + TypeTests.insert(Guid); + + SmallVector<DevirtCallSite, 4> DevirtCalls; + SmallVector<CallInst *, 4> Assumes; + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); + for (auto &Call : DevirtCalls) + addVCallToSet(Call, Guid, TypeTestAssumeVCalls, + TypeTestAssumeConstVCalls); + + break; + } + + case Intrinsic::type_checked_load: { + auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(2)); + auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); + if (!TypeId) + break; + GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + + SmallVector<DevirtCallSite, 4> DevirtCalls; + SmallVector<Instruction *, 4> LoadedPtrs; + SmallVector<Instruction *, 4> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI, DT); + // Any non-call uses of the result of llvm.type.checked.load will + // prevent us from optimizing away the llvm.type.test. + if (HasNonCallUses) + TypeTests.insert(Guid); + for (auto &Call : DevirtCalls) + addVCallToSet(Call, Guid, TypeCheckedLoadVCalls, + TypeCheckedLoadConstVCalls); + + break; + } + default: + break; + } +} + +static bool isNonVolatileLoad(const Instruction *I) { + if (const auto *LI = dyn_cast<LoadInst>(I)) + return !LI->isVolatile(); + + return false; +} + +static bool isNonVolatileStore(const Instruction *I) { + if (const auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isVolatile(); + + return false; +} + +static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, + const Function &F, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, DominatorTree &DT, + bool HasLocalsInUsedOrAsm, + DenseSet<GlobalValue::GUID> &CantBePromoted, + bool IsThinLTO) { + // Summary not currently supported for anonymous functions, they should + // have been named. + assert(F.hasName()); + + unsigned NumInsts = 0; + // Map from callee ValueId to profile count. Used to accumulate profile + // counts for all static calls to a given callee. + MapVector<ValueInfo, CalleeInfo> CallGraphEdges; + SetVector<ValueInfo> RefEdges, LoadRefEdges, StoreRefEdges; + SetVector<GlobalValue::GUID> TypeTests; + SetVector<FunctionSummary::VFuncId> TypeTestAssumeVCalls, + TypeCheckedLoadVCalls; + SetVector<FunctionSummary::ConstVCall> TypeTestAssumeConstVCalls, + TypeCheckedLoadConstVCalls; + ICallPromotionAnalysis ICallAnalysis; + SmallPtrSet<const User *, 8> Visited; + + // Add personality function, prefix data and prologue data to function's ref + // list. + findRefEdges(Index, &F, RefEdges, Visited); + std::vector<const Instruction *> NonVolatileLoads; + std::vector<const Instruction *> NonVolatileStores; + + bool HasInlineAsmMaybeReferencingInternal = false; + for (const BasicBlock &BB : F) + for (const Instruction &I : BB) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + ++NumInsts; + // Regular LTO module doesn't participate in ThinLTO import, + // so no reference from it can be read/writeonly, since this + // would require importing variable as local copy + if (IsThinLTO) { + if (isNonVolatileLoad(&I)) { + // Postpone processing of non-volatile load instructions + // See comments below + Visited.insert(&I); + NonVolatileLoads.push_back(&I); + continue; + } else if (isNonVolatileStore(&I)) { + Visited.insert(&I); + NonVolatileStores.push_back(&I); + // All references from second operand of store (destination address) + // can be considered write-only if they're not referenced by any + // non-store instruction. References from first operand of store + // (stored value) can't be treated either as read- or as write-only + // so we add them to RefEdges as we do with all other instructions + // except non-volatile load. + Value *Stored = I.getOperand(0); + if (auto *GV = dyn_cast<GlobalValue>(Stored)) + // findRefEdges will try to examine GV operands, so instead + // of calling it we should add GV to RefEdges directly. + RefEdges.insert(Index.getOrInsertValueInfo(GV)); + else if (auto *U = dyn_cast<User>(Stored)) + findRefEdges(Index, U, RefEdges, Visited); + continue; + } + } + findRefEdges(Index, &I, RefEdges, Visited); + auto CS = ImmutableCallSite(&I); + if (!CS) + continue; + + const auto *CI = dyn_cast<CallInst>(&I); + // Since we don't know exactly which local values are referenced in inline + // assembly, conservatively mark the function as possibly referencing + // a local value from inline assembly to ensure we don't export a + // reference (which would require renaming and promotion of the + // referenced value). + if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm()) + HasInlineAsmMaybeReferencingInternal = true; + + auto *CalledValue = CS.getCalledValue(); + auto *CalledFunction = CS.getCalledFunction(); + if (CalledValue && !CalledFunction) { + CalledValue = CalledValue->stripPointerCasts(); + // Stripping pointer casts can reveal a called function. + CalledFunction = dyn_cast<Function>(CalledValue); + } + // Check if this is an alias to a function. If so, get the + // called aliasee for the checks below. + if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) { + assert(!CalledFunction && "Expected null called function in callsite for alias"); + CalledFunction = dyn_cast<Function>(GA->getBaseObject()); + } + // Check if this is a direct call to a known function or a known + // intrinsic, or an indirect call with profile data. + if (CalledFunction) { + if (CI && CalledFunction->isIntrinsic()) { + addIntrinsicToSummary( + CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls, + TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls, DT); + continue; + } + // We should have named any anonymous globals + assert(CalledFunction->hasName()); + auto ScaledCount = PSI->getProfileCount(&I, BFI); + auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI) + : CalleeInfo::HotnessType::Unknown; + if (ForceSummaryEdgesCold != FunctionSummary::FSHT_None) + Hotness = CalleeInfo::HotnessType::Cold; + + // Use the original CalledValue, in case it was an alias. We want + // to record the call edge to the alias in that case. Eventually + // an alias summary will be created to associate the alias and + // aliasee. + auto &ValueInfo = CallGraphEdges[Index.getOrInsertValueInfo( + cast<GlobalValue>(CalledValue))]; + ValueInfo.updateHotness(Hotness); + // Add the relative block frequency to CalleeInfo if there is no profile + // information. + if (BFI != nullptr && Hotness == CalleeInfo::HotnessType::Unknown) { + uint64_t BBFreq = BFI->getBlockFreq(&BB).getFrequency(); + uint64_t EntryFreq = BFI->getEntryFreq(); + ValueInfo.updateRelBlockFreq(BBFreq, EntryFreq); + } + } else { + // Skip inline assembly calls. + if (CI && CI->isInlineAsm()) + continue; + // Skip direct calls. + if (!CalledValue || isa<Constant>(CalledValue)) + continue; + + // Check if the instruction has a callees metadata. If so, add callees + // to CallGraphEdges to reflect the references from the metadata, and + // to enable importing for subsequent indirect call promotion and + // inlining. + if (auto *MD = I.getMetadata(LLVMContext::MD_callees)) { + for (auto &Op : MD->operands()) { + Function *Callee = mdconst::extract_or_null<Function>(Op); + if (Callee) + CallGraphEdges[Index.getOrInsertValueInfo(Callee)]; + } + } + + uint32_t NumVals, NumCandidates; + uint64_t TotalCount; + auto CandidateProfileData = + ICallAnalysis.getPromotionCandidatesForInstruction( + &I, NumVals, TotalCount, NumCandidates); + for (auto &Candidate : CandidateProfileData) + CallGraphEdges[Index.getOrInsertValueInfo(Candidate.Value)] + .updateHotness(getHotness(Candidate.Count, PSI)); + } + } + + std::vector<ValueInfo> Refs; + if (IsThinLTO) { + auto AddRefEdges = [&](const std::vector<const Instruction *> &Instrs, + SetVector<ValueInfo> &Edges, + SmallPtrSet<const User *, 8> &Cache) { + for (const auto *I : Instrs) { + Cache.erase(I); + findRefEdges(Index, I, Edges, Cache); + } + }; + + // By now we processed all instructions in a function, except + // non-volatile loads and non-volatile value stores. Let's find + // ref edges for both of instruction sets + AddRefEdges(NonVolatileLoads, LoadRefEdges, Visited); + // We can add some values to the Visited set when processing load + // instructions which are also used by stores in NonVolatileStores. + // For example this can happen if we have following code: + // + // store %Derived* @foo, %Derived** bitcast (%Base** @bar to %Derived**) + // %42 = load %Derived*, %Derived** bitcast (%Base** @bar to %Derived**) + // + // After processing loads we'll add bitcast to the Visited set, and if + // we use the same set while processing stores, we'll never see store + // to @bar and @bar will be mistakenly treated as readonly. + SmallPtrSet<const llvm::User *, 8> StoreCache; + AddRefEdges(NonVolatileStores, StoreRefEdges, StoreCache); + + // If both load and store instruction reference the same variable + // we won't be able to optimize it. Add all such reference edges + // to RefEdges set. + for (auto &VI : StoreRefEdges) + if (LoadRefEdges.remove(VI)) + RefEdges.insert(VI); + + unsigned RefCnt = RefEdges.size(); + // All new reference edges inserted in two loops below are either + // read or write only. They will be grouped in the end of RefEdges + // vector, so we can use a single integer value to identify them. + for (auto &VI : LoadRefEdges) + RefEdges.insert(VI); + + unsigned FirstWORef = RefEdges.size(); + for (auto &VI : StoreRefEdges) + RefEdges.insert(VI); + + Refs = RefEdges.takeVector(); + for (; RefCnt < FirstWORef; ++RefCnt) + Refs[RefCnt].setReadOnly(); + + for (; RefCnt < Refs.size(); ++RefCnt) + Refs[RefCnt].setWriteOnly(); + } else { + Refs = RefEdges.takeVector(); + } + // Explicit add hot edges to enforce importing for designated GUIDs for + // sample PGO, to enable the same inlines as the profiled optimized binary. + for (auto &I : F.getImportGUIDs()) + CallGraphEdges[Index.getOrInsertValueInfo(I)].updateHotness( + ForceSummaryEdgesCold == FunctionSummary::FSHT_All + ? CalleeInfo::HotnessType::Cold + : CalleeInfo::HotnessType::Critical); + + bool NonRenamableLocal = isNonRenamableLocal(F); + bool NotEligibleForImport = + NonRenamableLocal || HasInlineAsmMaybeReferencingInternal; + GlobalValueSummary::GVFlags Flags(F.getLinkage(), NotEligibleForImport, + /* Live = */ false, F.isDSOLocal(), + F.hasLinkOnceODRLinkage() && F.hasGlobalUnnamedAddr()); + FunctionSummary::FFlags FunFlags{ + F.hasFnAttribute(Attribute::ReadNone), + F.hasFnAttribute(Attribute::ReadOnly), + F.hasFnAttribute(Attribute::NoRecurse), F.returnDoesNotAlias(), + // FIXME: refactor this to use the same code that inliner is using. + // Don't try to import functions with noinline attribute. + F.getAttributes().hasFnAttribute(Attribute::NoInline)}; + auto FuncSummary = std::make_unique<FunctionSummary>( + Flags, NumInsts, FunFlags, /*EntryCount=*/0, std::move(Refs), + CallGraphEdges.takeVector(), TypeTests.takeVector(), + TypeTestAssumeVCalls.takeVector(), TypeCheckedLoadVCalls.takeVector(), + TypeTestAssumeConstVCalls.takeVector(), + TypeCheckedLoadConstVCalls.takeVector()); + if (NonRenamableLocal) + CantBePromoted.insert(F.getGUID()); + Index.addGlobalValueSummary(F, std::move(FuncSummary)); +} + +/// Find function pointers referenced within the given vtable initializer +/// (or subset of an initializer) \p I. The starting offset of \p I within +/// the vtable initializer is \p StartingOffset. Any discovered function +/// pointers are added to \p VTableFuncs along with their cumulative offset +/// within the initializer. +static void findFuncPointers(const Constant *I, uint64_t StartingOffset, + const Module &M, ModuleSummaryIndex &Index, + VTableFuncList &VTableFuncs) { + // First check if this is a function pointer. + if (I->getType()->isPointerTy()) { + auto Fn = dyn_cast<Function>(I->stripPointerCasts()); + // We can disregard __cxa_pure_virtual as a possible call target, as + // calls to pure virtuals are UB. + if (Fn && Fn->getName() != "__cxa_pure_virtual") + VTableFuncs.push_back({Index.getOrInsertValueInfo(Fn), StartingOffset}); + return; + } + + // Walk through the elements in the constant struct or array and recursively + // look for virtual function pointers. + const DataLayout &DL = M.getDataLayout(); + if (auto *C = dyn_cast<ConstantStruct>(I)) { + StructType *STy = dyn_cast<StructType>(C->getType()); + assert(STy); + const StructLayout *SL = DL.getStructLayout(C->getType()); + + for (StructType::element_iterator EB = STy->element_begin(), EI = EB, + EE = STy->element_end(); + EI != EE; ++EI) { + auto Offset = SL->getElementOffset(EI - EB); + unsigned Op = SL->getElementContainingOffset(Offset); + findFuncPointers(cast<Constant>(I->getOperand(Op)), + StartingOffset + Offset, M, Index, VTableFuncs); + } + } else if (auto *C = dyn_cast<ConstantArray>(I)) { + ArrayType *ATy = C->getType(); + Type *EltTy = ATy->getElementType(); + uint64_t EltSize = DL.getTypeAllocSize(EltTy); + for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) { + findFuncPointers(cast<Constant>(I->getOperand(i)), + StartingOffset + i * EltSize, M, Index, VTableFuncs); + } + } +} + +// Identify the function pointers referenced by vtable definition \p V. +static void computeVTableFuncs(ModuleSummaryIndex &Index, + const GlobalVariable &V, const Module &M, + VTableFuncList &VTableFuncs) { + if (!V.isConstant()) + return; + + findFuncPointers(V.getInitializer(), /*StartingOffset=*/0, M, Index, + VTableFuncs); + +#ifndef NDEBUG + // Validate that the VTableFuncs list is ordered by offset. + uint64_t PrevOffset = 0; + for (auto &P : VTableFuncs) { + // The findVFuncPointers traversal should have encountered the + // functions in offset order. We need to use ">=" since PrevOffset + // starts at 0. + assert(P.VTableOffset >= PrevOffset); + PrevOffset = P.VTableOffset; + } +#endif +} + +/// Record vtable definition \p V for each type metadata it references. +static void +recordTypeIdCompatibleVtableReferences(ModuleSummaryIndex &Index, + const GlobalVariable &V, + SmallVectorImpl<MDNode *> &Types) { + for (MDNode *Type : Types) { + auto TypeID = Type->getOperand(1).get(); + + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + if (auto *TypeId = dyn_cast<MDString>(TypeID)) + Index.getOrInsertTypeIdCompatibleVtableSummary(TypeId->getString()) + .push_back({Offset, Index.getOrInsertValueInfo(&V)}); + } +} + +static void computeVariableSummary(ModuleSummaryIndex &Index, + const GlobalVariable &V, + DenseSet<GlobalValue::GUID> &CantBePromoted, + const Module &M, + SmallVectorImpl<MDNode *> &Types) { + SetVector<ValueInfo> RefEdges; + SmallPtrSet<const User *, 8> Visited; + bool HasBlockAddress = findRefEdges(Index, &V, RefEdges, Visited); + bool NonRenamableLocal = isNonRenamableLocal(V); + GlobalValueSummary::GVFlags Flags(V.getLinkage(), NonRenamableLocal, + /* Live = */ false, V.isDSOLocal(), + V.hasLinkOnceODRLinkage() && V.hasGlobalUnnamedAddr()); + + VTableFuncList VTableFuncs; + // If splitting is not enabled, then we compute the summary information + // necessary for index-based whole program devirtualization. + if (!Index.enableSplitLTOUnit()) { + Types.clear(); + V.getMetadata(LLVMContext::MD_type, Types); + if (!Types.empty()) { + // Identify the function pointers referenced by this vtable definition. + computeVTableFuncs(Index, V, M, VTableFuncs); + + // Record this vtable definition for each type metadata it references. + recordTypeIdCompatibleVtableReferences(Index, V, Types); + } + } + + // Don't mark variables we won't be able to internalize as read/write-only. + bool CanBeInternalized = + !V.hasComdat() && !V.hasAppendingLinkage() && !V.isInterposable() && + !V.hasAvailableExternallyLinkage() && !V.hasDLLExportStorageClass(); + GlobalVarSummary::GVarFlags VarFlags(CanBeInternalized, CanBeInternalized); + auto GVarSummary = std::make_unique<GlobalVarSummary>(Flags, VarFlags, + RefEdges.takeVector()); + if (NonRenamableLocal) + CantBePromoted.insert(V.getGUID()); + if (HasBlockAddress) + GVarSummary->setNotEligibleToImport(); + if (!VTableFuncs.empty()) + GVarSummary->setVTableFuncs(VTableFuncs); + Index.addGlobalValueSummary(V, std::move(GVarSummary)); +} + +static void +computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A, + DenseSet<GlobalValue::GUID> &CantBePromoted) { + bool NonRenamableLocal = isNonRenamableLocal(A); + GlobalValueSummary::GVFlags Flags(A.getLinkage(), NonRenamableLocal, + /* Live = */ false, A.isDSOLocal(), + A.hasLinkOnceODRLinkage() && A.hasGlobalUnnamedAddr()); + auto AS = std::make_unique<AliasSummary>(Flags); + auto *Aliasee = A.getBaseObject(); + auto AliaseeVI = Index.getValueInfo(Aliasee->getGUID()); + assert(AliaseeVI && "Alias expects aliasee summary to be available"); + assert(AliaseeVI.getSummaryList().size() == 1 && + "Expected a single entry per aliasee in per-module index"); + AS->setAliasee(AliaseeVI, AliaseeVI.getSummaryList()[0].get()); + if (NonRenamableLocal) + CantBePromoted.insert(A.getGUID()); + Index.addGlobalValueSummary(A, std::move(AS)); +} + +// Set LiveRoot flag on entries matching the given value name. +static void setLiveRoot(ModuleSummaryIndex &Index, StringRef Name) { + if (ValueInfo VI = Index.getValueInfo(GlobalValue::getGUID(Name))) + for (auto &Summary : VI.getSummaryList()) + Summary->setLive(true); +} + +ModuleSummaryIndex llvm::buildModuleSummaryIndex( + const Module &M, + std::function<BlockFrequencyInfo *(const Function &F)> GetBFICallback, + ProfileSummaryInfo *PSI) { + assert(PSI); + bool EnableSplitLTOUnit = false; + if (auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("EnableSplitLTOUnit"))) + EnableSplitLTOUnit = MD->getZExtValue(); + ModuleSummaryIndex Index(/*HaveGVs=*/true, EnableSplitLTOUnit); + + // Identify the local values in the llvm.used and llvm.compiler.used sets, + // which should not be exported as they would then require renaming and + // promotion, but we may have opaque uses e.g. in inline asm. We collect them + // here because we use this information to mark functions containing inline + // assembly calls as not importable. + SmallPtrSet<GlobalValue *, 8> LocalsUsed; + SmallPtrSet<GlobalValue *, 8> Used; + // First collect those in the llvm.used set. + collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ false); + // Next collect those in the llvm.compiler.used set. + collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ true); + DenseSet<GlobalValue::GUID> CantBePromoted; + for (auto *V : Used) { + if (V->hasLocalLinkage()) { + LocalsUsed.insert(V); + CantBePromoted.insert(V->getGUID()); + } + } + + bool HasLocalInlineAsmSymbol = false; + if (!M.getModuleInlineAsm().empty()) { + // Collect the local values defined by module level asm, and set up + // summaries for these symbols so that they can be marked as NoRename, + // to prevent export of any use of them in regular IR that would require + // renaming within the module level asm. Note we don't need to create a + // summary for weak or global defs, as they don't need to be flagged as + // NoRename, and defs in module level asm can't be imported anyway. + // Also, any values used but not defined within module level asm should + // be listed on the llvm.used or llvm.compiler.used global and marked as + // referenced from there. + ModuleSymbolTable::CollectAsmSymbols( + M, [&](StringRef Name, object::BasicSymbolRef::Flags Flags) { + // Symbols not marked as Weak or Global are local definitions. + if (Flags & (object::BasicSymbolRef::SF_Weak | + object::BasicSymbolRef::SF_Global)) + return; + HasLocalInlineAsmSymbol = true; + GlobalValue *GV = M.getNamedValue(Name); + if (!GV) + return; + assert(GV->isDeclaration() && "Def in module asm already has definition"); + GlobalValueSummary::GVFlags GVFlags(GlobalValue::InternalLinkage, + /* NotEligibleToImport = */ true, + /* Live = */ true, + /* Local */ GV->isDSOLocal(), + GV->hasLinkOnceODRLinkage() && GV->hasGlobalUnnamedAddr()); + CantBePromoted.insert(GV->getGUID()); + // Create the appropriate summary type. + if (Function *F = dyn_cast<Function>(GV)) { + std::unique_ptr<FunctionSummary> Summary = + std::make_unique<FunctionSummary>( + GVFlags, /*InstCount=*/0, + FunctionSummary::FFlags{ + F->hasFnAttribute(Attribute::ReadNone), + F->hasFnAttribute(Attribute::ReadOnly), + F->hasFnAttribute(Attribute::NoRecurse), + F->returnDoesNotAlias(), + /* NoInline = */ false}, + /*EntryCount=*/0, ArrayRef<ValueInfo>{}, + ArrayRef<FunctionSummary::EdgeTy>{}, + ArrayRef<GlobalValue::GUID>{}, + ArrayRef<FunctionSummary::VFuncId>{}, + ArrayRef<FunctionSummary::VFuncId>{}, + ArrayRef<FunctionSummary::ConstVCall>{}, + ArrayRef<FunctionSummary::ConstVCall>{}); + Index.addGlobalValueSummary(*GV, std::move(Summary)); + } else { + std::unique_ptr<GlobalVarSummary> Summary = + std::make_unique<GlobalVarSummary>( + GVFlags, GlobalVarSummary::GVarFlags(false, false), + ArrayRef<ValueInfo>{}); + Index.addGlobalValueSummary(*GV, std::move(Summary)); + } + }); + } + + bool IsThinLTO = true; + if (auto *MD = + mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("ThinLTO"))) + IsThinLTO = MD->getZExtValue(); + + // Compute summaries for all functions defined in module, and save in the + // index. + for (auto &F : M) { + if (F.isDeclaration()) + continue; + + DominatorTree DT(const_cast<Function &>(F)); + BlockFrequencyInfo *BFI = nullptr; + std::unique_ptr<BlockFrequencyInfo> BFIPtr; + if (GetBFICallback) + BFI = GetBFICallback(F); + else if (F.hasProfileData()) { + LoopInfo LI{DT}; + BranchProbabilityInfo BPI{F, LI}; + BFIPtr = std::make_unique<BlockFrequencyInfo>(F, BPI, LI); + BFI = BFIPtr.get(); + } + + computeFunctionSummary(Index, M, F, BFI, PSI, DT, + !LocalsUsed.empty() || HasLocalInlineAsmSymbol, + CantBePromoted, IsThinLTO); + } + + // Compute summaries for all variables defined in module, and save in the + // index. + SmallVector<MDNode *, 2> Types; + for (const GlobalVariable &G : M.globals()) { + if (G.isDeclaration()) + continue; + computeVariableSummary(Index, G, CantBePromoted, M, Types); + } + + // Compute summaries for all aliases defined in module, and save in the + // index. + for (const GlobalAlias &A : M.aliases()) + computeAliasSummary(Index, A, CantBePromoted); + + for (auto *V : LocalsUsed) { + auto *Summary = Index.getGlobalValueSummary(*V); + assert(Summary && "Missing summary for global value"); + Summary->setNotEligibleToImport(); + } + + // The linker doesn't know about these LLVM produced values, so we need + // to flag them as live in the index to ensure index-based dead value + // analysis treats them as live roots of the analysis. + setLiveRoot(Index, "llvm.used"); + setLiveRoot(Index, "llvm.compiler.used"); + setLiveRoot(Index, "llvm.global_ctors"); + setLiveRoot(Index, "llvm.global_dtors"); + setLiveRoot(Index, "llvm.global.annotations"); + + for (auto &GlobalList : Index) { + // Ignore entries for references that are undefined in the current module. + if (GlobalList.second.SummaryList.empty()) + continue; + + assert(GlobalList.second.SummaryList.size() == 1 && + "Expected module's index to have one summary per GUID"); + auto &Summary = GlobalList.second.SummaryList[0]; + if (!IsThinLTO) { + Summary->setNotEligibleToImport(); + continue; + } + + bool AllRefsCanBeExternallyReferenced = + llvm::all_of(Summary->refs(), [&](const ValueInfo &VI) { + return !CantBePromoted.count(VI.getGUID()); + }); + if (!AllRefsCanBeExternallyReferenced) { + Summary->setNotEligibleToImport(); + continue; + } + + if (auto *FuncSummary = dyn_cast<FunctionSummary>(Summary.get())) { + bool AllCallsCanBeExternallyReferenced = llvm::all_of( + FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) { + return !CantBePromoted.count(Edge.first.getGUID()); + }); + if (!AllCallsCanBeExternallyReferenced) + Summary->setNotEligibleToImport(); + } + } + + if (!ModuleSummaryDotFile.empty()) { + std::error_code EC; + raw_fd_ostream OSDot(ModuleSummaryDotFile, EC, sys::fs::OpenFlags::OF_None); + if (EC) + report_fatal_error(Twine("Failed to open dot file ") + + ModuleSummaryDotFile + ": " + EC.message() + "\n"); + Index.exportToDot(OSDot); + } + + return Index; +} + +AnalysisKey ModuleSummaryIndexAnalysis::Key; + +ModuleSummaryIndex +ModuleSummaryIndexAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + return buildModuleSummaryIndex( + M, + [&FAM](const Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>( + *const_cast<Function *>(&F)); + }, + &PSI); +} + +char ModuleSummaryIndexWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ModuleSummaryIndexWrapperPass, "module-summary-analysis", + "Module Summary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_END(ModuleSummaryIndexWrapperPass, "module-summary-analysis", + "Module Summary Analysis", false, true) + +ModulePass *llvm::createModuleSummaryIndexWrapperPass() { + return new ModuleSummaryIndexWrapperPass(); +} + +ModuleSummaryIndexWrapperPass::ModuleSummaryIndexWrapperPass() + : ModulePass(ID) { + initializeModuleSummaryIndexWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) { + auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + Index.emplace(buildModuleSummaryIndex( + M, + [this](const Function &F) { + return &(this->getAnalysis<BlockFrequencyInfoWrapperPass>( + *const_cast<Function *>(&F)) + .getBFI()); + }, + PSI)); + return false; +} + +bool ModuleSummaryIndexWrapperPass::doFinalization(Module &M) { + Index.reset(); + return false; +} + +void ModuleSummaryIndexWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); +} diff --git a/llvm/lib/Analysis/MustExecute.cpp b/llvm/lib/Analysis/MustExecute.cpp new file mode 100644 index 000000000000..44527773115d --- /dev/null +++ b/llvm/lib/Analysis/MustExecute.cpp @@ -0,0 +1,516 @@ +//===- MustExecute.cpp - Printer for isGuaranteedToExecute ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MustExecute.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "must-execute" + +const DenseMap<BasicBlock *, ColorVector> & +LoopSafetyInfo::getBlockColors() const { + return BlockColors; +} + +void LoopSafetyInfo::copyColors(BasicBlock *New, BasicBlock *Old) { + ColorVector &ColorsForNewBlock = BlockColors[New]; + ColorVector &ColorsForOldBlock = BlockColors[Old]; + ColorsForNewBlock = ColorsForOldBlock; +} + +bool SimpleLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const { + (void)BB; + return anyBlockMayThrow(); +} + +bool SimpleLoopSafetyInfo::anyBlockMayThrow() const { + return MayThrow; +} + +void SimpleLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) { + assert(CurLoop != nullptr && "CurLoop can't be null"); + BasicBlock *Header = CurLoop->getHeader(); + // Iterate over header and compute safety info. + HeaderMayThrow = !isGuaranteedToTransferExecutionToSuccessor(Header); + MayThrow = HeaderMayThrow; + // Iterate over loop instructions and compute safety info. + // Skip header as it has been computed and stored in HeaderMayThrow. + // The first block in loopinfo.Blocks is guaranteed to be the header. + assert(Header == *CurLoop->getBlocks().begin() && + "First block must be header"); + for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), + BBE = CurLoop->block_end(); + (BB != BBE) && !MayThrow; ++BB) + MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(*BB); + + computeBlockColors(CurLoop); +} + +bool ICFLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const { + return ICF.hasICF(BB); +} + +bool ICFLoopSafetyInfo::anyBlockMayThrow() const { + return MayThrow; +} + +void ICFLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) { + assert(CurLoop != nullptr && "CurLoop can't be null"); + ICF.clear(); + MW.clear(); + MayThrow = false; + // Figure out the fact that at least one block may throw. + for (auto &BB : CurLoop->blocks()) + if (ICF.hasICF(&*BB)) { + MayThrow = true; + break; + } + computeBlockColors(CurLoop); +} + +void ICFLoopSafetyInfo::insertInstructionTo(const Instruction *Inst, + const BasicBlock *BB) { + ICF.insertInstructionTo(Inst, BB); + MW.insertInstructionTo(Inst, BB); +} + +void ICFLoopSafetyInfo::removeInstruction(const Instruction *Inst) { + ICF.removeInstruction(Inst); + MW.removeInstruction(Inst); +} + +void LoopSafetyInfo::computeBlockColors(const Loop *CurLoop) { + // Compute funclet colors if we might sink/hoist in a function with a funclet + // personality routine. + Function *Fn = CurLoop->getHeader()->getParent(); + if (Fn->hasPersonalityFn()) + if (Constant *PersonalityFn = Fn->getPersonalityFn()) + if (isScopedEHPersonality(classifyEHPersonality(PersonalityFn))) + BlockColors = colorEHFunclets(*Fn); +} + +/// Return true if we can prove that the given ExitBlock is not reached on the +/// first iteration of the given loop. That is, the backedge of the loop must +/// be executed before the ExitBlock is executed in any dynamic execution trace. +static bool CanProveNotTakenFirstIteration(const BasicBlock *ExitBlock, + const DominatorTree *DT, + const Loop *CurLoop) { + auto *CondExitBlock = ExitBlock->getSinglePredecessor(); + if (!CondExitBlock) + // expect unique exits + return false; + assert(CurLoop->contains(CondExitBlock) && "meaning of exit block"); + auto *BI = dyn_cast<BranchInst>(CondExitBlock->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + // If condition is constant and false leads to ExitBlock then we always + // execute the true branch. + if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) + return BI->getSuccessor(Cond->getZExtValue() ? 1 : 0) == ExitBlock; + auto *Cond = dyn_cast<CmpInst>(BI->getCondition()); + if (!Cond) + return false; + // todo: this would be a lot more powerful if we used scev, but all the + // plumbing is currently missing to pass a pointer in from the pass + // Check for cmp (phi [x, preheader] ...), y where (pred x, y is known + auto *LHS = dyn_cast<PHINode>(Cond->getOperand(0)); + auto *RHS = Cond->getOperand(1); + if (!LHS || LHS->getParent() != CurLoop->getHeader()) + return false; + auto DL = ExitBlock->getModule()->getDataLayout(); + auto *IVStart = LHS->getIncomingValueForBlock(CurLoop->getLoopPreheader()); + auto *SimpleValOrNull = SimplifyCmpInst(Cond->getPredicate(), + IVStart, RHS, + {DL, /*TLI*/ nullptr, + DT, /*AC*/ nullptr, BI}); + auto *SimpleCst = dyn_cast_or_null<Constant>(SimpleValOrNull); + if (!SimpleCst) + return false; + if (ExitBlock == BI->getSuccessor(0)) + return SimpleCst->isZeroValue(); + assert(ExitBlock == BI->getSuccessor(1) && "implied by above"); + return SimpleCst->isAllOnesValue(); +} + +/// Collect all blocks from \p CurLoop which lie on all possible paths from +/// the header of \p CurLoop (inclusive) to BB (exclusive) into the set +/// \p Predecessors. If \p BB is the header, \p Predecessors will be empty. +static void collectTransitivePredecessors( + const Loop *CurLoop, const BasicBlock *BB, + SmallPtrSetImpl<const BasicBlock *> &Predecessors) { + assert(Predecessors.empty() && "Garbage in predecessors set?"); + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + if (BB == CurLoop->getHeader()) + return; + SmallVector<const BasicBlock *, 4> WorkList; + for (auto *Pred : predecessors(BB)) { + Predecessors.insert(Pred); + WorkList.push_back(Pred); + } + while (!WorkList.empty()) { + auto *Pred = WorkList.pop_back_val(); + assert(CurLoop->contains(Pred) && "Should only reach loop blocks!"); + // We are not interested in backedges and we don't want to leave loop. + if (Pred == CurLoop->getHeader()) + continue; + // TODO: If BB lies in an inner loop of CurLoop, this will traverse over all + // blocks of this inner loop, even those that are always executed AFTER the + // BB. It may make our analysis more conservative than it could be, see test + // @nested and @nested_no_throw in test/Analysis/MustExecute/loop-header.ll. + // We can ignore backedge of all loops containing BB to get a sligtly more + // optimistic result. + for (auto *PredPred : predecessors(Pred)) + if (Predecessors.insert(PredPred).second) + WorkList.push_back(PredPred); + } +} + +bool LoopSafetyInfo::allLoopPathsLeadToBlock(const Loop *CurLoop, + const BasicBlock *BB, + const DominatorTree *DT) const { + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + + // Fast path: header is always reached once the loop is entered. + if (BB == CurLoop->getHeader()) + return true; + + // Collect all transitive predecessors of BB in the same loop. This set will + // be a subset of the blocks within the loop. + SmallPtrSet<const BasicBlock *, 4> Predecessors; + collectTransitivePredecessors(CurLoop, BB, Predecessors); + + // Make sure that all successors of, all predecessors of BB which are not + // dominated by BB, are either: + // 1) BB, + // 2) Also predecessors of BB, + // 3) Exit blocks which are not taken on 1st iteration. + // Memoize blocks we've already checked. + SmallPtrSet<const BasicBlock *, 4> CheckedSuccessors; + for (auto *Pred : Predecessors) { + // Predecessor block may throw, so it has a side exit. + if (blockMayThrow(Pred)) + return false; + + // BB dominates Pred, so if Pred runs, BB must run. + // This is true when Pred is a loop latch. + if (DT->dominates(BB, Pred)) + continue; + + for (auto *Succ : successors(Pred)) + if (CheckedSuccessors.insert(Succ).second && + Succ != BB && !Predecessors.count(Succ)) + // By discharging conditions that are not executed on the 1st iteration, + // we guarantee that *at least* on the first iteration all paths from + // header that *may* execute will lead us to the block of interest. So + // that if we had virtually peeled one iteration away, in this peeled + // iteration the set of predecessors would contain only paths from + // header to BB without any exiting edges that may execute. + // + // TODO: We only do it for exiting edges currently. We could use the + // same function to skip some of the edges within the loop if we know + // that they will not be taken on the 1st iteration. + // + // TODO: If we somehow know the number of iterations in loop, the same + // check may be done for any arbitrary N-th iteration as long as N is + // not greater than minimum number of iterations in this loop. + if (CurLoop->contains(Succ) || + !CanProveNotTakenFirstIteration(Succ, DT, CurLoop)) + return false; + } + + // All predecessors can only lead us to BB. + return true; +} + +/// Returns true if the instruction in a loop is guaranteed to execute at least +/// once. +bool SimpleLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop) const { + // If the instruction is in the header block for the loop (which is very + // common), it is always guaranteed to dominate the exit blocks. Since this + // is a common case, and can save some work, check it now. + if (Inst.getParent() == CurLoop->getHeader()) + // If there's a throw in the header block, we can't guarantee we'll reach + // Inst unless we can prove that Inst comes before the potential implicit + // exit. At the moment, we use a (cheap) hack for the common case where + // the instruction of interest is the first one in the block. + return !HeaderMayThrow || + Inst.getParent()->getFirstNonPHIOrDbg() == &Inst; + + // If there is a path from header to exit or latch that doesn't lead to our + // instruction's block, return false. + return allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT); +} + +bool ICFLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop) const { + return !ICF.isDominatedByICFIFromSameBlock(&Inst) && + allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT); +} + +bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const BasicBlock *BB, + const Loop *CurLoop) const { + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + + // Fast path: there are no instructions before header. + if (BB == CurLoop->getHeader()) + return true; + + // Collect all transitive predecessors of BB in the same loop. This set will + // be a subset of the blocks within the loop. + SmallPtrSet<const BasicBlock *, 4> Predecessors; + collectTransitivePredecessors(CurLoop, BB, Predecessors); + // Find if there any instruction in either predecessor that could write + // to memory. + for (auto *Pred : Predecessors) + if (MW.mayWriteToMemory(Pred)) + return false; + return true; +} + +bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const Instruction &I, + const Loop *CurLoop) const { + auto *BB = I.getParent(); + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + return !MW.isDominatedByMemoryWriteFromSameBlock(&I) && + doesNotWriteMemoryBefore(BB, CurLoop); +} + +namespace { + struct MustExecutePrinter : public FunctionPass { + + static char ID; // Pass identification, replacement for typeid + MustExecutePrinter() : FunctionPass(ID) { + initializeMustExecutePrinterPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + } + bool runOnFunction(Function &F) override; + }; + struct MustBeExecutedContextPrinter : public ModulePass { + static char ID; + + MustBeExecutedContextPrinter() : ModulePass(ID) { + initializeMustBeExecutedContextPrinterPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + bool runOnModule(Module &M) override; + }; +} + +char MustExecutePrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MustExecutePrinter, "print-mustexecute", + "Instructions which execute on loop entry", false, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(MustExecutePrinter, "print-mustexecute", + "Instructions which execute on loop entry", false, true) + +FunctionPass *llvm::createMustExecutePrinter() { + return new MustExecutePrinter(); +} + +char MustBeExecutedContextPrinter::ID = 0; +INITIALIZE_PASS_BEGIN( + MustBeExecutedContextPrinter, "print-must-be-executed-contexts", + "print the must-be-executed-contexed for all instructions", false, true) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(MustBeExecutedContextPrinter, + "print-must-be-executed-contexts", + "print the must-be-executed-contexed for all instructions", + false, true) + +ModulePass *llvm::createMustBeExecutedContextPrinter() { + return new MustBeExecutedContextPrinter(); +} + +bool MustBeExecutedContextPrinter::runOnModule(Module &M) { + MustBeExecutedContextExplorer Explorer(true); + for (Function &F : M) { + for (Instruction &I : instructions(F)) { + dbgs() << "-- Explore context of: " << I << "\n"; + for (const Instruction *CI : Explorer.range(&I)) + dbgs() << " [F: " << CI->getFunction()->getName() << "] " << *CI + << "\n"; + } + } + + return false; +} + +static bool isMustExecuteIn(const Instruction &I, Loop *L, DominatorTree *DT) { + // TODO: merge these two routines. For the moment, we display the best + // result obtained by *either* implementation. This is a bit unfair since no + // caller actually gets the full power at the moment. + SimpleLoopSafetyInfo LSI; + LSI.computeLoopSafetyInfo(L); + return LSI.isGuaranteedToExecute(I, DT, L) || + isGuaranteedToExecuteForEveryIteration(&I, L); +} + +namespace { +/// An assembly annotator class to print must execute information in +/// comments. +class MustExecuteAnnotatedWriter : public AssemblyAnnotationWriter { + DenseMap<const Value*, SmallVector<Loop*, 4> > MustExec; + +public: + MustExecuteAnnotatedWriter(const Function &F, + DominatorTree &DT, LoopInfo &LI) { + for (auto &I: instructions(F)) { + Loop *L = LI.getLoopFor(I.getParent()); + while (L) { + if (isMustExecuteIn(I, L, &DT)) { + MustExec[&I].push_back(L); + } + L = L->getParentLoop(); + }; + } + } + MustExecuteAnnotatedWriter(const Module &M, + DominatorTree &DT, LoopInfo &LI) { + for (auto &F : M) + for (auto &I: instructions(F)) { + Loop *L = LI.getLoopFor(I.getParent()); + while (L) { + if (isMustExecuteIn(I, L, &DT)) { + MustExec[&I].push_back(L); + } + L = L->getParentLoop(); + }; + } + } + + + void printInfoComment(const Value &V, formatted_raw_ostream &OS) override { + if (!MustExec.count(&V)) + return; + + const auto &Loops = MustExec.lookup(&V); + const auto NumLoops = Loops.size(); + if (NumLoops > 1) + OS << " ; (mustexec in " << NumLoops << " loops: "; + else + OS << " ; (mustexec in: "; + + bool first = true; + for (const Loop *L : Loops) { + if (!first) + OS << ", "; + first = false; + OS << L->getHeader()->getName(); + } + OS << ")"; + } +}; +} // namespace + +bool MustExecutePrinter::runOnFunction(Function &F) { + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + + MustExecuteAnnotatedWriter Writer(F, DT, LI); + F.print(dbgs(), &Writer); + + return false; +} + +const Instruction * +MustBeExecutedContextExplorer::getMustBeExecutedNextInstruction( + MustBeExecutedIterator &It, const Instruction *PP) { + if (!PP) + return PP; + LLVM_DEBUG(dbgs() << "Find next instruction for " << *PP << "\n"); + + // If we explore only inside a given basic block we stop at terminators. + if (!ExploreInterBlock && PP->isTerminator()) { + LLVM_DEBUG(dbgs() << "\tReached terminator in intra-block mode, done\n"); + return nullptr; + } + + // If we do not traverse the call graph we check if we can make progress in + // the current function. First, check if the instruction is guaranteed to + // transfer execution to the successor. + bool TransfersExecution = isGuaranteedToTransferExecutionToSuccessor(PP); + if (!TransfersExecution) + return nullptr; + + // If this is not a terminator we know that there is a single instruction + // after this one that is executed next if control is transfered. If not, + // we can try to go back to a call site we entered earlier. If none exists, we + // do not know any instruction that has to be executd next. + if (!PP->isTerminator()) { + const Instruction *NextPP = PP->getNextNode(); + LLVM_DEBUG(dbgs() << "\tIntermediate instruction does transfer control\n"); + return NextPP; + } + + // Finally, we have to handle terminators, trivial ones first. + assert(PP->isTerminator() && "Expected a terminator!"); + + // A terminator without a successor is not handled yet. + if (PP->getNumSuccessors() == 0) { + LLVM_DEBUG(dbgs() << "\tUnhandled terminator\n"); + return nullptr; + } + + // A terminator with a single successor, we will continue at the beginning of + // that one. + if (PP->getNumSuccessors() == 1) { + LLVM_DEBUG( + dbgs() << "\tUnconditional terminator, continue with successor\n"); + return &PP->getSuccessor(0)->front(); + } + + LLVM_DEBUG(dbgs() << "\tNo join point found\n"); + return nullptr; +} + +MustBeExecutedIterator::MustBeExecutedIterator( + MustBeExecutedContextExplorer &Explorer, const Instruction *I) + : Explorer(Explorer), CurInst(I) { + reset(I); +} + +void MustBeExecutedIterator::reset(const Instruction *I) { + CurInst = I; + Visited.clear(); + Visited.insert(I); +} + +const Instruction *MustBeExecutedIterator::advance() { + assert(CurInst && "Cannot advance an end iterator!"); + const Instruction *Next = + Explorer.getMustBeExecutedNextInstruction(*this, CurInst); + if (Next && !Visited.insert(Next).second) + Next = nullptr; + return Next; +} diff --git a/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp new file mode 100644 index 000000000000..811033e73147 --- /dev/null +++ b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp @@ -0,0 +1,164 @@ +//===- ObjCARCAliasAnalysis.cpp - ObjC ARC Optimization -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines a simple ARC-aware AliasAnalysis using special knowledge +/// of Objective C to enhance other optimization passes which rely on the Alias +/// Analysis infrastructure. +/// +/// WARNING: This file knows about certain library functions. It recognizes them +/// by name, and hardwires knowledge of their semantics. +/// +/// WARNING: This file knows about how certain Objective-C library functions are +/// used. Naive LLVM IR transformations which would otherwise be +/// behavior-preserving may break these assumptions. +/// +/// TODO: Theoretically we could check for dependencies between objc_* calls +/// and FMRB_OnlyAccessesArgumentPointees calls or other well-behaved calls. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCAliasAnalysis.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/PassAnalysisSupport.h" +#include "llvm/PassSupport.h" + +#define DEBUG_TYPE "objc-arc-aa" + +using namespace llvm; +using namespace llvm::objcarc; + +AliasResult ObjCARCAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + if (!EnableARCOpts) + return AAResultBase::alias(LocA, LocB, AAQI); + + // First, strip off no-ops, including ObjC-specific no-ops, and try making a + // precise alias query. + const Value *SA = GetRCIdentityRoot(LocA.Ptr); + const Value *SB = GetRCIdentityRoot(LocB.Ptr); + AliasResult Result = + AAResultBase::alias(MemoryLocation(SA, LocA.Size, LocA.AATags), + MemoryLocation(SB, LocB.Size, LocB.AATags), AAQI); + if (Result != MayAlias) + return Result; + + // If that failed, climb to the underlying object, including climbing through + // ObjC-specific no-ops, and try making an imprecise alias query. + const Value *UA = GetUnderlyingObjCPtr(SA, DL); + const Value *UB = GetUnderlyingObjCPtr(SB, DL); + if (UA != SA || UB != SB) { + Result = AAResultBase::alias(MemoryLocation(UA), MemoryLocation(UB), AAQI); + // We can't use MustAlias or PartialAlias results here because + // GetUnderlyingObjCPtr may return an offsetted pointer value. + if (Result == NoAlias) + return NoAlias; + } + + // If that failed, fail. We don't need to chain here, since that's covered + // by the earlier precise query. + return MayAlias; +} + +bool ObjCARCAAResult::pointsToConstantMemory(const MemoryLocation &Loc, + AAQueryInfo &AAQI, bool OrLocal) { + if (!EnableARCOpts) + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + + // First, strip off no-ops, including ObjC-specific no-ops, and try making + // a precise alias query. + const Value *S = GetRCIdentityRoot(Loc.Ptr); + if (AAResultBase::pointsToConstantMemory( + MemoryLocation(S, Loc.Size, Loc.AATags), AAQI, OrLocal)) + return true; + + // If that failed, climb to the underlying object, including climbing through + // ObjC-specific no-ops, and try making an imprecise alias query. + const Value *U = GetUnderlyingObjCPtr(S, DL); + if (U != S) + return AAResultBase::pointsToConstantMemory(MemoryLocation(U), AAQI, + OrLocal); + + // If that failed, fail. We don't need to chain here, since that's covered + // by the earlier precise query. + return false; +} + +FunctionModRefBehavior ObjCARCAAResult::getModRefBehavior(const Function *F) { + if (!EnableARCOpts) + return AAResultBase::getModRefBehavior(F); + + switch (GetFunctionClass(F)) { + case ARCInstKind::NoopCast: + return FMRB_DoesNotAccessMemory; + default: + break; + } + + return AAResultBase::getModRefBehavior(F); +} + +ModRefInfo ObjCARCAAResult::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (!EnableARCOpts) + return AAResultBase::getModRefInfo(Call, Loc, AAQI); + + switch (GetBasicARCInstKind(Call)) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::NoopCast: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + // These functions don't access any memory visible to the compiler. + // Note that this doesn't include objc_retainBlock, because it updates + // pointers when it copies block data. + return ModRefInfo::NoModRef; + default: + break; + } + + return AAResultBase::getModRefInfo(Call, Loc, AAQI); +} + +ObjCARCAAResult ObjCARCAA::run(Function &F, FunctionAnalysisManager &AM) { + return ObjCARCAAResult(F.getParent()->getDataLayout()); +} + +char ObjCARCAAWrapperPass::ID = 0; +INITIALIZE_PASS(ObjCARCAAWrapperPass, "objc-arc-aa", + "ObjC-ARC-Based Alias Analysis", false, true) + +ImmutablePass *llvm::createObjCARCAAWrapperPass() { + return new ObjCARCAAWrapperPass(); +} + +ObjCARCAAWrapperPass::ObjCARCAAWrapperPass() : ImmutablePass(ID) { + initializeObjCARCAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ObjCARCAAWrapperPass::doInitialization(Module &M) { + Result.reset(new ObjCARCAAResult(M.getDataLayout())); + return false; +} + +bool ObjCARCAAWrapperPass::doFinalization(Module &M) { + Result.reset(); + return false; +} + +void ObjCARCAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} diff --git a/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp b/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp new file mode 100644 index 000000000000..56d1cb421225 --- /dev/null +++ b/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp @@ -0,0 +1,25 @@ +//===- ObjCARCAnalysisUtils.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements common infrastructure for libLLVMObjCARCOpts.a, which +// implements several scalar transformations over the LLVM intermediate +// representation, including the C bindings for that library. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; +using namespace llvm::objcarc; + +/// A handy option to enable/disable all ARC Optimizations. +bool llvm::objcarc::EnableARCOpts; +static cl::opt<bool, true> EnableARCOptimizations( + "enable-objc-arc-opts", cl::desc("enable/disable all ARC Optimizations"), + cl::location(EnableARCOpts), cl::init(true), cl::Hidden); diff --git a/llvm/lib/Analysis/ObjCARCInstKind.cpp b/llvm/lib/Analysis/ObjCARCInstKind.cpp new file mode 100644 index 000000000000..0e96c6e975c9 --- /dev/null +++ b/llvm/lib/Analysis/ObjCARCInstKind.cpp @@ -0,0 +1,705 @@ +//===- ARCInstKind.cpp - ObjC ARC Optimization ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines several utility functions used by various ARC +/// optimizations which are IMHO too big to be in a header file. +/// +/// WARNING: This file knows about certain library functions. It recognizes them +/// by name, and hardwires knowledge of their semantics. +/// +/// WARNING: This file knows about how certain Objective-C library functions are +/// used. Naive LLVM IR transformations which would otherwise be +/// behavior-preserving may break these assumptions. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCInstKind.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/IR/Intrinsics.h" + +using namespace llvm; +using namespace llvm::objcarc; + +raw_ostream &llvm::objcarc::operator<<(raw_ostream &OS, + const ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Retain: + return OS << "ARCInstKind::Retain"; + case ARCInstKind::RetainRV: + return OS << "ARCInstKind::RetainRV"; + case ARCInstKind::ClaimRV: + return OS << "ARCInstKind::ClaimRV"; + case ARCInstKind::RetainBlock: + return OS << "ARCInstKind::RetainBlock"; + case ARCInstKind::Release: + return OS << "ARCInstKind::Release"; + case ARCInstKind::Autorelease: + return OS << "ARCInstKind::Autorelease"; + case ARCInstKind::AutoreleaseRV: + return OS << "ARCInstKind::AutoreleaseRV"; + case ARCInstKind::AutoreleasepoolPush: + return OS << "ARCInstKind::AutoreleasepoolPush"; + case ARCInstKind::AutoreleasepoolPop: + return OS << "ARCInstKind::AutoreleasepoolPop"; + case ARCInstKind::NoopCast: + return OS << "ARCInstKind::NoopCast"; + case ARCInstKind::FusedRetainAutorelease: + return OS << "ARCInstKind::FusedRetainAutorelease"; + case ARCInstKind::FusedRetainAutoreleaseRV: + return OS << "ARCInstKind::FusedRetainAutoreleaseRV"; + case ARCInstKind::LoadWeakRetained: + return OS << "ARCInstKind::LoadWeakRetained"; + case ARCInstKind::StoreWeak: + return OS << "ARCInstKind::StoreWeak"; + case ARCInstKind::InitWeak: + return OS << "ARCInstKind::InitWeak"; + case ARCInstKind::LoadWeak: + return OS << "ARCInstKind::LoadWeak"; + case ARCInstKind::MoveWeak: + return OS << "ARCInstKind::MoveWeak"; + case ARCInstKind::CopyWeak: + return OS << "ARCInstKind::CopyWeak"; + case ARCInstKind::DestroyWeak: + return OS << "ARCInstKind::DestroyWeak"; + case ARCInstKind::StoreStrong: + return OS << "ARCInstKind::StoreStrong"; + case ARCInstKind::CallOrUser: + return OS << "ARCInstKind::CallOrUser"; + case ARCInstKind::Call: + return OS << "ARCInstKind::Call"; + case ARCInstKind::User: + return OS << "ARCInstKind::User"; + case ARCInstKind::IntrinsicUser: + return OS << "ARCInstKind::IntrinsicUser"; + case ARCInstKind::None: + return OS << "ARCInstKind::None"; + } + llvm_unreachable("Unknown instruction class!"); +} + +ARCInstKind llvm::objcarc::GetFunctionClass(const Function *F) { + + Intrinsic::ID ID = F->getIntrinsicID(); + switch (ID) { + default: + return ARCInstKind::CallOrUser; + case Intrinsic::objc_autorelease: + return ARCInstKind::Autorelease; + case Intrinsic::objc_autoreleasePoolPop: + return ARCInstKind::AutoreleasepoolPop; + case Intrinsic::objc_autoreleasePoolPush: + return ARCInstKind::AutoreleasepoolPush; + case Intrinsic::objc_autoreleaseReturnValue: + return ARCInstKind::AutoreleaseRV; + case Intrinsic::objc_copyWeak: + return ARCInstKind::CopyWeak; + case Intrinsic::objc_destroyWeak: + return ARCInstKind::DestroyWeak; + case Intrinsic::objc_initWeak: + return ARCInstKind::InitWeak; + case Intrinsic::objc_loadWeak: + return ARCInstKind::LoadWeak; + case Intrinsic::objc_loadWeakRetained: + return ARCInstKind::LoadWeakRetained; + case Intrinsic::objc_moveWeak: + return ARCInstKind::MoveWeak; + case Intrinsic::objc_release: + return ARCInstKind::Release; + case Intrinsic::objc_retain: + return ARCInstKind::Retain; + case Intrinsic::objc_retainAutorelease: + return ARCInstKind::FusedRetainAutorelease; + case Intrinsic::objc_retainAutoreleaseReturnValue: + return ARCInstKind::FusedRetainAutoreleaseRV; + case Intrinsic::objc_retainAutoreleasedReturnValue: + return ARCInstKind::RetainRV; + case Intrinsic::objc_retainBlock: + return ARCInstKind::RetainBlock; + case Intrinsic::objc_storeStrong: + return ARCInstKind::StoreStrong; + case Intrinsic::objc_storeWeak: + return ARCInstKind::StoreWeak; + case Intrinsic::objc_clang_arc_use: + return ARCInstKind::IntrinsicUser; + case Intrinsic::objc_unsafeClaimAutoreleasedReturnValue: + return ARCInstKind::ClaimRV; + case Intrinsic::objc_retainedObject: + return ARCInstKind::NoopCast; + case Intrinsic::objc_unretainedObject: + return ARCInstKind::NoopCast; + case Intrinsic::objc_unretainedPointer: + return ARCInstKind::NoopCast; + case Intrinsic::objc_retain_autorelease: + return ARCInstKind::FusedRetainAutorelease; + case Intrinsic::objc_sync_enter: + return ARCInstKind::User; + case Intrinsic::objc_sync_exit: + return ARCInstKind::User; + case Intrinsic::objc_arc_annotation_topdown_bbstart: + case Intrinsic::objc_arc_annotation_topdown_bbend: + case Intrinsic::objc_arc_annotation_bottomup_bbstart: + case Intrinsic::objc_arc_annotation_bottomup_bbend: + // Ignore annotation calls. This is important to stop the + // optimizer from treating annotations as uses which would + // make the state of the pointers they are attempting to + // elucidate to be incorrect. + return ARCInstKind::None; + } +} + +// A whitelist of intrinsics that we know do not use objc pointers or decrement +// ref counts. +static bool isInertIntrinsic(unsigned ID) { + // TODO: Make this into a covered switch. + switch (ID) { + case Intrinsic::returnaddress: + case Intrinsic::addressofreturnaddress: + case Intrinsic::frameaddress: + case Intrinsic::stacksave: + case Intrinsic::stackrestore: + case Intrinsic::vastart: + case Intrinsic::vacopy: + case Intrinsic::vaend: + case Intrinsic::objectsize: + case Intrinsic::prefetch: + case Intrinsic::stackprotector: + case Intrinsic::eh_return_i32: + case Intrinsic::eh_return_i64: + case Intrinsic::eh_typeid_for: + case Intrinsic::eh_dwarf_cfa: + case Intrinsic::eh_sjlj_lsda: + case Intrinsic::eh_sjlj_functioncontext: + case Intrinsic::init_trampoline: + case Intrinsic::adjust_trampoline: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + // Don't let dbg info affect our results. + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::dbg_label: + // Short cut: Some intrinsics obviously don't use ObjC pointers. + return true; + default: + return false; + } +} + +// A whitelist of intrinsics that we know do not use objc pointers or decrement +// ref counts. +static bool isUseOnlyIntrinsic(unsigned ID) { + // We are conservative and even though intrinsics are unlikely to touch + // reference counts, we white list them for safety. + // + // TODO: Expand this into a covered switch. There is a lot more here. + switch (ID) { + case Intrinsic::memcpy: + case Intrinsic::memmove: + case Intrinsic::memset: + return true; + default: + return false; + } +} + +/// Determine what kind of construct V is. +ARCInstKind llvm::objcarc::GetARCInstKind(const Value *V) { + if (const Instruction *I = dyn_cast<Instruction>(V)) { + // Any instruction other than bitcast and gep with a pointer operand have a + // use of an objc pointer. Bitcasts, GEPs, Selects, PHIs transfer a pointer + // to a subsequent use, rather than using it themselves, in this sense. + // As a short cut, several other opcodes are known to have no pointer + // operands of interest. And ret is never followed by a release, so it's + // not interesting to examine. + switch (I->getOpcode()) { + case Instruction::Call: { + const CallInst *CI = cast<CallInst>(I); + // See if we have a function that we know something about. + if (const Function *F = CI->getCalledFunction()) { + ARCInstKind Class = GetFunctionClass(F); + if (Class != ARCInstKind::CallOrUser) + return Class; + Intrinsic::ID ID = F->getIntrinsicID(); + if (isInertIntrinsic(ID)) + return ARCInstKind::None; + if (isUseOnlyIntrinsic(ID)) + return ARCInstKind::User; + } + + // Otherwise, be conservative. + return GetCallSiteClass(CI); + } + case Instruction::Invoke: + // Otherwise, be conservative. + return GetCallSiteClass(cast<InvokeInst>(I)); + case Instruction::BitCast: + case Instruction::GetElementPtr: + case Instruction::Select: + case Instruction::PHI: + case Instruction::Ret: + case Instruction::Br: + case Instruction::Switch: + case Instruction::IndirectBr: + case Instruction::Alloca: + case Instruction::VAArg: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::FDiv: + case Instruction::SRem: + case Instruction::URem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::Trunc: + case Instruction::IntToPtr: + case Instruction::FCmp: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::InsertElement: + case Instruction::ExtractElement: + case Instruction::ShuffleVector: + case Instruction::ExtractValue: + break; + case Instruction::ICmp: + // Comparing a pointer with null, or any other constant, isn't an + // interesting use, because we don't care what the pointer points to, or + // about the values of any other dynamic reference-counted pointers. + if (IsPotentialRetainableObjPtr(I->getOperand(1))) + return ARCInstKind::User; + break; + default: + // For anything else, check all the operands. + // Note that this includes both operands of a Store: while the first + // operand isn't actually being dereferenced, it is being stored to + // memory where we can no longer track who might read it and dereference + // it, so we have to consider it potentially used. + for (User::const_op_iterator OI = I->op_begin(), OE = I->op_end(); + OI != OE; ++OI) + if (IsPotentialRetainableObjPtr(*OI)) + return ARCInstKind::User; + } + } + + // Otherwise, it's totally inert for ARC purposes. + return ARCInstKind::None; +} + +/// Test if the given class is a kind of user. +bool llvm::objcarc::IsUser(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::User: + case ARCInstKind::CallOrUser: + case ARCInstKind::IntrinsicUser: + return true; + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::RetainBlock: + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::NoopCast: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::Call: + case ARCInstKind::None: + case ARCInstKind::ClaimRV: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class is objc_retain or equivalent. +bool llvm::objcarc::IsRetain(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + return true; + // I believe we treat retain block as not a retain since it can copy its + // block. + case ARCInstKind::RetainBlock: + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::NoopCast: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::ClaimRV: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class is objc_autorelease or equivalent. +bool llvm::objcarc::IsAutorelease(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + return true; + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::RetainBlock: + case ARCInstKind::Release: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::NoopCast: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which return their +/// argument verbatim. +bool llvm::objcarc::IsForwarding(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::NoopCast: + return true; + case ARCInstKind::RetainBlock: + case ARCInstKind::Release: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which do nothing if +/// passed a null pointer. +bool llvm::objcarc::IsNoopOnNull(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::RetainBlock: + return true; + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which do nothing if +/// passed a global variable. +bool llvm::objcarc::IsNoopOnGlobal(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::RetainBlock: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + return true; + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are always safe +/// to mark with the "tail" keyword. +bool llvm::objcarc::IsAlwaysTail(ARCInstKind Class) { + // ARCInstKind::RetainBlock may be given a stack argument. + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::AutoreleaseRV: + return true; + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::RetainBlock: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are never safe +/// to mark with the "tail" keyword. +bool llvm::objcarc::IsNeverTail(ARCInstKind Class) { + /// It is never safe to tail call objc_autorelease since by tail calling + /// objc_autorelease: fast autoreleasing causing our object to be potentially + /// reclaimed from the autorelease pool which violates the semantics of + /// __autoreleasing types in ARC. + switch (Class) { + case ARCInstKind::Autorelease: + return true; + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::Release: + case ARCInstKind::RetainBlock: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are always safe +/// to mark with the nounwind attribute. +bool llvm::objcarc::IsNoThrow(ARCInstKind Class) { + // objc_retainBlock is not nounwind because it calls user copy constructors + // which could theoretically throw. + switch (Class) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::Release: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + return true; + case ARCInstKind::RetainBlock: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +/// Test whether the given instruction can autorelease any pointer or cause an +/// autoreleasepool pop. +/// +/// This means that it *could* interrupt the RV optimization. +bool llvm::objcarc::CanInterruptRV(ARCInstKind Class) { + switch (Class) { + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + return true; + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: + case ARCInstKind::Release: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::RetainBlock: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::User: + case ARCInstKind::None: + case ARCInstKind::NoopCast: + return false; + } + llvm_unreachable("covered switch isn't covered?"); +} + +bool llvm::objcarc::CanDecrementRefCount(ARCInstKind Kind) { + switch (Kind) { + case ARCInstKind::Retain: + case ARCInstKind::RetainRV: + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + case ARCInstKind::NoopCast: + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + case ARCInstKind::IntrinsicUser: + case ARCInstKind::User: + case ARCInstKind::None: + return false; + + // The cases below are conservative. + + // RetainBlock can result in user defined copy constructors being called + // implying releases may occur. + case ARCInstKind::RetainBlock: + case ARCInstKind::Release: + case ARCInstKind::AutoreleasepoolPush: + case ARCInstKind::AutoreleasepoolPop: + case ARCInstKind::LoadWeakRetained: + case ARCInstKind::StoreWeak: + case ARCInstKind::InitWeak: + case ARCInstKind::LoadWeak: + case ARCInstKind::MoveWeak: + case ARCInstKind::CopyWeak: + case ARCInstKind::DestroyWeak: + case ARCInstKind::StoreStrong: + case ARCInstKind::CallOrUser: + case ARCInstKind::Call: + case ARCInstKind::ClaimRV: + return true; + } + + llvm_unreachable("covered switch isn't covered?"); +} diff --git a/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp new file mode 100644 index 000000000000..07a5619a35b9 --- /dev/null +++ b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp @@ -0,0 +1,133 @@ +//===- OptimizationRemarkEmitter.cpp - Optimization Diagnostic --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Optimization diagnostic interfaces. It's packaged as an analysis pass so +// that by using this service passes become dependent on BFI as well. BFI is +// used to compute the "hotness" of the diagnostic message. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" + +using namespace llvm; + +OptimizationRemarkEmitter::OptimizationRemarkEmitter(const Function *F) + : F(F), BFI(nullptr) { + if (!F->getContext().getDiagnosticsHotnessRequested()) + return; + + // First create a dominator tree. + DominatorTree DT; + DT.recalculate(*const_cast<Function *>(F)); + + // Generate LoopInfo from it. + LoopInfo LI; + LI.analyze(DT); + + // Then compute BranchProbabilityInfo. + BranchProbabilityInfo BPI; + BPI.calculate(*F, LI); + + // Finally compute BFI. + OwnedBFI = std::make_unique<BlockFrequencyInfo>(*F, BPI, LI); + BFI = OwnedBFI.get(); +} + +bool OptimizationRemarkEmitter::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // This analysis has no state and so can be trivially preserved but it needs + // a fresh view of BFI if it was constructed with one. + if (BFI && Inv.invalidate<BlockFrequencyAnalysis>(F, PA)) + return true; + + // Otherwise this analysis result remains valid. + return false; +} + +Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(const Value *V) { + if (!BFI) + return None; + + return BFI->getBlockProfileCount(cast<BasicBlock>(V)); +} + +void OptimizationRemarkEmitter::computeHotness( + DiagnosticInfoIROptimization &OptDiag) { + const Value *V = OptDiag.getCodeRegion(); + if (V) + OptDiag.setHotness(computeHotness(V)); +} + +void OptimizationRemarkEmitter::emit( + DiagnosticInfoOptimizationBase &OptDiagBase) { + auto &OptDiag = cast<DiagnosticInfoIROptimization>(OptDiagBase); + computeHotness(OptDiag); + + // Only emit it if its hotness meets the threshold. + if (OptDiag.getHotness().getValueOr(0) < + F->getContext().getDiagnosticsHotnessThreshold()) { + return; + } + + F->getContext().diagnose(OptDiag); +} + +OptimizationRemarkEmitterWrapperPass::OptimizationRemarkEmitterWrapperPass() + : FunctionPass(ID) { + initializeOptimizationRemarkEmitterWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +bool OptimizationRemarkEmitterWrapperPass::runOnFunction(Function &Fn) { + BlockFrequencyInfo *BFI; + + if (Fn.getContext().getDiagnosticsHotnessRequested()) + BFI = &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI(); + else + BFI = nullptr; + + ORE = std::make_unique<OptimizationRemarkEmitter>(&Fn, BFI); + return false; +} + +void OptimizationRemarkEmitterWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); + AU.setPreservesAll(); +} + +AnalysisKey OptimizationRemarkEmitterAnalysis::Key; + +OptimizationRemarkEmitter +OptimizationRemarkEmitterAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + BlockFrequencyInfo *BFI; + + if (F.getContext().getDiagnosticsHotnessRequested()) + BFI = &AM.getResult<BlockFrequencyAnalysis>(F); + else + BFI = nullptr; + + return OptimizationRemarkEmitter(&F, BFI); +} + +char OptimizationRemarkEmitterWrapperPass::ID = 0; +static const char ore_name[] = "Optimization Remark Emitter"; +#define ORE_NAME "opt-remark-emitter" + +INITIALIZE_PASS_BEGIN(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, + false, true) +INITIALIZE_PASS_DEPENDENCY(LazyBFIPass) +INITIALIZE_PASS_END(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, + false, true) diff --git a/llvm/lib/Analysis/OrderedBasicBlock.cpp b/llvm/lib/Analysis/OrderedBasicBlock.cpp new file mode 100644 index 000000000000..48f2a4020c66 --- /dev/null +++ b/llvm/lib/Analysis/OrderedBasicBlock.cpp @@ -0,0 +1,111 @@ +//===- OrderedBasicBlock.cpp --------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the OrderedBasicBlock class. OrderedBasicBlock +// maintains an interface where clients can query if one instruction comes +// before another in a BasicBlock. Since BasicBlock currently lacks a reliable +// way to query relative position between instructions one can use +// OrderedBasicBlock to do such queries. OrderedBasicBlock is lazily built on a +// source BasicBlock and maintains an internal Instruction -> Position map. A +// OrderedBasicBlock instance should be discarded whenever the source +// BasicBlock changes. +// +// It's currently used by the CaptureTracker in order to find relative +// positions of a pair of instructions inside a BasicBlock. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/IR/Instruction.h" +using namespace llvm; + +OrderedBasicBlock::OrderedBasicBlock(const BasicBlock *BasicB) + : NextInstPos(0), BB(BasicB) { + LastInstFound = BB->end(); +} + +/// Given no cached results, find if \p A comes before \p B in \p BB. +/// Cache and number out instruction while walking \p BB. +bool OrderedBasicBlock::comesBefore(const Instruction *A, + const Instruction *B) { + const Instruction *Inst = nullptr; + assert(!(LastInstFound == BB->end() && NextInstPos != 0) && + "Instruction supposed to be in NumberedInsts"); + assert(A->getParent() == BB && "Instruction supposed to be in the block!"); + assert(B->getParent() == BB && "Instruction supposed to be in the block!"); + + // Start the search with the instruction found in the last lookup round. + auto II = BB->begin(); + auto IE = BB->end(); + if (LastInstFound != IE) + II = std::next(LastInstFound); + + // Number all instructions up to the point where we find 'A' or 'B'. + for (; II != IE; ++II) { + Inst = cast<Instruction>(II); + NumberedInsts[Inst] = NextInstPos++; + if (Inst == A || Inst == B) + break; + } + + assert(II != IE && "Instruction not found?"); + assert((Inst == A || Inst == B) && "Should find A or B"); + LastInstFound = II; + return Inst != B; +} + +/// Find out whether \p A dominates \p B, meaning whether \p A +/// comes before \p B in \p BB. This is a simplification that considers +/// cached instruction positions and ignores other basic blocks, being +/// only relevant to compare relative instructions positions inside \p BB. +bool OrderedBasicBlock::dominates(const Instruction *A, const Instruction *B) { + assert(A->getParent() == B->getParent() && + "Instructions must be in the same basic block!"); + assert(A->getParent() == BB && "Instructions must be in the tracked block!"); + + // First we lookup the instructions. If they don't exist, lookup will give us + // back ::end(). If they both exist, we compare the numbers. Otherwise, if NA + // exists and NB doesn't, it means NA must come before NB because we would + // have numbered NB as well if it didn't. The same is true for NB. If it + // exists, but NA does not, NA must come after it. If neither exist, we need + // to number the block and cache the results (by calling comesBefore). + auto NAI = NumberedInsts.find(A); + auto NBI = NumberedInsts.find(B); + if (NAI != NumberedInsts.end() && NBI != NumberedInsts.end()) + return NAI->second < NBI->second; + if (NAI != NumberedInsts.end()) + return true; + if (NBI != NumberedInsts.end()) + return false; + + return comesBefore(A, B); +} + +void OrderedBasicBlock::eraseInstruction(const Instruction *I) { + if (LastInstFound != BB->end() && I == &*LastInstFound) { + if (LastInstFound == BB->begin()) { + LastInstFound = BB->end(); + NextInstPos = 0; + } else + LastInstFound--; + } + + NumberedInsts.erase(I); +} + +void OrderedBasicBlock::replaceInstruction(const Instruction *Old, + const Instruction *New) { + auto OI = NumberedInsts.find(Old); + if (OI == NumberedInsts.end()) + return; + + NumberedInsts.insert({New, OI->second}); + if (LastInstFound != BB->end() && Old == &*LastInstFound) + LastInstFound = New->getIterator(); + NumberedInsts.erase(Old); +} diff --git a/llvm/lib/Analysis/OrderedInstructions.cpp b/llvm/lib/Analysis/OrderedInstructions.cpp new file mode 100644 index 000000000000..e947e5e388a8 --- /dev/null +++ b/llvm/lib/Analysis/OrderedInstructions.cpp @@ -0,0 +1,50 @@ +//===-- OrderedInstructions.cpp - Instruction dominance function ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utility to check dominance relation of 2 instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OrderedInstructions.h" +using namespace llvm; + +bool OrderedInstructions::localDominates(const Instruction *InstA, + const Instruction *InstB) const { + assert(InstA->getParent() == InstB->getParent() && + "Instructions must be in the same basic block"); + + const BasicBlock *IBB = InstA->getParent(); + auto OBB = OBBMap.find(IBB); + if (OBB == OBBMap.end()) + OBB = OBBMap.insert({IBB, std::make_unique<OrderedBasicBlock>(IBB)}).first; + return OBB->second->dominates(InstA, InstB); +} + +/// Given 2 instructions, use OrderedBasicBlock to check for dominance relation +/// if the instructions are in the same basic block, Otherwise, use dominator +/// tree. +bool OrderedInstructions::dominates(const Instruction *InstA, + const Instruction *InstB) const { + // Use ordered basic block to do dominance check in case the 2 instructions + // are in the same basic block. + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); + return DT->dominates(InstA->getParent(), InstB->getParent()); +} + +bool OrderedInstructions::dfsBefore(const Instruction *InstA, + const Instruction *InstB) const { + // Use ordered basic block in case the 2 instructions are in the same basic + // block. + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); + + DomTreeNode *DA = DT->getNode(InstA->getParent()); + DomTreeNode *DB = DT->getNode(InstB->getParent()); + return DA->getDFSNumIn() < DB->getDFSNumIn(); +} diff --git a/llvm/lib/Analysis/PHITransAddr.cpp b/llvm/lib/Analysis/PHITransAddr.cpp new file mode 100644 index 000000000000..7f77ab146c4c --- /dev/null +++ b/llvm/lib/Analysis/PHITransAddr.cpp @@ -0,0 +1,439 @@ +//===- PHITransAddr.cpp - PHI Translation for Addresses -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the PHITransAddr class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static bool CanPHITrans(Instruction *Inst) { + if (isa<PHINode>(Inst) || + isa<GetElementPtrInst>(Inst)) + return true; + + if (isa<CastInst>(Inst) && + isSafeToSpeculativelyExecute(Inst)) + return true; + + if (Inst->getOpcode() == Instruction::Add && + isa<ConstantInt>(Inst->getOperand(1))) + return true; + + // cerr << "MEMDEP: Could not PHI translate: " << *Pointer; + // if (isa<BitCastInst>(PtrInst) || isa<GetElementPtrInst>(PtrInst)) + // cerr << "OP:\t\t\t\t" << *PtrInst->getOperand(0); + return false; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void PHITransAddr::dump() const { + if (!Addr) { + dbgs() << "PHITransAddr: null\n"; + return; + } + dbgs() << "PHITransAddr: " << *Addr << "\n"; + for (unsigned i = 0, e = InstInputs.size(); i != e; ++i) + dbgs() << " Input #" << i << " is " << *InstInputs[i] << "\n"; +} +#endif + + +static bool VerifySubExpr(Value *Expr, + SmallVectorImpl<Instruction*> &InstInputs) { + // If this is a non-instruction value, there is nothing to do. + Instruction *I = dyn_cast<Instruction>(Expr); + if (!I) return true; + + // If it's an instruction, it is either in Tmp or its operands recursively + // are. + SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); + if (Entry != InstInputs.end()) { + InstInputs.erase(Entry); + return true; + } + + // If it isn't in the InstInputs list it is a subexpr incorporated into the + // address. Sanity check that it is phi translatable. + if (!CanPHITrans(I)) { + errs() << "Instruction in PHITransAddr is not phi-translatable:\n"; + errs() << *I << '\n'; + llvm_unreachable("Either something is missing from InstInputs or " + "CanPHITrans is wrong."); + } + + // Validate the operands of the instruction. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (!VerifySubExpr(I->getOperand(i), InstInputs)) + return false; + + return true; +} + +/// Verify - Check internal consistency of this data structure. If the +/// structure is valid, it returns true. If invalid, it prints errors and +/// returns false. +bool PHITransAddr::Verify() const { + if (!Addr) return true; + + SmallVector<Instruction*, 8> Tmp(InstInputs.begin(), InstInputs.end()); + + if (!VerifySubExpr(Addr, Tmp)) + return false; + + if (!Tmp.empty()) { + errs() << "PHITransAddr contains extra instructions:\n"; + for (unsigned i = 0, e = InstInputs.size(); i != e; ++i) + errs() << " InstInput #" << i << " is " << *InstInputs[i] << "\n"; + llvm_unreachable("This is unexpected."); + } + + // a-ok. + return true; +} + + +/// IsPotentiallyPHITranslatable - If this needs PHI translation, return true +/// if we have some hope of doing it. This should be used as a filter to +/// avoid calling PHITranslateValue in hopeless situations. +bool PHITransAddr::IsPotentiallyPHITranslatable() const { + // If the input value is not an instruction, or if it is not defined in CurBB, + // then we don't need to phi translate it. + Instruction *Inst = dyn_cast<Instruction>(Addr); + return !Inst || CanPHITrans(Inst); +} + + +static void RemoveInstInputs(Value *V, + SmallVectorImpl<Instruction*> &InstInputs) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return; + + // If the instruction is in the InstInputs list, remove it. + SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); + if (Entry != InstInputs.end()) { + InstInputs.erase(Entry); + return; + } + + assert(!isa<PHINode>(I) && "Error, removing something that isn't an input"); + + // Otherwise, it must have instruction inputs itself. Zap them recursively. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (Instruction *Op = dyn_cast<Instruction>(I->getOperand(i))) + RemoveInstInputs(Op, InstInputs); + } +} + +Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, + BasicBlock *PredBB, + const DominatorTree *DT) { + // If this is a non-instruction value, it can't require PHI translation. + Instruction *Inst = dyn_cast<Instruction>(V); + if (!Inst) return V; + + // Determine whether 'Inst' is an input to our PHI translatable expression. + bool isInput = is_contained(InstInputs, Inst); + + // Handle inputs instructions if needed. + if (isInput) { + if (Inst->getParent() != CurBB) { + // If it is an input defined in a different block, then it remains an + // input. + return Inst; + } + + // If 'Inst' is defined in this block and is an input that needs to be phi + // translated, we need to incorporate the value into the expression or fail. + + // In either case, the instruction itself isn't an input any longer. + InstInputs.erase(find(InstInputs, Inst)); + + // If this is a PHI, go ahead and translate it. + if (PHINode *PN = dyn_cast<PHINode>(Inst)) + return AddAsInput(PN->getIncomingValueForBlock(PredBB)); + + // If this is a non-phi value, and it is analyzable, we can incorporate it + // into the expression by making all instruction operands be inputs. + if (!CanPHITrans(Inst)) + return nullptr; + + // All instruction operands are now inputs (and of course, they may also be + // defined in this block, so they may need to be phi translated themselves. + for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(Inst->getOperand(i))) + InstInputs.push_back(Op); + } + + // Ok, it must be an intermediate result (either because it started that way + // or because we just incorporated it into the expression). See if its + // operands need to be phi translated, and if so, reconstruct it. + + if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { + if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; + Value *PHIIn = PHITranslateSubExpr(Cast->getOperand(0), CurBB, PredBB, DT); + if (!PHIIn) return nullptr; + if (PHIIn == Cast->getOperand(0)) + return Cast; + + // Find an available version of this cast. + + // Constants are trivial to find. + if (Constant *C = dyn_cast<Constant>(PHIIn)) + return AddAsInput(ConstantExpr::getCast(Cast->getOpcode(), + C, Cast->getType())); + + // Otherwise we have to see if a casted version of the incoming pointer + // is available. If so, we can use it, otherwise we have to fail. + for (User *U : PHIIn->users()) { + if (CastInst *CastI = dyn_cast<CastInst>(U)) + if (CastI->getOpcode() == Cast->getOpcode() && + CastI->getType() == Cast->getType() && + (!DT || DT->dominates(CastI->getParent(), PredBB))) + return CastI; + } + return nullptr; + } + + // Handle getelementptr with at least one PHI translatable operand. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { + SmallVector<Value*, 8> GEPOps; + bool AnyChanged = false; + for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { + Value *GEPOp = PHITranslateSubExpr(GEP->getOperand(i), CurBB, PredBB, DT); + if (!GEPOp) return nullptr; + + AnyChanged |= GEPOp != GEP->getOperand(i); + GEPOps.push_back(GEPOp); + } + + if (!AnyChanged) + return GEP; + + // Simplify the GEP to handle 'gep x, 0' -> x etc. + if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(), + GEPOps, {DL, TLI, DT, AC})) { + for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) + RemoveInstInputs(GEPOps[i], InstInputs); + + return AddAsInput(V); + } + + // Scan to see if we have this GEP available. + Value *APHIOp = GEPOps[0]; + for (User *U : APHIOp->users()) { + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) + if (GEPI->getType() == GEP->getType() && + GEPI->getNumOperands() == GEPOps.size() && + GEPI->getParent()->getParent() == CurBB->getParent() && + (!DT || DT->dominates(GEPI->getParent(), PredBB))) { + if (std::equal(GEPOps.begin(), GEPOps.end(), GEPI->op_begin())) + return GEPI; + } + } + return nullptr; + } + + // Handle add with a constant RHS. + if (Inst->getOpcode() == Instruction::Add && + isa<ConstantInt>(Inst->getOperand(1))) { + // PHI translate the LHS. + Constant *RHS = cast<ConstantInt>(Inst->getOperand(1)); + bool isNSW = cast<BinaryOperator>(Inst)->hasNoSignedWrap(); + bool isNUW = cast<BinaryOperator>(Inst)->hasNoUnsignedWrap(); + + Value *LHS = PHITranslateSubExpr(Inst->getOperand(0), CurBB, PredBB, DT); + if (!LHS) return nullptr; + + // If the PHI translated LHS is an add of a constant, fold the immediates. + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(LHS)) + if (BOp->getOpcode() == Instruction::Add) + if (ConstantInt *CI = dyn_cast<ConstantInt>(BOp->getOperand(1))) { + LHS = BOp->getOperand(0); + RHS = ConstantExpr::getAdd(RHS, CI); + isNSW = isNUW = false; + + // If the old 'LHS' was an input, add the new 'LHS' as an input. + if (is_contained(InstInputs, BOp)) { + RemoveInstInputs(BOp, InstInputs); + AddAsInput(LHS); + } + } + + // See if the add simplifies away. + if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) { + // If we simplified the operands, the LHS is no longer an input, but Res + // is. + RemoveInstInputs(LHS, InstInputs); + return AddAsInput(Res); + } + + // If we didn't modify the add, just return it. + if (LHS == Inst->getOperand(0) && RHS == Inst->getOperand(1)) + return Inst; + + // Otherwise, see if we have this add available somewhere. + for (User *U : LHS->users()) { + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) + if (BO->getOpcode() == Instruction::Add && + BO->getOperand(0) == LHS && BO->getOperand(1) == RHS && + BO->getParent()->getParent() == CurBB->getParent() && + (!DT || DT->dominates(BO->getParent(), PredBB))) + return BO; + } + + return nullptr; + } + + // Otherwise, we failed. + return nullptr; +} + + +/// PHITranslateValue - PHI translate the current address up the CFG from +/// CurBB to Pred, updating our state to reflect any needed changes. If +/// 'MustDominate' is true, the translated value must dominate +/// PredBB. This returns true on failure and sets Addr to null. +bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB, + const DominatorTree *DT, + bool MustDominate) { + assert(DT || !MustDominate); + assert(Verify() && "Invalid PHITransAddr!"); + if (DT && DT->isReachableFromEntry(PredBB)) + Addr = + PHITranslateSubExpr(Addr, CurBB, PredBB, MustDominate ? DT : nullptr); + else + Addr = nullptr; + assert(Verify() && "Invalid PHITransAddr!"); + + if (MustDominate) + // Make sure the value is live in the predecessor. + if (Instruction *Inst = dyn_cast_or_null<Instruction>(Addr)) + if (!DT->dominates(Inst->getParent(), PredBB)) + Addr = nullptr; + + return Addr == nullptr; +} + +/// PHITranslateWithInsertion - PHI translate this value into the specified +/// predecessor block, inserting a computation of the value if it is +/// unavailable. +/// +/// All newly created instructions are added to the NewInsts list. This +/// returns null on failure. +/// +Value *PHITransAddr:: +PHITranslateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB, + const DominatorTree &DT, + SmallVectorImpl<Instruction*> &NewInsts) { + unsigned NISize = NewInsts.size(); + + // Attempt to PHI translate with insertion. + Addr = InsertPHITranslatedSubExpr(Addr, CurBB, PredBB, DT, NewInsts); + + // If successful, return the new value. + if (Addr) return Addr; + + // If not, destroy any intermediate instructions inserted. + while (NewInsts.size() != NISize) + NewInsts.pop_back_val()->eraseFromParent(); + return nullptr; +} + + +/// InsertPHITranslatedPointer - Insert a computation of the PHI translated +/// version of 'V' for the edge PredBB->CurBB into the end of the PredBB +/// block. All newly created instructions are added to the NewInsts list. +/// This returns null on failure. +/// +Value *PHITransAddr:: +InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, + BasicBlock *PredBB, const DominatorTree &DT, + SmallVectorImpl<Instruction*> &NewInsts) { + // See if we have a version of this value already available and dominating + // PredBB. If so, there is no need to insert a new instance of it. + PHITransAddr Tmp(InVal, DL, AC); + if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT, /*MustDominate=*/true)) + return Tmp.getAddr(); + + // We don't need to PHI translate values which aren't instructions. + auto *Inst = dyn_cast<Instruction>(InVal); + if (!Inst) + return nullptr; + + // Handle cast of PHI translatable value. + if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { + if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; + Value *OpVal = InsertPHITranslatedSubExpr(Cast->getOperand(0), + CurBB, PredBB, DT, NewInsts); + if (!OpVal) return nullptr; + + // Otherwise insert a cast at the end of PredBB. + CastInst *New = CastInst::Create(Cast->getOpcode(), OpVal, InVal->getType(), + InVal->getName() + ".phi.trans.insert", + PredBB->getTerminator()); + New->setDebugLoc(Inst->getDebugLoc()); + NewInsts.push_back(New); + return New; + } + + // Handle getelementptr with at least one PHI operand. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { + SmallVector<Value*, 8> GEPOps; + BasicBlock *CurBB = GEP->getParent(); + for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { + Value *OpVal = InsertPHITranslatedSubExpr(GEP->getOperand(i), + CurBB, PredBB, DT, NewInsts); + if (!OpVal) return nullptr; + GEPOps.push_back(OpVal); + } + + GetElementPtrInst *Result = GetElementPtrInst::Create( + GEP->getSourceElementType(), GEPOps[0], makeArrayRef(GEPOps).slice(1), + InVal->getName() + ".phi.trans.insert", PredBB->getTerminator()); + Result->setDebugLoc(Inst->getDebugLoc()); + Result->setIsInBounds(GEP->isInBounds()); + NewInsts.push_back(Result); + return Result; + } + +#if 0 + // FIXME: This code works, but it is unclear that we actually want to insert + // a big chain of computation in order to make a value available in a block. + // This needs to be evaluated carefully to consider its cost trade offs. + + // Handle add with a constant RHS. + if (Inst->getOpcode() == Instruction::Add && + isa<ConstantInt>(Inst->getOperand(1))) { + // PHI translate the LHS. + Value *OpVal = InsertPHITranslatedSubExpr(Inst->getOperand(0), + CurBB, PredBB, DT, NewInsts); + if (OpVal == 0) return 0; + + BinaryOperator *Res = BinaryOperator::CreateAdd(OpVal, Inst->getOperand(1), + InVal->getName()+".phi.trans.insert", + PredBB->getTerminator()); + Res->setHasNoSignedWrap(cast<BinaryOperator>(Inst)->hasNoSignedWrap()); + Res->setHasNoUnsignedWrap(cast<BinaryOperator>(Inst)->hasNoUnsignedWrap()); + NewInsts.push_back(Res); + return Res; + } +#endif + + return nullptr; +} diff --git a/llvm/lib/Analysis/PhiValues.cpp b/llvm/lib/Analysis/PhiValues.cpp new file mode 100644 index 000000000000..49749bc44746 --- /dev/null +++ b/llvm/lib/Analysis/PhiValues.cpp @@ -0,0 +1,212 @@ +//===- PhiValues.cpp - Phi Value Analysis ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PhiValues.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +void PhiValues::PhiValuesCallbackVH::deleted() { + PV->invalidateValue(getValPtr()); +} + +void PhiValues::PhiValuesCallbackVH::allUsesReplacedWith(Value *) { + // We could potentially update the cached values we have with the new value, + // but it's simpler to just treat the old value as invalidated. + PV->invalidateValue(getValPtr()); +} + +bool PhiValues::invalidate(Function &, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // PhiValues is invalidated if it isn't preserved. + auto PAC = PA.getChecker<PhiValuesAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()); +} + +// The goal here is to find all of the non-phi values reachable from this phi, +// and to do the same for all of the phis reachable from this phi, as doing so +// is necessary anyway in order to get the values for this phi. We do this using +// Tarjan's algorithm with Nuutila's improvements to find the strongly connected +// components of the phi graph rooted in this phi: +// * All phis in a strongly connected component will have the same reachable +// non-phi values. The SCC may not be the maximal subgraph for that set of +// reachable values, but finding out that isn't really necessary (it would +// only reduce the amount of memory needed to store the values). +// * Tarjan's algorithm completes components in a bottom-up manner, i.e. it +// never completes a component before the components reachable from it have +// been completed. This means that when we complete a component we have +// everything we need to collect the values reachable from that component. +// * We collect both the non-phi values reachable from each SCC, as that's what +// we're ultimately interested in, and all of the reachable values, i.e. +// including phis, as that makes invalidateValue easier. +void PhiValues::processPhi(const PHINode *Phi, + SmallVector<const PHINode *, 8> &Stack) { + // Initialize the phi with the next depth number. + assert(DepthMap.lookup(Phi) == 0); + assert(NextDepthNumber != UINT_MAX); + unsigned int DepthNumber = ++NextDepthNumber; + DepthMap[Phi] = DepthNumber; + + // Recursively process the incoming phis of this phi. + TrackedValues.insert(PhiValuesCallbackVH(const_cast<PHINode *>(Phi), this)); + for (Value *PhiOp : Phi->incoming_values()) { + if (PHINode *PhiPhiOp = dyn_cast<PHINode>(PhiOp)) { + // Recurse if the phi has not yet been visited. + if (DepthMap.lookup(PhiPhiOp) == 0) + processPhi(PhiPhiOp, Stack); + assert(DepthMap.lookup(PhiPhiOp) != 0); + // If the phi did not become part of a component then this phi and that + // phi are part of the same component, so adjust the depth number. + if (!ReachableMap.count(DepthMap[PhiPhiOp])) + DepthMap[Phi] = std::min(DepthMap[Phi], DepthMap[PhiPhiOp]); + } else { + TrackedValues.insert(PhiValuesCallbackVH(PhiOp, this)); + } + } + + // Now that incoming phis have been handled, push this phi to the stack. + Stack.push_back(Phi); + + // If the depth number has not changed then we've finished collecting the phis + // of a strongly connected component. + if (DepthMap[Phi] == DepthNumber) { + // Collect the reachable values for this component. The phis of this + // component will be those on top of the depth stach with the same or + // greater depth number. + ConstValueSet Reachable; + while (!Stack.empty() && DepthMap[Stack.back()] >= DepthNumber) { + const PHINode *ComponentPhi = Stack.pop_back_val(); + Reachable.insert(ComponentPhi); + DepthMap[ComponentPhi] = DepthNumber; + for (Value *Op : ComponentPhi->incoming_values()) { + if (PHINode *PhiOp = dyn_cast<PHINode>(Op)) { + // If this phi is not part of the same component then that component + // is guaranteed to have been completed before this one. Therefore we + // can just add its reachable values to the reachable values of this + // component. + auto It = ReachableMap.find(DepthMap[PhiOp]); + if (It != ReachableMap.end()) + Reachable.insert(It->second.begin(), It->second.end()); + } else { + Reachable.insert(Op); + } + } + } + ReachableMap.insert({DepthNumber,Reachable}); + + // Filter out phis to get the non-phi reachable values. + ValueSet NonPhi; + for (const Value *V : Reachable) + if (!isa<PHINode>(V)) + NonPhi.insert(const_cast<Value*>(V)); + NonPhiReachableMap.insert({DepthNumber,NonPhi}); + } +} + +const PhiValues::ValueSet &PhiValues::getValuesForPhi(const PHINode *PN) { + if (DepthMap.count(PN) == 0) { + SmallVector<const PHINode *, 8> Stack; + processPhi(PN, Stack); + assert(Stack.empty()); + } + assert(DepthMap.lookup(PN) != 0); + return NonPhiReachableMap[DepthMap[PN]]; +} + +void PhiValues::invalidateValue(const Value *V) { + // Components that can reach V are invalid. + SmallVector<unsigned int, 8> InvalidComponents; + for (auto &Pair : ReachableMap) + if (Pair.second.count(V)) + InvalidComponents.push_back(Pair.first); + + for (unsigned int N : InvalidComponents) { + for (const Value *V : ReachableMap[N]) + if (const PHINode *PN = dyn_cast<PHINode>(V)) + DepthMap.erase(PN); + NonPhiReachableMap.erase(N); + ReachableMap.erase(N); + } + // This value is no longer tracked + auto It = TrackedValues.find_as(V); + if (It != TrackedValues.end()) + TrackedValues.erase(It); +} + +void PhiValues::releaseMemory() { + DepthMap.clear(); + NonPhiReachableMap.clear(); + ReachableMap.clear(); +} + +void PhiValues::print(raw_ostream &OS) const { + // Iterate through the phi nodes of the function rather than iterating through + // DepthMap in order to get predictable ordering. + for (const BasicBlock &BB : F) { + for (const PHINode &PN : BB.phis()) { + OS << "PHI "; + PN.printAsOperand(OS, false); + OS << " has values:\n"; + unsigned int N = DepthMap.lookup(&PN); + auto It = NonPhiReachableMap.find(N); + if (It == NonPhiReachableMap.end()) + OS << " UNKNOWN\n"; + else if (It->second.empty()) + OS << " NONE\n"; + else + for (Value *V : It->second) + // Printing of an instruction prints two spaces at the start, so + // handle instructions and everything else slightly differently in + // order to get consistent indenting. + if (Instruction *I = dyn_cast<Instruction>(V)) + OS << *I << "\n"; + else + OS << " " << *V << "\n"; + } + } +} + +AnalysisKey PhiValuesAnalysis::Key; +PhiValues PhiValuesAnalysis::run(Function &F, FunctionAnalysisManager &) { + return PhiValues(F); +} + +PreservedAnalyses PhiValuesPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "PHI Values for function: " << F.getName() << "\n"; + PhiValues &PI = AM.getResult<PhiValuesAnalysis>(F); + for (const BasicBlock &BB : F) + for (const PHINode &PN : BB.phis()) + PI.getValuesForPhi(&PN); + PI.print(OS); + return PreservedAnalyses::all(); +} + +PhiValuesWrapperPass::PhiValuesWrapperPass() : FunctionPass(ID) { + initializePhiValuesWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool PhiValuesWrapperPass::runOnFunction(Function &F) { + Result.reset(new PhiValues(F)); + return false; +} + +void PhiValuesWrapperPass::releaseMemory() { + Result->releaseMemory(); +} + +void PhiValuesWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} + +char PhiValuesWrapperPass::ID = 0; + +INITIALIZE_PASS(PhiValuesWrapperPass, "phi-values", "Phi Values Analysis", false, + true) diff --git a/llvm/lib/Analysis/PostDominators.cpp b/llvm/lib/Analysis/PostDominators.cpp new file mode 100644 index 000000000000..4afe22bd5342 --- /dev/null +++ b/llvm/lib/Analysis/PostDominators.cpp @@ -0,0 +1,84 @@ +//===- PostDominators.cpp - Post-Dominator Calculation --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the post-dominator construction algorithms. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "postdomtree" + +#ifdef EXPENSIVE_CHECKS +static constexpr bool ExpensiveChecksEnabled = true; +#else +static constexpr bool ExpensiveChecksEnabled = false; +#endif + +//===----------------------------------------------------------------------===// +// PostDominatorTree Implementation +//===----------------------------------------------------------------------===// + +char PostDominatorTreeWrapperPass::ID = 0; + +INITIALIZE_PASS(PostDominatorTreeWrapperPass, "postdomtree", + "Post-Dominator Tree Construction", true, true) + +bool PostDominatorTree::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<PostDominatorTreeAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + +bool PostDominatorTreeWrapperPass::runOnFunction(Function &F) { + DT.recalculate(F); + return false; +} + +void PostDominatorTreeWrapperPass::verifyAnalysis() const { + if (VerifyDomInfo) + assert(DT.verify(PostDominatorTree::VerificationLevel::Full)); + else if (ExpensiveChecksEnabled) + assert(DT.verify(PostDominatorTree::VerificationLevel::Basic)); +} + +void PostDominatorTreeWrapperPass::print(raw_ostream &OS, const Module *) const { + DT.print(OS); +} + +FunctionPass* llvm::createPostDomTree() { + return new PostDominatorTreeWrapperPass(); +} + +AnalysisKey PostDominatorTreeAnalysis::Key; + +PostDominatorTree PostDominatorTreeAnalysis::run(Function &F, + FunctionAnalysisManager &) { + PostDominatorTree PDT(F); + return PDT; +} + +PostDominatorTreePrinterPass::PostDominatorTreePrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses +PostDominatorTreePrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "PostDominatorTree for function: " << F.getName() << "\n"; + AM.getResult<PostDominatorTreeAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp new file mode 100644 index 000000000000..b99b75715025 --- /dev/null +++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -0,0 +1,392 @@ +//===- ProfileSummaryInfo.cpp - Global profile summary information --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that provides access to the global profile summary +// information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ProfileSummary.h" +using namespace llvm; + +// The following two parameters determine the threshold for a count to be +// considered hot/cold. These two parameters are percentile values (multiplied +// by 10000). If the counts are sorted in descending order, the minimum count to +// reach ProfileSummaryCutoffHot gives the threshold to determine a hot count. +// Similarly, the minimum count to reach ProfileSummaryCutoffCold gives the +// threshold for determining cold count (everything <= this threshold is +// considered cold). + +static cl::opt<int> ProfileSummaryCutoffHot( + "profile-summary-cutoff-hot", cl::Hidden, cl::init(990000), cl::ZeroOrMore, + cl::desc("A count is hot if it exceeds the minimum count to" + " reach this percentile of total counts.")); + +static cl::opt<int> ProfileSummaryCutoffCold( + "profile-summary-cutoff-cold", cl::Hidden, cl::init(999999), cl::ZeroOrMore, + cl::desc("A count is cold if it is below the minimum count" + " to reach this percentile of total counts.")); + +static cl::opt<unsigned> ProfileSummaryHugeWorkingSetSizeThreshold( + "profile-summary-huge-working-set-size-threshold", cl::Hidden, + cl::init(15000), cl::ZeroOrMore, + cl::desc("The code working set size is considered huge if the number of" + " blocks required to reach the -profile-summary-cutoff-hot" + " percentile exceeds this count.")); + +static cl::opt<unsigned> ProfileSummaryLargeWorkingSetSizeThreshold( + "profile-summary-large-working-set-size-threshold", cl::Hidden, + cl::init(12500), cl::ZeroOrMore, + cl::desc("The code working set size is considered large if the number of" + " blocks required to reach the -profile-summary-cutoff-hot" + " percentile exceeds this count.")); + +// The next two options override the counts derived from summary computation and +// are useful for debugging purposes. +static cl::opt<int> ProfileSummaryHotCount( + "profile-summary-hot-count", cl::ReallyHidden, cl::ZeroOrMore, + cl::desc("A fixed hot count that overrides the count derived from" + " profile-summary-cutoff-hot")); + +static cl::opt<int> ProfileSummaryColdCount( + "profile-summary-cold-count", cl::ReallyHidden, cl::ZeroOrMore, + cl::desc("A fixed cold count that overrides the count derived from" + " profile-summary-cutoff-cold")); + +// Find the summary entry for a desired percentile of counts. +static const ProfileSummaryEntry &getEntryForPercentile(SummaryEntryVector &DS, + uint64_t Percentile) { + auto It = partition_point(DS, [=](const ProfileSummaryEntry &Entry) { + return Entry.Cutoff < Percentile; + }); + // The required percentile has to be <= one of the percentiles in the + // detailed summary. + if (It == DS.end()) + report_fatal_error("Desired percentile exceeds the maximum cutoff"); + return *It; +} + +// The profile summary metadata may be attached either by the frontend or by +// any backend passes (IR level instrumentation, for example). This method +// checks if the Summary is null and if so checks if the summary metadata is now +// available in the module and parses it to get the Summary object. Returns true +// if a valid Summary is available. +bool ProfileSummaryInfo::computeSummary() { + if (Summary) + return true; + // First try to get context sensitive ProfileSummary. + auto *SummaryMD = M.getProfileSummary(/* IsCS */ true); + if (SummaryMD) { + Summary.reset(ProfileSummary::getFromMD(SummaryMD)); + return true; + } + // This will actually return PSK_Instr or PSK_Sample summary. + SummaryMD = M.getProfileSummary(/* IsCS */ false); + if (!SummaryMD) + return false; + Summary.reset(ProfileSummary::getFromMD(SummaryMD)); + return true; +} + +Optional<uint64_t> +ProfileSummaryInfo::getProfileCount(const Instruction *Inst, + BlockFrequencyInfo *BFI, + bool AllowSynthetic) { + if (!Inst) + return None; + assert((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && + "We can only get profile count for call/invoke instruction."); + if (hasSampleProfile()) { + // In sample PGO mode, check if there is a profile metadata on the + // instruction. If it is present, determine hotness solely based on that, + // since the sampled entry count may not be accurate. If there is no + // annotated on the instruction, return None. + uint64_t TotalCount; + if (Inst->extractProfTotalWeight(TotalCount)) + return TotalCount; + return None; + } + if (BFI) + return BFI->getBlockProfileCount(Inst->getParent(), AllowSynthetic); + return None; +} + +/// Returns true if the function's entry is hot. If it returns false, it +/// either means it is not hot or it is unknown whether it is hot or not (for +/// example, no profile data is available). +bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) { + if (!F || !computeSummary()) + return false; + auto FunctionCount = F->getEntryCount(); + // FIXME: The heuristic used below for determining hotness is based on + // preliminary SPEC tuning for inliner. This will eventually be a + // convenience method that calls isHotCount. + return FunctionCount && isHotCount(FunctionCount.getCount()); +} + +/// Returns true if the function contains hot code. This can include a hot +/// function entry count, hot basic block, or (in the case of Sample PGO) +/// hot total call edge count. +/// If it returns false, it either means it is not hot or it is unknown +/// (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionHotInCallGraph(const Function *F, + BlockFrequencyInfo &BFI) { + if (!F || !computeSummary()) + return false; + if (auto FunctionCount = F->getEntryCount()) + if (isHotCount(FunctionCount.getCount())) + return true; + + if (hasSampleProfile()) { + uint64_t TotalCallCount = 0; + for (const auto &BB : *F) + for (const auto &I : BB) + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + if (auto CallCount = getProfileCount(&I, nullptr)) + TotalCallCount += CallCount.getValue(); + if (isHotCount(TotalCallCount)) + return true; + } + for (const auto &BB : *F) + if (isHotBlock(&BB, &BFI)) + return true; + return false; +} + +/// Returns true if the function only contains cold code. This means that +/// the function entry and blocks are all cold, and (in the case of Sample PGO) +/// the total call edge count is cold. +/// If it returns false, it either means it is not cold or it is unknown +/// (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionColdInCallGraph(const Function *F, + BlockFrequencyInfo &BFI) { + if (!F || !computeSummary()) + return false; + if (auto FunctionCount = F->getEntryCount()) + if (!isColdCount(FunctionCount.getCount())) + return false; + + if (hasSampleProfile()) { + uint64_t TotalCallCount = 0; + for (const auto &BB : *F) + for (const auto &I : BB) + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + if (auto CallCount = getProfileCount(&I, nullptr)) + TotalCallCount += CallCount.getValue(); + if (!isColdCount(TotalCallCount)) + return false; + } + for (const auto &BB : *F) + if (!isColdBlock(&BB, &BFI)) + return false; + return true; +} + +// Like isFunctionHotInCallGraph but for a given cutoff. +bool ProfileSummaryInfo::isFunctionHotInCallGraphNthPercentile( + int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) { + if (!F || !computeSummary()) + return false; + if (auto FunctionCount = F->getEntryCount()) + if (isHotCountNthPercentile(PercentileCutoff, FunctionCount.getCount())) + return true; + + if (hasSampleProfile()) { + uint64_t TotalCallCount = 0; + for (const auto &BB : *F) + for (const auto &I : BB) + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + if (auto CallCount = getProfileCount(&I, nullptr)) + TotalCallCount += CallCount.getValue(); + if (isHotCountNthPercentile(PercentileCutoff, TotalCallCount)) + return true; + } + for (const auto &BB : *F) + if (isHotBlockNthPercentile(PercentileCutoff, &BB, &BFI)) + return true; + return false; +} + +/// Returns true if the function's entry is a cold. If it returns false, it +/// either means it is not cold or it is unknown whether it is cold or not (for +/// example, no profile data is available). +bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) { + if (!F) + return false; + if (F->hasFnAttribute(Attribute::Cold)) + return true; + if (!computeSummary()) + return false; + auto FunctionCount = F->getEntryCount(); + // FIXME: The heuristic used below for determining coldness is based on + // preliminary SPEC tuning for inliner. This will eventually be a + // convenience method that calls isHotCount. + return FunctionCount && isColdCount(FunctionCount.getCount()); +} + +/// Compute the hot and cold thresholds. +void ProfileSummaryInfo::computeThresholds() { + if (!computeSummary()) + return; + auto &DetailedSummary = Summary->getDetailedSummary(); + auto &HotEntry = + getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffHot); + HotCountThreshold = HotEntry.MinCount; + if (ProfileSummaryHotCount.getNumOccurrences() > 0) + HotCountThreshold = ProfileSummaryHotCount; + auto &ColdEntry = + getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffCold); + ColdCountThreshold = ColdEntry.MinCount; + if (ProfileSummaryColdCount.getNumOccurrences() > 0) + ColdCountThreshold = ProfileSummaryColdCount; + assert(ColdCountThreshold <= HotCountThreshold && + "Cold count threshold cannot exceed hot count threshold!"); + HasHugeWorkingSetSize = + HotEntry.NumCounts > ProfileSummaryHugeWorkingSetSizeThreshold; + HasLargeWorkingSetSize = + HotEntry.NumCounts > ProfileSummaryLargeWorkingSetSizeThreshold; +} + +Optional<uint64_t> ProfileSummaryInfo::computeThreshold(int PercentileCutoff) { + if (!computeSummary()) + return None; + auto iter = ThresholdCache.find(PercentileCutoff); + if (iter != ThresholdCache.end()) { + return iter->second; + } + auto &DetailedSummary = Summary->getDetailedSummary(); + auto &Entry = + getEntryForPercentile(DetailedSummary, PercentileCutoff); + uint64_t CountThreshold = Entry.MinCount; + ThresholdCache[PercentileCutoff] = CountThreshold; + return CountThreshold; +} + +bool ProfileSummaryInfo::hasHugeWorkingSetSize() { + if (!HasHugeWorkingSetSize) + computeThresholds(); + return HasHugeWorkingSetSize && HasHugeWorkingSetSize.getValue(); +} + +bool ProfileSummaryInfo::hasLargeWorkingSetSize() { + if (!HasLargeWorkingSetSize) + computeThresholds(); + return HasLargeWorkingSetSize && HasLargeWorkingSetSize.getValue(); +} + +bool ProfileSummaryInfo::isHotCount(uint64_t C) { + if (!HotCountThreshold) + computeThresholds(); + return HotCountThreshold && C >= HotCountThreshold.getValue(); +} + +bool ProfileSummaryInfo::isColdCount(uint64_t C) { + if (!ColdCountThreshold) + computeThresholds(); + return ColdCountThreshold && C <= ColdCountThreshold.getValue(); +} + +bool ProfileSummaryInfo::isHotCountNthPercentile(int PercentileCutoff, uint64_t C) { + auto CountThreshold = computeThreshold(PercentileCutoff); + return CountThreshold && C >= CountThreshold.getValue(); +} + +uint64_t ProfileSummaryInfo::getOrCompHotCountThreshold() { + if (!HotCountThreshold) + computeThresholds(); + return HotCountThreshold ? HotCountThreshold.getValue() : UINT64_MAX; +} + +uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() { + if (!ColdCountThreshold) + computeThresholds(); + return ColdCountThreshold ? ColdCountThreshold.getValue() : 0; +} + +bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(BB); + return Count && isHotCount(*Count); +} + +bool ProfileSummaryInfo::isColdBlock(const BasicBlock *BB, + BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(BB); + return Count && isColdCount(*Count); +} + +bool ProfileSummaryInfo::isHotBlockNthPercentile(int PercentileCutoff, + const BasicBlock *BB, + BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(BB); + return Count && isHotCountNthPercentile(PercentileCutoff, *Count); +} + +bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto C = getProfileCount(CS.getInstruction(), BFI); + return C && isHotCount(*C); +} + +bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto C = getProfileCount(CS.getInstruction(), BFI); + if (C) + return isColdCount(*C); + + // In SamplePGO, if the caller has been sampled, and there is no profile + // annotated on the callsite, we consider the callsite as cold. + return hasSampleProfile() && CS.getCaller()->hasProfileData(); +} + +INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", + "Profile summary info", false, true) + +ProfileSummaryInfoWrapperPass::ProfileSummaryInfoWrapperPass() + : ImmutablePass(ID) { + initializeProfileSummaryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ProfileSummaryInfoWrapperPass::doInitialization(Module &M) { + PSI.reset(new ProfileSummaryInfo(M)); + return false; +} + +bool ProfileSummaryInfoWrapperPass::doFinalization(Module &M) { + PSI.reset(); + return false; +} + +AnalysisKey ProfileSummaryAnalysis::Key; +ProfileSummaryInfo ProfileSummaryAnalysis::run(Module &M, + ModuleAnalysisManager &) { + return ProfileSummaryInfo(M); +} + +PreservedAnalyses ProfileSummaryPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); + + OS << "Functions in " << M.getName() << " with hot/cold annotations: \n"; + for (auto &F : M) { + OS << F.getName(); + if (PSI.isFunctionEntryHot(&F)) + OS << " :hot entry "; + else if (PSI.isFunctionEntryCold(&F)) + OS << " :cold entry "; + OS << "\n"; + } + return PreservedAnalyses::all(); +} + +char ProfileSummaryInfoWrapperPass::ID = 0; diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp new file mode 100644 index 000000000000..9a834ba4866a --- /dev/null +++ b/llvm/lib/Analysis/PtrUseVisitor.cpp @@ -0,0 +1,44 @@ +//===- PtrUseVisitor.cpp - InstVisitors over a pointers uses --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Implementation of the pointer use visitors. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PtrUseVisitor.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include <algorithm> + +using namespace llvm; + +void detail::PtrUseVisitorBase::enqueueUsers(Instruction &I) { + for (Use &U : I.uses()) { + if (VisitedUses.insert(&U).second) { + UseToVisit NewU = { + UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), + Offset + }; + Worklist.push_back(std::move(NewU)); + } + } +} + +bool detail::PtrUseVisitorBase::adjustOffsetForGEP(GetElementPtrInst &GEPI) { + if (!IsOffsetKnown) + return false; + + APInt TmpOffset(DL.getIndexTypeSizeInBits(GEPI.getType()), 0); + if (GEPI.accumulateConstantOffset(DL, TmpOffset)) { + Offset += TmpOffset.sextOrTrunc(Offset.getBitWidth()); + return true; + } + + return false; +} diff --git a/llvm/lib/Analysis/RegionInfo.cpp b/llvm/lib/Analysis/RegionInfo.cpp new file mode 100644 index 000000000000..8ba38adfb0d2 --- /dev/null +++ b/llvm/lib/Analysis/RegionInfo.cpp @@ -0,0 +1,215 @@ +//===- RegionInfo.cpp - SESE region detection analysis --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Detects single entry single exit regions in the control flow graph. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/ADT/Statistic.h" +#ifndef NDEBUG +#include "llvm/Analysis/RegionPrinter.h" +#endif +#include "llvm/Analysis/RegionInfoImpl.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "region" + +namespace llvm { + +template class RegionBase<RegionTraits<Function>>; +template class RegionNodeBase<RegionTraits<Function>>; +template class RegionInfoBase<RegionTraits<Function>>; + +} // end namespace llvm + +STATISTIC(numRegions, "The # of regions"); +STATISTIC(numSimpleRegions, "The # of simple regions"); + +// Always verify if expensive checking is enabled. + +static cl::opt<bool,true> +VerifyRegionInfoX( + "verify-region-info", + cl::location(RegionInfoBase<RegionTraits<Function>>::VerifyRegionInfo), + cl::desc("Verify region info (time consuming)")); + +static cl::opt<Region::PrintStyle, true> printStyleX("print-region-style", + cl::location(RegionInfo::printStyle), + cl::Hidden, + cl::desc("style of printing regions"), + cl::values( + clEnumValN(Region::PrintNone, "none", "print no details"), + clEnumValN(Region::PrintBB, "bb", + "print regions in detail with block_iterator"), + clEnumValN(Region::PrintRN, "rn", + "print regions in detail with element_iterator"))); + +//===----------------------------------------------------------------------===// +// Region implementation +// + +Region::Region(BasicBlock *Entry, BasicBlock *Exit, + RegionInfo* RI, + DominatorTree *DT, Region *Parent) : + RegionBase<RegionTraits<Function>>(Entry, Exit, RI, DT, Parent) { + +} + +Region::~Region() = default; + +//===----------------------------------------------------------------------===// +// RegionInfo implementation +// + +RegionInfo::RegionInfo() = default; + +RegionInfo::~RegionInfo() = default; + +bool RegionInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG has been preserved. + auto PAC = PA.getChecker<RegionInfoAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + +void RegionInfo::updateStatistics(Region *R) { + ++numRegions; + + // TODO: Slow. Should only be enabled if -stats is used. + if (R->isSimple()) + ++numSimpleRegions; +} + +void RegionInfo::recalculate(Function &F, DominatorTree *DT_, + PostDominatorTree *PDT_, DominanceFrontier *DF_) { + DT = DT_; + PDT = PDT_; + DF = DF_; + + TopLevelRegion = new Region(&F.getEntryBlock(), nullptr, + this, DT, nullptr); + updateStatistics(TopLevelRegion); + calculate(F); +} + +#ifndef NDEBUG +void RegionInfo::view() { viewRegion(this); } + +void RegionInfo::viewOnly() { viewRegionOnly(this); } +#endif + +//===----------------------------------------------------------------------===// +// RegionInfoPass implementation +// + +RegionInfoPass::RegionInfoPass() : FunctionPass(ID) { + initializeRegionInfoPassPass(*PassRegistry::getPassRegistry()); +} + +RegionInfoPass::~RegionInfoPass() = default; + +bool RegionInfoPass::runOnFunction(Function &F) { + releaseMemory(); + + auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + auto DF = &getAnalysis<DominanceFrontierWrapperPass>().getDominanceFrontier(); + + RI.recalculate(F, DT, PDT, DF); + return false; +} + +void RegionInfoPass::releaseMemory() { + RI.releaseMemory(); +} + +void RegionInfoPass::verifyAnalysis() const { + RI.verifyAnalysis(); +} + +void RegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<DominanceFrontierWrapperPass>(); +} + +void RegionInfoPass::print(raw_ostream &OS, const Module *) const { + RI.print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void RegionInfoPass::dump() const { + RI.dump(); +} +#endif + +char RegionInfoPass::ID = 0; + +INITIALIZE_PASS_BEGIN(RegionInfoPass, "regions", + "Detect single entry single exit regions", true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominanceFrontierWrapperPass) +INITIALIZE_PASS_END(RegionInfoPass, "regions", + "Detect single entry single exit regions", true, true) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +namespace llvm { + + FunctionPass *createRegionInfoPass() { + return new RegionInfoPass(); + } + +} // end namespace llvm + +//===----------------------------------------------------------------------===// +// RegionInfoAnalysis implementation +// + +AnalysisKey RegionInfoAnalysis::Key; + +RegionInfo RegionInfoAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + RegionInfo RI; + auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F); + auto *DF = &AM.getResult<DominanceFrontierAnalysis>(F); + + RI.recalculate(F, DT, PDT, DF); + return RI; +} + +RegionInfoPrinterPass::RegionInfoPrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses RegionInfoPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "Region Tree for function: " << F.getName() << "\n"; + AM.getResult<RegionInfoAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses RegionInfoVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<RegionInfoAnalysis>(F).verifyAnalysis(); + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/RegionPass.cpp b/llvm/lib/Analysis/RegionPass.cpp new file mode 100644 index 000000000000..6c0d17b45c62 --- /dev/null +++ b/llvm/lib/Analysis/RegionPass.cpp @@ -0,0 +1,299 @@ +//===- RegionPass.cpp - Region Pass and Region Pass Manager ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements RegionPass and RGPassManager. All region optimization +// and transformation passes are derived from RegionPass. RGPassManager is +// responsible for managing RegionPasses. +// Most of this code has been COPIED from LoopPass.cpp +// +//===----------------------------------------------------------------------===// +#include "llvm/Analysis/RegionPass.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassTimingInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "regionpassmgr" + +//===----------------------------------------------------------------------===// +// RGPassManager +// + +char RGPassManager::ID = 0; + +RGPassManager::RGPassManager() + : FunctionPass(ID), PMDataManager() { + skipThisRegion = false; + redoThisRegion = false; + RI = nullptr; + CurrentRegion = nullptr; +} + +// Recurse through all subregions and all regions into RQ. +static void addRegionIntoQueue(Region &R, std::deque<Region *> &RQ) { + RQ.push_back(&R); + for (const auto &E : R) + addRegionIntoQueue(*E, RQ); +} + +/// Pass Manager itself does not invalidate any analysis info. +void RGPassManager::getAnalysisUsage(AnalysisUsage &Info) const { + Info.addRequired<RegionInfoPass>(); + Info.setPreservesAll(); +} + +/// run - Execute all of the passes scheduled for execution. Keep track of +/// whether any of the passes modifies the function, and if so, return true. +bool RGPassManager::runOnFunction(Function &F) { + RI = &getAnalysis<RegionInfoPass>().getRegionInfo(); + bool Changed = false; + + // Collect inherited analysis from Module level pass manager. + populateInheritedAnalysis(TPM->activeStack); + + addRegionIntoQueue(*RI->getTopLevelRegion(), RQ); + + if (RQ.empty()) // No regions, skip calling finalizers + return false; + + // Initialization + for (Region *R : RQ) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + RegionPass *RP = (RegionPass *)getContainedPass(Index); + Changed |= RP->doInitialization(R, *this); + } + } + + // Walk Regions + while (!RQ.empty()) { + + CurrentRegion = RQ.back(); + skipThisRegion = false; + redoThisRegion = false; + + // Run all passes on the current Region. + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + RegionPass *P = (RegionPass*)getContainedPass(Index); + + if (isPassDebuggingExecutionsOrMore()) { + dumpPassInfo(P, EXECUTION_MSG, ON_REGION_MSG, + CurrentRegion->getNameStr()); + dumpRequiredSet(P); + } + + initializeAnalysisImpl(P); + + { + PassManagerPrettyStackEntry X(P, *CurrentRegion->getEntry()); + + TimeRegion PassTimer(getPassTimer(P)); + Changed |= P->runOnRegion(CurrentRegion, *this); + } + + if (isPassDebuggingExecutionsOrMore()) { + if (Changed) + dumpPassInfo(P, MODIFICATION_MSG, ON_REGION_MSG, + skipThisRegion ? "<deleted>" : + CurrentRegion->getNameStr()); + dumpPreservedSet(P); + } + + if (!skipThisRegion) { + // Manually check that this region is still healthy. This is done + // instead of relying on RegionInfo::verifyRegion since RegionInfo + // is a function pass and it's really expensive to verify every + // Region in the function every time. That level of checking can be + // enabled with the -verify-region-info option. + { + TimeRegion PassTimer(getPassTimer(P)); + CurrentRegion->verifyRegion(); + } + + // Then call the regular verifyAnalysis functions. + verifyPreservedAnalysis(P); + } + + removeNotPreservedAnalysis(P); + recordAvailableAnalysis(P); + removeDeadPasses(P, + (!isPassDebuggingExecutionsOrMore() || skipThisRegion) ? + "<deleted>" : CurrentRegion->getNameStr(), + ON_REGION_MSG); + + if (skipThisRegion) + // Do not run other passes on this region. + break; + } + + // If the region was deleted, release all the region passes. This frees up + // some memory, and avoids trouble with the pass manager trying to call + // verifyAnalysis on them. + if (skipThisRegion) + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + freePass(P, "<deleted>", ON_REGION_MSG); + } + + // Pop the region from queue after running all passes. + RQ.pop_back(); + + if (redoThisRegion) + RQ.push_back(CurrentRegion); + + // Free all region nodes created in region passes. + RI->clearNodeCache(); + } + + // Finalization + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + RegionPass *P = (RegionPass*)getContainedPass(Index); + Changed |= P->doFinalization(); + } + + // Print the region tree after all pass. + LLVM_DEBUG(dbgs() << "\nRegion tree of function " << F.getName() + << " after all region Pass:\n"; + RI->dump(); dbgs() << "\n";); + + return Changed; +} + +/// Print passes managed by this manager +void RGPassManager::dumpPassStructure(unsigned Offset) { + errs().indent(Offset*2) << "Region Pass Manager\n"; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + P->dumpPassStructure(Offset + 1); + dumpLastUses(P, Offset+1); + } +} + +namespace { +//===----------------------------------------------------------------------===// +// PrintRegionPass +class PrintRegionPass : public RegionPass { +private: + std::string Banner; + raw_ostream &Out; // raw_ostream to print on. + +public: + static char ID; + PrintRegionPass(const std::string &B, raw_ostream &o) + : RegionPass(ID), Banner(B), Out(o) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + bool runOnRegion(Region *R, RGPassManager &RGM) override { + Out << Banner; + for (const auto *BB : R->blocks()) { + if (BB) + BB->print(Out); + else + Out << "Printing <null> Block"; + } + + return false; + } + + StringRef getPassName() const override { return "Print Region IR"; } +}; + +char PrintRegionPass::ID = 0; +} //end anonymous namespace + +//===----------------------------------------------------------------------===// +// RegionPass + +// Check if this pass is suitable for the current RGPassManager, if +// available. This pass P is not suitable for a RGPassManager if P +// is not preserving higher level analysis info used by other +// RGPassManager passes. In such case, pop RGPassManager from the +// stack. This will force assignPassManager() to create new +// LPPassManger as expected. +void RegionPass::preparePassManager(PMStack &PMS) { + + // Find RGPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_RegionPassManager) + PMS.pop(); + + + // If this pass is destroying high level information that is used + // by other passes that are managed by LPM then do not insert + // this pass in current LPM. Use new RGPassManager. + if (PMS.top()->getPassManagerType() == PMT_RegionPassManager && + !PMS.top()->preserveHigherLevelAnalysis(this)) + PMS.pop(); +} + +/// Assign pass manager to manage this pass. +void RegionPass::assignPassManager(PMStack &PMS, + PassManagerType PreferredType) { + // Find RGPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_RegionPassManager) + PMS.pop(); + + RGPassManager *RGPM; + + // Create new Region Pass Manager if it does not exist. + if (PMS.top()->getPassManagerType() == PMT_RegionPassManager) + RGPM = (RGPassManager*)PMS.top(); + else { + + assert (!PMS.empty() && "Unable to create Region Pass Manager"); + PMDataManager *PMD = PMS.top(); + + // [1] Create new Region Pass Manager + RGPM = new RGPassManager(); + RGPM->populateInheritedAnalysis(PMS); + + // [2] Set up new manager's top level manager + PMTopLevelManager *TPM = PMD->getTopLevelManager(); + TPM->addIndirectPassManager(RGPM); + + // [3] Assign manager to manage this new manager. This may create + // and push new managers into PMS + TPM->schedulePass(RGPM); + + // [4] Push new manager into PMS + PMS.push(RGPM); + } + + RGPM->add(this); +} + +/// Get the printer pass +Pass *RegionPass::createPrinterPass(raw_ostream &O, + const std::string &Banner) const { + return new PrintRegionPass(Banner, O); +} + +static std::string getDescription(const Region &R) { + return "region"; +} + +bool RegionPass::skipRegion(Region &R) const { + Function &F = *R.getEntry()->getParent(); + OptPassGate &Gate = F.getContext().getOptPassGate(); + if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(R))) + return true; + + if (F.hasOptNone()) { + // Report this only once per function. + if (R.getEntry() == &F.getEntryBlock()) + LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName() + << "' on function " << F.getName() << "\n"); + return true; + } + return false; +} diff --git a/llvm/lib/Analysis/RegionPrinter.cpp b/llvm/lib/Analysis/RegionPrinter.cpp new file mode 100644 index 000000000000..5bdcb31fbe99 --- /dev/null +++ b/llvm/lib/Analysis/RegionPrinter.cpp @@ -0,0 +1,266 @@ +//===- RegionPrinter.cpp - Print regions tree pass ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Print out the region tree of a function using dotty/graphviz. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/RegionPrinter.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionIterator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#ifndef NDEBUG +#include "llvm/IR/LegacyPassManager.h" +#endif + +using namespace llvm; + +//===----------------------------------------------------------------------===// +/// onlySimpleRegion - Show only the simple regions in the RegionViewer. +static cl::opt<bool> +onlySimpleRegions("only-simple-regions", + cl::desc("Show only simple regions in the graphviz viewer"), + cl::Hidden, + cl::init(false)); + +namespace llvm { +template<> +struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits { + + DOTGraphTraits (bool isSimple=false) + : DefaultDOTGraphTraits(isSimple) {} + + std::string getNodeLabel(RegionNode *Node, RegionNode *Graph) { + + if (!Node->isSubRegion()) { + BasicBlock *BB = Node->getNodeAs<BasicBlock>(); + + if (isSimple()) + return DOTGraphTraits<const Function*> + ::getSimpleNodeLabel(BB, BB->getParent()); + else + return DOTGraphTraits<const Function*> + ::getCompleteNodeLabel(BB, BB->getParent()); + } + + return "Not implemented"; + } +}; + +template <> +struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { + + DOTGraphTraits (bool isSimple = false) + : DOTGraphTraits<RegionNode*>(isSimple) {} + + static std::string getGraphName(const RegionInfo *) { return "Region Graph"; } + + std::string getNodeLabel(RegionNode *Node, RegionInfo *G) { + return DOTGraphTraits<RegionNode *>::getNodeLabel( + Node, reinterpret_cast<RegionNode *>(G->getTopLevelRegion())); + } + + std::string getEdgeAttributes(RegionNode *srcNode, + GraphTraits<RegionInfo *>::ChildIteratorType CI, + RegionInfo *G) { + RegionNode *destNode = *CI; + + if (srcNode->isSubRegion() || destNode->isSubRegion()) + return ""; + + // In case of a backedge, do not use it to define the layout of the nodes. + BasicBlock *srcBB = srcNode->getNodeAs<BasicBlock>(); + BasicBlock *destBB = destNode->getNodeAs<BasicBlock>(); + + Region *R = G->getRegionFor(destBB); + + while (R && R->getParent()) + if (R->getParent()->getEntry() == destBB) + R = R->getParent(); + else + break; + + if (R && R->getEntry() == destBB && R->contains(srcBB)) + return "constraint=false"; + + return ""; + } + + // Print the cluster of the subregions. This groups the single basic blocks + // and adds a different background color for each group. + static void printRegionCluster(const Region &R, GraphWriter<RegionInfo *> &GW, + unsigned depth = 0) { + raw_ostream &O = GW.getOStream(); + O.indent(2 * depth) << "subgraph cluster_" << static_cast<const void*>(&R) + << " {\n"; + O.indent(2 * (depth + 1)) << "label = \"\";\n"; + + if (!onlySimpleRegions || R.isSimple()) { + O.indent(2 * (depth + 1)) << "style = filled;\n"; + O.indent(2 * (depth + 1)) << "color = " + << ((R.getDepth() * 2 % 12) + 1) << "\n"; + + } else { + O.indent(2 * (depth + 1)) << "style = solid;\n"; + O.indent(2 * (depth + 1)) << "color = " + << ((R.getDepth() * 2 % 12) + 2) << "\n"; + } + + for (const auto &RI : R) + printRegionCluster(*RI, GW, depth + 1); + + const RegionInfo &RI = *static_cast<const RegionInfo*>(R.getRegionInfo()); + + for (auto *BB : R.blocks()) + if (RI.getRegionFor(BB) == &R) + O.indent(2 * (depth + 1)) << "Node" + << static_cast<const void*>(RI.getTopLevelRegion()->getBBNode(BB)) + << ";\n"; + + O.indent(2 * depth) << "}\n"; + } + + static void addCustomGraphFeatures(const RegionInfo *G, + GraphWriter<RegionInfo *> &GW) { + raw_ostream &O = GW.getOStream(); + O << "\tcolorscheme = \"paired12\"\n"; + printRegionCluster(*G->getTopLevelRegion(), GW, 4); + } +}; +} //end namespace llvm + +namespace { + +struct RegionInfoPassGraphTraits { + static RegionInfo *getGraph(RegionInfoPass *RIP) { + return &RIP->getRegionInfo(); + } +}; + +struct RegionPrinter + : public DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits> { + static char ID; + RegionPrinter() + : DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { + initializeRegionPrinterPass(*PassRegistry::getPassRegistry()); + } +}; +char RegionPrinter::ID = 0; + +struct RegionOnlyPrinter + : public DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits> { + static char ID; + RegionOnlyPrinter() + : DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { + initializeRegionOnlyPrinterPass(*PassRegistry::getPassRegistry()); + } +}; +char RegionOnlyPrinter::ID = 0; + +struct RegionViewer + : public DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits> { + static char ID; + RegionViewer() + : DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { + initializeRegionViewerPass(*PassRegistry::getPassRegistry()); + } +}; +char RegionViewer::ID = 0; + +struct RegionOnlyViewer + : public DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits> { + static char ID; + RegionOnlyViewer() + : DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits>("regonly", ID) { + initializeRegionOnlyViewerPass(*PassRegistry::getPassRegistry()); + } +}; +char RegionOnlyViewer::ID = 0; + +} //end anonymous namespace + +INITIALIZE_PASS(RegionPrinter, "dot-regions", + "Print regions of function to 'dot' file", true, true) + +INITIALIZE_PASS( + RegionOnlyPrinter, "dot-regions-only", + "Print regions of function to 'dot' file (with no function bodies)", true, + true) + +INITIALIZE_PASS(RegionViewer, "view-regions", "View regions of function", + true, true) + +INITIALIZE_PASS(RegionOnlyViewer, "view-regions-only", + "View regions of function (with no function bodies)", + true, true) + +FunctionPass *llvm::createRegionPrinterPass() { return new RegionPrinter(); } + +FunctionPass *llvm::createRegionOnlyPrinterPass() { + return new RegionOnlyPrinter(); +} + +FunctionPass* llvm::createRegionViewerPass() { + return new RegionViewer(); +} + +FunctionPass* llvm::createRegionOnlyViewerPass() { + return new RegionOnlyViewer(); +} + +#ifndef NDEBUG +static void viewRegionInfo(RegionInfo *RI, bool ShortNames) { + assert(RI && "Argument must be non-null"); + + llvm::Function *F = RI->getTopLevelRegion()->getEntry()->getParent(); + std::string GraphName = DOTGraphTraits<RegionInfo *>::getGraphName(RI); + + llvm::ViewGraph(RI, "reg", ShortNames, + Twine(GraphName) + " for '" + F->getName() + "' function"); +} + +static void invokeFunctionPass(const Function *F, FunctionPass *ViewerPass) { + assert(F && "Argument must be non-null"); + assert(!F->isDeclaration() && "Function must have an implementation"); + + // The viewer and analysis passes do not modify anything, so we can safely + // remove the const qualifier + auto NonConstF = const_cast<Function *>(F); + + llvm::legacy::FunctionPassManager FPM(NonConstF->getParent()); + FPM.add(ViewerPass); + FPM.doInitialization(); + FPM.run(*NonConstF); + FPM.doFinalization(); +} + +void llvm::viewRegion(RegionInfo *RI) { viewRegionInfo(RI, false); } + +void llvm::viewRegion(const Function *F) { + invokeFunctionPass(F, createRegionViewerPass()); +} + +void llvm::viewRegionOnly(RegionInfo *RI) { viewRegionInfo(RI, true); } + +void llvm::viewRegionOnly(const Function *F) { + invokeFunctionPass(F, createRegionOnlyViewerPass()); +} +#endif diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp new file mode 100644 index 000000000000..5ce0a1adeaa0 --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -0,0 +1,12530 @@ +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution analysis +// engine, which is used primarily to analyze expressions involving induction +// variables in loops. +// +// There are several aspects to this library. First is the representation of +// scalar expressions, which are represented as subclasses of the SCEV class. +// These classes are used to represent certain types of subexpressions that we +// can handle. We only create one SCEV of a particular shape, so +// pointer-comparisons for equality are legal. +// +// One important aspect of the SCEV objects is that they are never cyclic, even +// if there is a cycle in the dataflow for an expression (ie, a PHI node). If +// the PHI node is one of the idioms that we can represent (e.g., a polynomial +// recurrence) then we represent it directly as a recurrence node, otherwise we +// represent it as a SCEVUnknown node. +// +// In addition to being able to represent expressions of various types, we also +// have folders that are used to build the *canonical* representation for a +// particular expression. These folders are capable of using a variety of +// rewrite rules to simplify the expressions. +// +// Once the folders are defined, we can implement the more interesting +// higher-level code, such as the code that recognizes PHI nodes of various +// types, computes the execution count of a loop, etc. +// +// TODO: We should use these routines and value representations to implement +// dependence analysis! +// +//===----------------------------------------------------------------------===// +// +// There are several good references for the techniques used in this analysis. +// +// Chains of recurrences -- a method to expedite the evaluation +// of closed-form functions +// Olaf Bachmann, Paul S. Wang, Eugene V. Zima +// +// On computational properties of chains of recurrences +// Eugene V. Zima +// +// Symbolic Evaluation of Chains of Recurrences for Loop Optimization +// Robert A. van Engelen +// +// Efficient Symbolic Analysis for Optimizing Compilers +// Robert A. van Engelen +// +// Using the chains of recurrences algebra for data dependence testing and +// induction variable substitution +// MS Thesis, Johnie Birch +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <climits> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <map> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "scalar-evolution" + +STATISTIC(NumArrayLenItCounts, + "Number of trip counts computed with array length"); +STATISTIC(NumTripCountsComputed, + "Number of loops with predictable loop counts"); +STATISTIC(NumTripCountsNotComputed, + "Number of loops without predictable loop counts"); +STATISTIC(NumBruteForceTripCountsComputed, + "Number of loops with trip counts computed by force"); + +static cl::opt<unsigned> +MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, + cl::ZeroOrMore, + cl::desc("Maximum number of iterations SCEV will " + "symbolically execute a constant " + "derived loop"), + cl::init(100)); + +// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. +static cl::opt<bool> VerifySCEV( + "verify-scev", cl::Hidden, + cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); +static cl::opt<bool> VerifySCEVStrict( + "verify-scev-strict", cl::Hidden, + cl::desc("Enable stricter verification with -verify-scev is passed")); +static cl::opt<bool> + VerifySCEVMap("verify-scev-maps", cl::Hidden, + cl::desc("Verify no dangling value in ScalarEvolution's " + "ExprValueMap (slow)")); + +static cl::opt<bool> VerifyIR( + "scev-verify-ir", cl::Hidden, + cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), + cl::init(false)); + +static cl::opt<unsigned> MulOpsInlineThreshold( + "scev-mulops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining multiplication operands into a SCEV"), + cl::init(32)); + +static cl::opt<unsigned> AddOpsInlineThreshold( + "scev-addops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining addition operands into a SCEV"), + cl::init(500)); + +static cl::opt<unsigned> MaxSCEVCompareDepth( + "scalar-evolution-max-scev-compare-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV complexity comparisons"), + cl::init(32)); + +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( + "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV operations implication analysis"), + cl::init(2)); + +static cl::opt<unsigned> MaxValueCompareDepth( + "scalar-evolution-max-value-compare-depth", cl::Hidden, + cl::desc("Maximum depth of recursive value complexity comparisons"), + cl::init(2)); + +static cl::opt<unsigned> + MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, + cl::desc("Maximum depth of recursive arithmetics"), + cl::init(32)); + +static cl::opt<unsigned> MaxConstantEvolvingDepth( + "scalar-evolution-max-constant-evolving-depth", cl::Hidden, + cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); + +static cl::opt<unsigned> + MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), + cl::init(8)); + +static cl::opt<unsigned> + MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, + cl::desc("Max coefficients in AddRec during evolving"), + cl::init(8)); + +static cl::opt<unsigned> + HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, + cl::desc("Size of the expression which is considered huge"), + cl::init(4096)); + +//===----------------------------------------------------------------------===// +// SCEV class definitions +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Implementation of the SCEV class. +// + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { + print(dbgs()); + dbgs() << '\n'; +} +#endif + +void SCEV::print(raw_ostream &OS) const { + switch (static_cast<SCEVTypes>(getSCEVType())) { + case scConstant: + cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); + return; + case scTruncate: { + const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); + const SCEV *Op = Trunc->getOperand(); + OS << "(trunc " << *Op->getType() << " " << *Op << " to " + << *Trunc->getType() << ")"; + return; + } + case scZeroExtend: { + const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this); + const SCEV *Op = ZExt->getOperand(); + OS << "(zext " << *Op->getType() << " " << *Op << " to " + << *ZExt->getType() << ")"; + return; + } + case scSignExtend: { + const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this); + const SCEV *Op = SExt->getOperand(); + OS << "(sext " << *Op->getType() << " " << *Op << " to " + << *SExt->getType() << ")"; + return; + } + case scAddRecExpr: { + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this); + OS << "{" << *AR->getOperand(0); + for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) + OS << ",+," << *AR->getOperand(i); + OS << "}<"; + if (AR->hasNoUnsignedWrap()) + OS << "nuw><"; + if (AR->hasNoSignedWrap()) + OS << "nsw><"; + if (AR->hasNoSelfWrap() && + !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) + OS << "nw><"; + AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ">"; + return; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: { + const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this); + const char *OpStr = nullptr; + switch (NAry->getSCEVType()) { + case scAddExpr: OpStr = " + "; break; + case scMulExpr: OpStr = " * "; break; + case scUMaxExpr: OpStr = " umax "; break; + case scSMaxExpr: OpStr = " smax "; break; + case scUMinExpr: + OpStr = " umin "; + break; + case scSMinExpr: + OpStr = " smin "; + break; + } + OS << "("; + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + OS << **I; + if (std::next(I) != E) + OS << OpStr; + } + OS << ")"; + switch (NAry->getSCEVType()) { + case scAddExpr: + case scMulExpr: + if (NAry->hasNoUnsignedWrap()) + OS << "<nuw>"; + if (NAry->hasNoSignedWrap()) + OS << "<nsw>"; + } + return; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this); + OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; + return; + } + case scUnknown: { + const SCEVUnknown *U = cast<SCEVUnknown>(this); + Type *AllocTy; + if (U->isSizeOf(AllocTy)) { + OS << "sizeof(" << *AllocTy << ")"; + return; + } + if (U->isAlignOf(AllocTy)) { + OS << "alignof(" << *AllocTy << ")"; + return; + } + + Type *CTy; + Constant *FieldNo; + if (U->isOffsetOf(CTy, FieldNo)) { + OS << "offsetof(" << *CTy << ", "; + FieldNo->printAsOperand(OS, false); + OS << ")"; + return; + } + + // Otherwise just print it normally. + U->getValue()->printAsOperand(OS, false); + return; + } + case scCouldNotCompute: + OS << "***COULDNOTCOMPUTE***"; + return; + } + llvm_unreachable("Unknown SCEV kind!"); +} + +Type *SCEV::getType() const { + switch (static_cast<SCEVTypes>(getSCEVType())) { + case scConstant: + return cast<SCEVConstant>(this)->getType(); + case scTruncate: + case scZeroExtend: + case scSignExtend: + return cast<SCEVCastExpr>(this)->getType(); + case scAddRecExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: + return cast<SCEVNAryExpr>(this)->getType(); + case scAddExpr: + return cast<SCEVAddExpr>(this)->getType(); + case scUDivExpr: + return cast<SCEVUDivExpr>(this)->getType(); + case scUnknown: + return cast<SCEVUnknown>(this)->getType(); + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} + +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) + return SC->getValue()->isZero(); + return false; +} + +bool SCEV::isOne() const { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) + return SC->getValue()->isOne(); + return false; +} + +bool SCEV::isAllOnesValue() const { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) + return SC->getValue()->isMinusOne(); + return false; +} + +bool SCEV::isNonConstantNegative() const { + const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); + if (!Mul) return false; + + // If there is a constant factor, it will be first. + const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); + if (!SC) return false; + + // Return true if the value is negative, this matches things like (-42 * V). + return SC->getAPInt().isNegative(); +} + +SCEVCouldNotCompute::SCEVCouldNotCompute() : + SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {} + +bool SCEVCouldNotCompute::classof(const SCEV *S) { + return S->getSCEVType() == scCouldNotCompute; +} + +const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { + FoldingSetNodeID ID; + ID.AddInteger(scConstant); + ID.AddPointer(V); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + +const SCEV *ScalarEvolution::getConstant(const APInt &Val) { + return getConstant(ConstantInt::get(getContext(), Val)); +} + +const SCEV * +ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { + IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty)); + return getConstant(ConstantInt::get(ITy, V, isSigned)); +} + +SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, + unsigned SCEVTy, const SCEV *op, Type *ty) + : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} + +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, + const SCEV *op, Type *ty) + : SCEVCastExpr(ID, scTruncate, op, ty) { + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot truncate non-integer value!"); +} + +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, + const SCEV *op, Type *ty) + : SCEVCastExpr(ID, scZeroExtend, op, ty) { + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot zero extend non-integer value!"); +} + +SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, + const SCEV *op, Type *ty) + : SCEVCastExpr(ID, scSignExtend, op, ty) { + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot sign extend non-integer value!"); +} + +void SCEVUnknown::deleted() { + // Clear this SCEVUnknown from various maps. + SE->forgetMemoizedResults(this); + + // Remove this SCEVUnknown from the uniquing map. + SE->UniqueSCEVs.RemoveNode(this); + + // Release the value. + setValPtr(nullptr); +} + +void SCEVUnknown::allUsesReplacedWith(Value *New) { + // Remove this SCEVUnknown from the uniquing map. + SE->UniqueSCEVs.RemoveNode(this); + + // Update this SCEVUnknown to point to the new value. This is needed + // because there may still be outstanding SCEVs which still point to + // this SCEVUnknown. + setValPtr(New); +} + +bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getOperand(0)->isNullValue() && + CE->getNumOperands() == 2) + if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1))) + if (CI->isOne()) { + AllocTy = cast<PointerType>(CE->getOperand(0)->getType()) + ->getElementType(); + return true; + } + + return false; +} + +bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getOperand(0)->isNullValue()) { + Type *Ty = + cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); + if (StructType *STy = dyn_cast<StructType>(Ty)) + if (!STy->isPacked() && + CE->getNumOperands() == 3 && + CE->getOperand(1)->isNullValue()) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2))) + if (CI->isOne() && + STy->getNumElements() == 2 && + STy->getElementType(0)->isIntegerTy(1)) { + AllocTy = STy->getElementType(1); + return true; + } + } + } + + return false; +} + +bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { + if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getNumOperands() == 3 && + CE->getOperand(0)->isNullValue() && + CE->getOperand(1)->isNullValue()) { + Type *Ty = + cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); + // Ignore vector types here so that ScalarEvolutionExpander doesn't + // emit getelementptrs that index into vectors. + if (Ty->isStructTy() || Ty->isArrayTy()) { + CTy = Ty; + FieldNo = CE->getOperand(2); + return true; + } + } + + return false; +} + +//===----------------------------------------------------------------------===// +// SCEV Utilities +//===----------------------------------------------------------------------===// + +/// Compare the two values \p LV and \p RV in terms of their "complexity" where +/// "complexity" is a partial (and somewhat ad-hoc) relation used to order +/// operands in SCEV expressions. \p EqCache is a set of pairs of values that +/// have been previously deemed to be "equally complex" by this routine. It is +/// intended to avoid exponential time complexity in cases like: +/// +/// %a = f(%x, %y) +/// %b = f(%a, %a) +/// %c = f(%b, %b) +/// +/// %d = f(%x, %y) +/// %e = f(%d, %d) +/// %f = f(%e, %e) +/// +/// CompareValueComplexity(%f, %c) +/// +/// Since we do not continue running this routine on expression trees once we +/// have seen unequal values, there is no need to track them in the cache. +static int +CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue, + const LoopInfo *const LI, Value *LV, Value *RV, + unsigned Depth) { + if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV)) + return 0; + + // Order pointer values after integer values. This helps SCEVExpander form + // GEPs. + bool LIsPointer = LV->getType()->isPointerTy(), + RIsPointer = RV->getType()->isPointerTy(); + if (LIsPointer != RIsPointer) + return (int)LIsPointer - (int)RIsPointer; + + // Compare getValueID values. + unsigned LID = LV->getValueID(), RID = RV->getValueID(); + if (LID != RID) + return (int)LID - (int)RID; + + // Sort arguments by their position. + if (const auto *LA = dyn_cast<Argument>(LV)) { + const auto *RA = cast<Argument>(RV); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; + } + + if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { + const auto *RGV = cast<GlobalValue>(RV); + + const auto IsGVNameSemantic = [&](const GlobalValue *GV) { + auto LT = GV->getLinkage(); + return !(GlobalValue::isPrivateLinkage(LT) || + GlobalValue::isInternalLinkage(LT)); + }; + + // Use the names to distinguish the two values, but only if the + // names are semantically important. + if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) + return LGV->getName().compare(RGV->getName()); + } + + // For instructions, compare their loop depth, and their operand count. This + // is pretty loose. + if (const auto *LInst = dyn_cast<Instruction>(LV)) { + const auto *RInst = cast<Instruction>(RV); + + // Compare loop depths. + const BasicBlock *LParent = LInst->getParent(), + *RParent = RInst->getParent(); + if (LParent != RParent) { + unsigned LDepth = LI->getLoopDepth(LParent), + RDepth = LI->getLoopDepth(RParent); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; + } + + // Compare the number of operands. + unsigned LNumOps = LInst->getNumOperands(), + RNumOps = RInst->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + for (unsigned Idx : seq(0u, LNumOps)) { + int Result = + CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx), + RInst->getOperand(Idx), Depth + 1); + if (Result != 0) + return Result; + } + } + + EqCacheValue.unionSets(LV, RV); + return 0; +} + +// Return negative, zero, or positive, if LHS is less than, equal to, or greater +// than RHS, respectively. A three-way result allows recursive comparisons to be +// more efficient. +static int CompareSCEVComplexity( + EquivalenceClasses<const SCEV *> &EqCacheSCEV, + EquivalenceClasses<const Value *> &EqCacheValue, + const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, + DominatorTree &DT, unsigned Depth = 0) { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return 0; + + // Primarily, sort the SCEVs by their getSCEVType(). + unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + if (LType != RType) + return (int)LType - (int)RType; + + if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS)) + return 0; + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + switch (static_cast<SCEVTypes>(LType)) { + case scUnknown: { + const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); + const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); + + int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(), + RU->getValue(), Depth + 1); + if (X == 0) + EqCacheSCEV.unionSets(LHS, RHS); + return X; + } + + case scConstant: { + const SCEVConstant *LC = cast<SCEVConstant>(LHS); + const SCEVConstant *RC = cast<SCEVConstant>(RHS); + + // Compare constant values. + const APInt &LA = LC->getAPInt(); + const APInt &RA = RC->getAPInt(); + unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); + if (LBitWidth != RBitWidth) + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; + } + + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); + const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + + // There is always a dominance between two recs that are used by one SCEV, + // so we can safely sort recs by loop header dominance. We require such + // order in getAddExpr. + const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); + if (LLoop != RLoop) { + const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader(); + assert(LHead != RHead && "Two loops share the same header?"); + if (DT.dominates(LHead, RHead)) + return 1; + else + assert(DT.dominates(RHead, LHead) && + "No dominance between recurrences used by one SCEV?"); + return -1; + } + + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LA->getOperand(i), RA->getOperand(i), DT, + Depth + 1); + if (X != 0) + return X; + } + EqCacheSCEV.unionSets(LHS, RHS); + return 0; + } + + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: + case scSMinExpr: + case scUMinExpr: { + const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); + const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); + + // Lexicographically compare n-ary expressions. + unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + for (unsigned i = 0; i != LNumOps; ++i) { + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LC->getOperand(i), RC->getOperand(i), DT, + Depth + 1); + if (X != 0) + return X; + } + EqCacheSCEV.unionSets(LHS, RHS); + return 0; + } + + case scUDivExpr: { + const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); + const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); + + // Lexicographically compare udiv expressions. + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(), + RC->getLHS(), DT, Depth + 1); + if (X != 0) + return X; + X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(), + RC->getRHS(), DT, Depth + 1); + if (X == 0) + EqCacheSCEV.unionSets(LHS, RHS); + return X; + } + + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); + const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + + // Compare cast expressions by operand. + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LC->getOperand(), RC->getOperand(), DT, + Depth + 1); + if (X == 0) + EqCacheSCEV.unionSets(LHS, RHS); + return X; + } + + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} + +/// Given a list of SCEV objects, order them by their complexity, and group +/// objects of the same complexity together by value. When this routine is +/// finished, we know that any duplicates in the vector are consecutive and that +/// complexity is monotonically increasing. +/// +/// Note that we go take special precautions to ensure that we get deterministic +/// results from this routine. In other words, we don't want the results of +/// this to depend on where the addresses of various SCEV objects happened to +/// land in memory. +static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, + LoopInfo *LI, DominatorTree &DT) { + if (Ops.size() < 2) return; // Noop + + EquivalenceClasses<const SCEV *> EqCacheSCEV; + EquivalenceClasses<const Value *> EqCacheValue; + if (Ops.size() == 2) { + // This is the common case, which also happens to be trivially simple. + // Special case it. + const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; + if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0) + std::swap(LHS, RHS); + return; + } + + // Do the rough sort by complexity. + llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) { + return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) < + 0; + }); + + // Now that we are sorted by complexity, group elements of the same + // complexity. Note that this is, at worst, N^2, but the vector is likely to + // be extremely short in practice. Note that we take this approach because we + // do not want to depend on the addresses of the objects we are grouping. + for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { + const SCEV *S = Ops[i]; + unsigned Complexity = S->getSCEVType(); + + // If there are any objects of the same complexity and same value as this + // one, group them. + for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { + if (Ops[j] == S) { // Found a duplicate. + // Move it to immediately after i'th element. + std::swap(Ops[i+1], Ops[j]); + ++i; // no need to rescan it. + if (i == e-2) return; // Done! + } + } + } +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { + struct FindSCEVSize { + int Size = 0; + + FindSCEVSize() = default; + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + + bool isDone() const { + return false; + } + }; + + FindSCEVSize F; + SCEVTraversal<FindSCEVSize> ST(F); + ST.visitAll(S); + return F.Size; +} + +/// Returns true if the subtree of \p S contains at least HugeExprThreshold +/// nodes. +static bool isHugeExpression(const SCEV *S) { + return S->getExpressionSize() >= HugeExprThreshold; +} + +/// Returns true of \p Ops contains a huge SCEV (see definition above). +static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) { + return any_of(Ops, isHugeExpression); +} + +namespace { + +struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + SCEVDivision D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // A simple case when N/1. The quotient is N. + if (Denominator->isOne()) { + *Quotient = Numerator; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; + } + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitSMinExpr(const SCEVSMinExpr *Numerator) {} + void visitUMinExpr(const SCEVUMinExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { + APInt NumeratorVal = Numerator->getAPInt(); + APInt DenominatorVal = D->getAPInt(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); + return; + } + } + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + if (!Numerator->isAffine()) + return cannotDivide(Numerator); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + // Bail out if the types do not match. + Type *Ty = Denominator->getType(); + if (Ty != StartQ->getType() || Ty != StartR->getType() || + Ty != StepQ->getType() || Ty != StepR->getType()) + return cannotDivide(Numerator); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) + return cannotDivide(Numerator); + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); + } + + void visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) + return cannotDivide(Numerator); + + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) + return cannotDivide(Numerator); + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa<SCEVUnknown>(Denominator)) + return cannotDivide(Numerator); + + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(One)->getValue(); + Quotient = + SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + // This SCEV does not seem to simplify: fail the division here. + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) + return cannotDivide(Numerator); + divide(SE, Diff, Denominator, &Q, &R); + if (R != Zero) + return cannotDivide(Numerator); + Quotient = Q; + } + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getZero(Denominator->getType()); + One = SE.getOne(Denominator->getType()); + + // We generally do not know how to divide Expr by Denominator. We + // initialize the division to a "cannot divide" state to simplify the rest + // of the code. + cannotDivide(Numerator); + } + + // Convenience function for giving up on the division. We set the quotient to + // be equal to zero and the remainder to be equal to the numerator. + void cannotDivide(const SCEV *Numerator) { + Quotient = Zero; + Remainder = Numerator; + } + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Simple SCEV method implementations +//===----------------------------------------------------------------------===// + +/// Compute BC(It, K). The result has width W. Assume, K > 0. +static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, + ScalarEvolution &SE, + Type *ResultTy) { + // Handle the simplest case efficiently. + if (K == 1) + return SE.getTruncateOrZeroExtend(It, ResultTy); + + // We are using the following formula for BC(It, K): + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! + // + // Suppose, W is the bitwidth of the return value. We must be prepared for + // overflow. Hence, we must assure that the result of our computation is + // equal to the accurate one modulo 2^W. Unfortunately, division isn't + // safe in modular arithmetic. + // + // However, this code doesn't use exactly that formula; the formula it uses + // is something like the following, where T is the number of factors of 2 in + // K! (i.e. trailing zeros in the binary representation of K!), and ^ is + // exponentiation: + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) + // + // This formula is trivially equivalent to the previous formula. However, + // this formula can be implemented much more efficiently. The trick is that + // K! / 2^T is odd, and exact division by an odd number *is* safe in modular + // arithmetic. To do exact division in modular arithmetic, all we have + // to do is multiply by the inverse. Therefore, this step can be done at + // width W. + // + // The next issue is how to safely do the division by 2^T. The way this + // is done is by doing the multiplication step at a width of at least W + T + // bits. This way, the bottom W+T bits of the product are accurate. Then, + // when we perform the division by 2^T (which is equivalent to a right shift + // by T), the bottom W bits are accurate. Extra bits are okay; they'll get + // truncated out after the division by 2^T. + // + // In comparison to just directly using the first formula, this technique + // is much more efficient; using the first formula requires W * K bits, + // but this formula less than W + K bits. Also, the first formula requires + // a division step, whereas this formula only requires multiplies and shifts. + // + // It doesn't matter whether the subtraction step is done in the calculation + // width or the input iteration count's width; if the subtraction overflows, + // the result must be zero anyway. We prefer here to do it in the width of + // the induction variable because it helps a lot for certain cases; CodeGen + // isn't smart enough to ignore the overflow, which leads to much less + // efficient code if the width of the subtraction is wider than the native + // register width. + // + // (It's possible to not widen at all by pulling out factors of 2 before + // the multiplication; for example, K=2 can be calculated as + // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires + // extra arithmetic, so it's not an obvious win, and it gets + // much more complicated for K > 3.) + + // Protection from insane SCEVs; this bound is conservative, + // but it probably doesn't matter. + if (K > 1000) + return SE.getCouldNotCompute(); + + unsigned W = SE.getTypeSizeInBits(ResultTy); + + // Calculate K! / 2^T and T; we divide out the factors of two before + // multiplying for calculating K! / 2^T to avoid overflow. + // Other overflow doesn't matter because we only care about the bottom + // W bits of the result. + APInt OddFactorial(W, 1); + unsigned T = 1; + for (unsigned i = 3; i <= K; ++i) { + APInt Mult(W, i); + unsigned TwoFactors = Mult.countTrailingZeros(); + T += TwoFactors; + Mult.lshrInPlace(TwoFactors); + OddFactorial *= Mult; + } + + // We need at least W + T bits for the multiplication step + unsigned CalculationBits = W + T; + + // Calculate 2^T, at width T+W. + APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); + + // Calculate the multiplicative inverse of K! / 2^T; + // this multiplication factor will perform the exact division by + // K! / 2^T. + APInt Mod = APInt::getSignedMinValue(W+1); + APInt MultiplyFactor = OddFactorial.zext(W+1); + MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); + MultiplyFactor = MultiplyFactor.trunc(W); + + // Calculate the product, at width T+W + IntegerType *CalculationTy = IntegerType::get(SE.getContext(), + CalculationBits); + const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + for (unsigned i = 1; i != K; ++i) { + const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); + Dividend = SE.getMulExpr(Dividend, + SE.getTruncateOrZeroExtend(S, CalculationTy)); + } + + // Divide by 2^T + const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + + // Truncate the result, and divide by K! / 2^T. + + return SE.getMulExpr(SE.getConstant(MultiplyFactor), + SE.getTruncateOrZeroExtend(DivResult, ResultTy)); +} + +/// Return the value of this chain of recurrences at the specified iteration +/// number. We can evaluate this recurrence by multiplying each element in the +/// chain by the binomial coefficient corresponding to it. In other words, we +/// can evaluate {A,+,B,+,C,+,D} as: +/// +/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) +/// +/// where BC(It, k) stands for binomial coefficient. +const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, + ScalarEvolution &SE) const { + const SCEV *Result = getStart(); + for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { + // The computation is correct in the face of overflow provided that the + // multiplication is performed _after_ the evaluation of the binomial + // coefficient. + const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); + if (isa<SCEVCouldNotCompute>(Coeff)) + return Coeff; + + Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); + } + return Result; +} + +//===----------------------------------------------------------------------===// +// SCEV Expression folder implementations +//===----------------------------------------------------------------------===// + +const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, + unsigned Depth) { + assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && + "This is not a truncating conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + FoldingSetNodeID ID; + ID.AddInteger(scTruncate); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + + // Fold if the operand is constant. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) + return getConstant( + cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty))); + + // trunc(trunc(x)) --> trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) + return getTruncateExpr(ST->getOperand(), Ty, Depth + 1); + + // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing + if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) + return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1); + + // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing + if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) + return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1); + + if (Depth > MaxCastDepth) { + SCEV *S = + new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; + } + + // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and + // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN), + // if after transforming we have at most one truncate, not counting truncates + // that replace other casts. + if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) { + auto *CommOp = cast<SCEVCommutativeExpr>(Op); + SmallVector<const SCEV *, 4> Operands; + unsigned numTruncs = 0; + for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; + ++i) { + const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); + if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S)) + numTruncs++; + Operands.push_back(S); + } + if (numTruncs < 2) { + if (isa<SCEVAddExpr>(Op)) + return getAddExpr(Operands); + else if (isa<SCEVMulExpr>(Op)) + return getMulExpr(Operands); + else + llvm_unreachable("Unexpected SCEV type for Op."); + } + // Although we checked in the beginning that ID is not in the cache, it is + // possible that during recursion and different modification ID was inserted + // into the cache. So if we find it, just return it. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + } + + // If the input value is a chrec scev, truncate the chrec's operands. + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { + SmallVector<const SCEV *, 4> Operands; + for (const SCEV *Op : AddRec->operands()) + Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1)); + return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); + } + + // The cast wasn't folded; create an explicit cast node. We can reuse + // the existing insert position since if we get here, we won't have + // made any changes which would invalidate it. + SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the +// loop does not exceed this limit before incrementing. +static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + if (SE->isKnownPositive(Step)) { + *Pred = ICmpInst::ICMP_SLT; + return SE->getConstant(APInt::getSignedMinValue(BitWidth) - + SE->getSignedRangeMax(Step)); + } + if (SE->isKnownNegative(Step)) { + *Pred = ICmpInst::ICMP_SGT; + return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - + SE->getSignedRangeMin(Step)); + } + return nullptr; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// unsigned overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + *Pred = ICmpInst::ICMP_ULT; + + return SE->getConstant(APInt::getMinValue(BitWidth) - + SE->getUnsignedRangeMax(Step)); +} + +namespace { + +struct ExtendOpTraitsBase { + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, + unsigned); +}; + +// Used to make code generic over signed and unsigned overflow. +template <typename ExtendOp> struct ExtendOpTraits { + // Members present: + // + // static const SCEV::NoWrapFlags WrapType; + // + // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; + // + // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // ICmpInst::Predicate *Pred, + // ScalarEvolution *SE); +}; + +template <> +struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getSignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; + +template <> +struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getUnsignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; + +} // end anonymous namespace + +// The recurrence AR has been shown to have no signed/unsigned wrap or something +// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as +// easily prove NSW/NUW for its preincrement or postincrement sibling. This +// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + +// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the +// expression "Step + sext/zext(PreIncAR)" is congruent with +// "sext/zext(PostIncAR)" +template <typename ExtendOpTy> +static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, unsigned Depth) { + auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; + auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + + const Loop *L = AR->getLoop(); + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // Check for a simple looking step prior to loop entry. + const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); + if (!SA) + return nullptr; + + // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV + // subtraction is expensive. For this purpose, perform a quick and dirty + // difference, by checking for Step in the operand list. + SmallVector<const SCEV *, 4> DiffOps; + for (const SCEV *Op : SA->operands()) + if (Op != Step) + DiffOps.push_back(Op); + + if (DiffOps.size() == SA->getNumOperands()) + return nullptr; + + // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + + // `Step`: + + // 1. NSW/NUW flags on the step increment. + auto PreStartFlags = + ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); + const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); + const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( + SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + + // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies + // "S+X does not sign/unsign-overflow". + // + + const SCEV *BECount = SE->getBackedgeTakenCount(L); + if (PreAR && PreAR->getNoWrapFlags(WrapType) && + !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) + return PreStart; + + // 2. Direct overflow check on the step operation's expression. + unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); + const SCEV *OperandExtendedStart = + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), + (SE->*GetExtendExpr)(Step, WideTy, Depth)); + if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { + if (PreAR && AR->getNoWrapFlags(WrapType)) { + // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW + // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then + // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. + const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType); + } + return PreStart; + } + + // 3. Loop precondition. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); + + if (OverflowLimit && + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) + return PreStart; + + return nullptr; +} + +// Get the normalized zero or sign extended expression for this AddRec's Start. +template <typename ExtendOpTy> +static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, + unsigned Depth) { + auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + + const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth); + if (!PreStart) + return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); + + return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, + Depth), + (SE->*GetExtendExpr)(PreStart, Ty, Depth)); +} + +// Try to prove away overflow by looking at "nearby" add recurrences. A +// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it +// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. +// +// Formally: +// +// {S,+,X} == {S-T,+,X} + T +// => Ext({S,+,X}) == Ext({S-T,+,X} + T) +// +// If ({S-T,+,X} + T) does not overflow ... (1) +// +// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) +// +// If {S-T,+,X} does not overflow ... (2) +// +// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) +// == {Ext(S-T)+Ext(T),+,Ext(X)} +// +// If (S-T)+T does not overflow ... (3) +// +// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} +// == {Ext(S),+,Ext(X)} == LHS +// +// Thus, if (1), (2) and (3) are true for some T, then +// Ext({S,+,X}) == {Ext(S),+,Ext(X)} +// +// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) +// does not overflow" restricted to the 0th iteration. Therefore we only need +// to check for (1) and (2). +// +// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T +// is `Delta` (defined below). +template <typename ExtendOpTy> +bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, + const SCEV *Step, + const Loop *L) { + auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; + + // We restrict `Start` to a constant to prevent SCEV from spending too much + // time here. It is correct (but more expensive) to continue with a + // non-constant `Start` and do a general SCEV subtraction to compute + // `PreStart` below. + const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); + if (!StartC) + return false; + + APInt StartAI = StartC->getAPInt(); + + for (unsigned Delta : {-2, -1, 1, 2}) { + const SCEV *PreStart = getConstant(StartAI - Delta); + + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + const auto *PreAR = + static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + + // Give up if we don't already have the add recurrence we need because + // actually constructing an add recurrence is relatively expensive. + if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) + const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep( + DeltaS, &Pred, this); + if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) + return true; + } + } + + return false; +} + +// Finds an integer D for an expression (C + x + y + ...) such that the top +// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or +// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is +// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and +// the (C + x + y + ...) expression is \p WholeAddExpr. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, + const SCEVConstant *ConstantTerm, + const SCEVAddExpr *WholeAddExpr) { + const APInt C = ConstantTerm->getAPInt(); + const unsigned BitWidth = C.getBitWidth(); + // Find number of trailing zeros of (x + y + ...) w/o the C first: + uint32_t TZ = BitWidth; + for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) + TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I))); + if (TZ) { + // Set D to be as many least significant bits of C as possible while still + // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: + return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C; + } + return APInt(BitWidth, 0); +} + +// Finds an integer D for an affine AddRec expression {C,+,x} such that the top +// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the +// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p +// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, + const APInt &ConstantStart, + const SCEV *Step) { + const unsigned BitWidth = ConstantStart.getBitWidth(); + const uint32_t TZ = SE.GetMinTrailingZeros(Step); + if (TZ) + return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) + : ConstantStart; + return APInt(BitWidth, 0); +} + +const SCEV * +ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && + "This is not an extending conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + // Fold if the operand is constant. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) + return getConstant( + cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); + + // zext(zext(x)) --> zext(x) + if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) + return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); + + // Before doing any expensive analysis, check to see if we've already + // computed a SCEV for this Op and Ty. + FoldingSetNodeID ID; + ID.AddInteger(scZeroExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (Depth > MaxCastDepth) { + SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; + } + + // zext(trunc(x)) --> zext(x) or x or trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { + // It's possible the bits taken off by the truncate were all zero bits. If + // so, we should be able to simplify this further. + const SCEV *X = ST->getOperand(); + ConstantRange CR = getUnsignedRange(X); + unsigned TruncBits = getTypeSizeInBits(ST->getType()); + unsigned NewBits = getTypeSizeInBits(Ty); + if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( + CR.zextOrTrunc(NewBits))) + return getTruncateOrZeroExtend(X, Ty, Depth); + } + + // If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can zero extend all of the + // operands (often constants). This allows analysis of something like + // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) + if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + + if (!AR->hasNoUnsignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoUnsignedWrap()) + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + if (!isa<SCEVCouldNotCompute>(MaxBECount)) { + // Manually compute the final value for AR, checking for + // overflow. + + // Check whether the backedge-taken count can be losslessly casted to + // the addrec's type. The count is always unsigned. + const SCEV *CastedMaxBECount = + getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); + const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + CastedMaxBECount, MaxBECount->getType(), Depth); + if (MaxBECount == RecastedMaxBECount) { + Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); + // Check whether Start+Step*MaxBECount has no unsigned overflow. + const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, + SCEV::FlagAnyWrap, Depth + 1); + const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, + SCEV::FlagAnyWrap, + Depth + 1), + WideTy, Depth + 1); + const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); + const SCEV *WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + const SCEV *OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); + if (ZAdd == OperandExtendedAdd) { + // Cache knowledge of AR NUW, which is propagated to this AddRec. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + // Similar to above, only this time treat the step value as signed. + // This covers loops that count down. + OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); + if (ZAdd == OperandExtendedAdd) { + // Cache knowledge of AR NW, which is propagated to this AddRec. + // Negative step causes unsigned wrap, but it still can't self-wrap. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + } + } + + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - + getUnsignedRangeMax(Step)); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || + isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { + // Cache knowledge of AR NUW, which is propagated to this + // AddRec. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + } else if (isKnownNegative(Step)) { + const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - + getSignedRangeMin(Step)); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || + isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { + // Cache knowledge of AR NW, which is propagated to this + // AddRec. Negative step causes unsigned wrap, but it + // still can't self-wrap. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + } + } + + // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw> + // if D + (C - D + Step * n) could be proven to not unsigned wrap + // where D maximizes the number of trailing zeros of (C - D + Step * n) + if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { + const APInt &C = SC->getAPInt(); + const APInt &D = extractConstantWithoutWrapping(*this, C, Step); + if (D != 0) { + const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); + const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SZExtD, SZExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + + if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + } + } + + // zext(A % B) --> zext(A) % zext(B) + { + const SCEV *LHS; + const SCEV *RHS; + if (matchURem(Op, LHS, RHS)) + return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), + getZeroExtendExpr(RHS, Ty, Depth + 1)); + } + + // zext(A / B) --> zext(A) / zext(B). + if (auto *Div = dyn_cast<SCEVUDivExpr>(Op)) + return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1), + getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1)); + + if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { + // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> + if (SA->hasNoUnsignedWrap()) { + // If the addition does not unsign overflow then we can, by definition, + // commute the zero extension with the addition operation. + SmallVector<const SCEV *, 4> Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); + return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); + } + + // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...)) + // if D + (C - D + x + y + ...) could be proven to not unsigned wrap + // where D maximizes the number of trailing zeros of (C - D + x + y + ...) + // + // Often address arithmetics contain expressions like + // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))). + // This transformation is useful while proving that such expressions are + // equal or differ by a small constant amount, see LoadStoreVectorizer pass. + if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { + const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); + if (D != 0) { + const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); + const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SZExtD, SZExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + } + + if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) { + // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw> + if (SM->hasNoUnsignedWrap()) { + // If the multiply does not unsign overflow then we can, by definition, + // commute the zero extension with the multiply operation. + SmallVector<const SCEV *, 4> Ops; + for (const auto *Op : SM->operands()) + Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); + return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); + } + + // zext(2^K * (trunc X to iN)) to iM -> + // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw> + // + // Proof: + // + // zext(2^K * (trunc X to iN)) to iM + // = zext((trunc X to iN) << K) to iM + // = zext((trunc X to i{N-K}) << K)<nuw> to iM + // (because shl removes the top K bits) + // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM + // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. + // + if (SM->getNumOperands() == 2) + if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) + if (MulLHS->getAPInt().isPowerOf2()) + if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { + int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - + MulLHS->getAPInt().logBase2(); + Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); + return getMulExpr( + getZeroExtendExpr(MulLHS, Ty), + getZeroExtendExpr( + getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), + SCEV::FlagNUW, Depth + 1); + } + } + + // The cast wasn't folded; create an explicit cast node. + // Recompute the insert position, as it may have been invalidated. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; +} + +const SCEV * +ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && + "This is not an extending conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + // Fold if the operand is constant. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) + return getConstant( + cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); + + // sext(sext(x)) --> sext(x) + if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) + return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1); + + // sext(zext(x)) --> zext(x) + if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) + return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); + + // Before doing any expensive analysis, check to see if we've already + // computed a SCEV for this Op and Ty. + FoldingSetNodeID ID; + ID.AddInteger(scSignExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // Limit recursion depth. + if (Depth > MaxCastDepth) { + SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; + } + + // sext(trunc(x)) --> sext(x) or x or trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { + // It's possible the bits taken off by the truncate were all sign bits. If + // so, we should be able to simplify this further. + const SCEV *X = ST->getOperand(); + ConstantRange CR = getSignedRange(X); + unsigned TruncBits = getTypeSizeInBits(ST->getType()); + unsigned NewBits = getTypeSizeInBits(Ty); + if (CR.truncate(TruncBits).signExtend(NewBits).contains( + CR.sextOrTrunc(NewBits))) + return getTruncateOrSignExtend(X, Ty, Depth); + } + + if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { + // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> + if (SA->hasNoSignedWrap()) { + // If the addition does not sign overflow then we can, by definition, + // commute the sign extension with the addition operation. + SmallVector<const SCEV *, 4> Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); + return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); + } + + // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...)) + // if D + (C - D + x + y + ...) could be proven to not signed wrap + // where D maximizes the number of trailing zeros of (C - D + x + y + ...) + // + // For instance, this will bring two seemingly different expressions: + // 1 + sext(5 + 20 * %x + 24 * %y) and + // sext(6 + 20 * %x + 24 * %y) + // to the same form: + // 2 + sext(4 + 20 * %x + 24 * %y) + if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { + const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); + if (D != 0) { + const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); + const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SSExtD, SSExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + } + // If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This allows analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) + if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + + if (!AR->hasNoSignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoSignedWrap()) + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW); + + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + if (!isa<SCEVCouldNotCompute>(MaxBECount)) { + // Manually compute the final value for AR, checking for + // overflow. + + // Check whether the backedge-taken count can be losslessly casted to + // the addrec's type. The count is always unsigned. + const SCEV *CastedMaxBECount = + getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); + const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + CastedMaxBECount, MaxBECount->getType(), Depth); + if (MaxBECount == RecastedMaxBECount) { + Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); + // Check whether Start+Step*MaxBECount has no signed overflow. + const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, + SCEV::FlagAnyWrap, Depth + 1); + const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, + SCEV::FlagAnyWrap, + Depth + 1), + WideTy, Depth + 1); + const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); + const SCEV *WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + const SCEV *OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); + if (SAdd == OperandExtendedAdd) { + // Cache knowledge of AR NSW, which is propagated to this AddRec. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + // Similar to above, only this time treat the step value as unsigned. + // This covers loops that count up with an unsigned step. + OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); + if (SAdd == OperandExtendedAdd) { + // If AR wraps around then + // + // abs(Step) * MaxBECount > unsigned-max(AR->getType()) + // => SAdd != OperandExtendedAdd + // + // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> + // (SAdd == OperandExtendedAdd => AR is NW) + + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); + + // Return the expression with the addrec on the outside. + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); + } + } + } + + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + getSignedOverflowLimitForStep(Step, &Pred, this); + if (OverflowLimit && + (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || + isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { + // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + } + } + + // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> + // if D + (C - D + Step * n) could be proven to not signed wrap + // where D maximizes the number of trailing zeros of (C - D + Step * n) + if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { + const APInt &C = SC->getAPInt(); + const APInt &D = extractConstantWithoutWrapping(*this, C, Step); + if (D != 0) { + const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); + const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SSExtD, SSExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + + if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + } + } + + // If the input value is provably positive and we could not simplify + // away the sext build a zext instead. + if (isKnownNonNegative(Op)) + return getZeroExtendExpr(Op, Ty, Depth + 1); + + // The cast wasn't folded; create an explicit cast node. + // Recompute the insert position, as it may have been invalidated. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; +} + +/// getAnyExtendExpr - Return a SCEV for the given operand extended with +/// unspecified bits out to the given type. +const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, + Type *Ty) { + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && + "This is not an extending conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + // Sign-extend negative constants. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) + if (SC->getAPInt().isNegative()) + return getSignExtendExpr(Op, Ty); + + // Peel off a truncate cast. + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { + const SCEV *NewOp = T->getOperand(); + if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) + return getAnyExtendExpr(NewOp, Ty); + return getTruncateOrNoop(NewOp, Ty); + } + + // Next try a zext cast. If the cast is folded, use it. + const SCEV *ZExt = getZeroExtendExpr(Op, Ty); + if (!isa<SCEVZeroExtendExpr>(ZExt)) + return ZExt; + + // Next try a sext cast. If the cast is folded, use it. + const SCEV *SExt = getSignExtendExpr(Op, Ty); + if (!isa<SCEVSignExtendExpr>(SExt)) + return SExt; + + // Force the cast to be folded into the operands of an addrec. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) { + SmallVector<const SCEV *, 4> Ops; + for (const SCEV *Op : AR->operands()) + Ops.push_back(getAnyExtendExpr(Op, Ty)); + return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); + } + + // If the expression is obviously signed, use the sext cast value. + if (isa<SCEVSMaxExpr>(Op)) + return SExt; + + // Absent any other information, use the zext cast value. + return ZExt; +} + +/// Process the given Ops list, which is a list of operands to be added under +/// the given scale, update the given map. This is a helper function for +/// getAddRecExpr. As an example of what it does, given a sequence of operands +/// that would form an add expression like this: +/// +/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) +/// +/// where A and B are constants, update the map with these values: +/// +/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) +/// +/// and add 13 + A*B*29 to AccumulatedConstant. +/// This will allow getAddRecExpr to produce this: +/// +/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) +/// +/// This form often exposes folding opportunities that are hidden in +/// the original operand list. +/// +/// Return true iff it appears that any interesting folding opportunities +/// may be exposed. This helps getAddRecExpr short-circuit extra work in +/// the common case where no interesting opportunities are present, and +/// is also used as a check to avoid infinite recursion. +static bool +CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, + SmallVectorImpl<const SCEV *> &NewOps, + APInt &AccumulatedConstant, + const SCEV *const *Ops, size_t NumOperands, + const APInt &Scale, + ScalarEvolution &SE) { + bool Interesting = false; + + // Iterate over the add operands. They are sorted, with constants first. + unsigned i = 0; + while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { + ++i; + // Pull a buried constant out to the outside. + if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) + Interesting = true; + AccumulatedConstant += Scale * C->getAPInt(); + } + + // Next comes everything else. We're especially interested in multiplies + // here, but they're in the middle, so just visit the rest with one loop. + for (; i != NumOperands; ++i) { + const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); + if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { + APInt NewScale = + Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt(); + if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { + // A multiplication of a constant with another add; recurse. + const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1)); + Interesting |= + CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + Add->op_begin(), Add->getNumOperands(), + NewScale, SE); + } else { + // A multiplication of a constant with some other value. Update + // the map. + SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); + const SCEV *Key = SE.getMulExpr(MulOps); + auto Pair = M.insert({Key, NewScale}); + if (Pair.second) { + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += NewScale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } else { + // An ordinary operand. Update the map. + std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = + M.insert({Ops[i], Scale}); + if (Pair.second) { + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += Scale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } + + return Interesting; +} + +// We're trying to construct a SCEV of type `Type' with `Ops' as operands and +// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of +// can't-overflow flags for the operation if possible. +static SCEV::NoWrapFlags +StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, + const ArrayRef<const SCEV *> Ops, + SCEV::NoWrapFlags Flags) { + using namespace std::placeholders; + + using OBO = OverflowingBinaryOperator; + + bool CanAnalyze = + Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; + (void)CanAnalyze; + assert(CanAnalyze && "don't call from other places!"); + + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = + ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + auto IsKnownNonNegative = [&](const SCEV *S) { + return SE->isKnownNonNegative(S); + }; + + if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) + Flags = + ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); + + SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + + if (SignOrUnsignWrap != SignOrUnsignMask && + (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && + isa<SCEVConstant>(Ops[0])) { + + auto Opcode = [&] { + switch (Type) { + case scAddExpr: + return Instruction::Add; + case scMulExpr: + return Instruction::Mul; + default: + llvm_unreachable("Unexpected SCEV op."); + } + }(); + + const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); + + // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow. + if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Opcode, C, OBO::NoSignedWrap); + if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + } + + // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow. + if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Opcode, C, OBO::NoUnsignedWrap); + if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + } + } + + return Flags; +} + +bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { + return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); +} + +/// Get a canonical add expression, or something simpler if possible. +const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags, + unsigned Depth) { + assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && + "only nuw or nsw allowed"); + assert(!Ops.empty() && "Cannot get empty add!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && + "SCEVAddExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, &LI, DT); + + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); + if (Ops.size() == 2) return Ops[0]; + Ops.erase(Ops.begin()+1); // Erase the folded element + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant zero being added, strip it off. + if (LHSC->getValue()->isZero()) { + Ops.erase(Ops.begin()); + --Idx; + } + + if (Ops.size() == 1) return Ops[0]; + } + + // Limit recursion calls depth. + if (Depth > MaxArithDepth || hasHugeExpression(Ops)) + return getOrCreateAddExpr(Ops, Flags); + + // Okay, check to see if the same value occurs in the operand list more than + // once. If so, merge them together into an multiply expression. Since we + // sorted the list, these values are required to be adjacent. + Type *Ty = Ops[0]->getType(); + bool FoundMatch = false; + for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) + if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 + // Scan ahead to count how many equal operands there are. + unsigned Count = 2; + while (i+Count != e && Ops[i+Count] == Ops[i]) + ++Count; + // Merge the values into a multiply. + const SCEV *Scale = getConstant(Ty, Count); + const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); + if (Ops.size() == Count) + return Mul; + Ops[i] = Mul; + Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); + --i; e -= Count - 1; + FoundMatch = true; + } + if (FoundMatch) + return getAddExpr(Ops, Flags, Depth + 1); + + // Check for truncates. If all the operands are truncated from the same + // type, see if factoring out the truncate would permit the result to be + // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y) + // if the contents of the resulting outer trunc fold to something simple. + auto FindTruncSrcType = [&]() -> Type * { + // We're ultimately looking to fold an addrec of truncs and muls of only + // constants and truncs, so if we find any other types of SCEV + // as operands of the addrec then we bail and return nullptr here. + // Otherwise, we return the type of the operand of a trunc that we find. + if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx])) + return T->getOperand()->getType(); + if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { + const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); + if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp)) + return T->getOperand()->getType(); + } + return nullptr; + }; + if (auto *SrcType = FindTruncSrcType()) { + SmallVector<const SCEV *, 8> LargeOps; + bool Ok = true; + // Check all the operands to see if they can be represented in the + // source type of the truncate. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeOps.push_back(T->getOperand()); + } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { + LargeOps.push_back(getAnyExtendExpr(C, SrcType)); + } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { + SmallVector<const SCEV *, 8> LargeMulOps; + for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { + if (const SCEVTruncateExpr *T = + dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeMulOps.push_back(T->getOperand()); + } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) { + LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); + } else { + Ok = false; + break; + } + } + if (Ok) + LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1)); + } else { + Ok = false; + break; + } + } + if (Ok) { + // Evaluate the expression in the larger type. + const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); + // If it folds to something simple, use it. Otherwise, don't. + if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) + return getTruncateExpr(Fold, Ty); + } + } + + // Skip past any other cast SCEVs. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) + ++Idx; + + // If there are add operands they would be next. + if (Idx < Ops.size()) { + bool DeletedAdd = false; + while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { + if (Ops.size() > AddOpsInlineThreshold || + Add->getNumOperands() > AddOpsInlineThreshold) + break; + // If we have an add, expand the add operands onto the end of the operands + // list. + Ops.erase(Ops.begin()+Idx); + Ops.append(Add->op_begin(), Add->op_end()); + DeletedAdd = true; + } + + // If we deleted at least one add, we added operands to the end of the list, + // and they are not necessarily sorted. Recurse to resort and resimplify + // any operands we just acquired. + if (DeletedAdd) + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + + // Skip over the add expression until we get to a multiply. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) + ++Idx; + + // Check to see if there are any folding opportunities present with + // operands multiplied by constant values. + if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { + uint64_t BitWidth = getTypeSizeInBits(Ty); + DenseMap<const SCEV *, APInt> M; + SmallVector<const SCEV *, 8> NewOps; + APInt AccumulatedConstant(BitWidth, 0); + if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + Ops.data(), Ops.size(), + APInt(BitWidth, 1), *this)) { + struct APIntCompare { + bool operator()(const APInt &LHS, const APInt &RHS) const { + return LHS.ult(RHS); + } + }; + + // Some interesting folding opportunity is present, so its worthwhile to + // re-generate the operands list. Group the operands by constant scale, + // to avoid multiplying by the same constant scale multiple times. + std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists; + for (const SCEV *NewOp : NewOps) + MulOpLists[M.find(NewOp)->second].push_back(NewOp); + // Re-generate the operands list. + Ops.clear(); + if (AccumulatedConstant != 0) + Ops.push_back(getConstant(AccumulatedConstant)); + for (auto &MulOp : MulOpLists) + if (MulOp.first != 0) + Ops.push_back(getMulExpr( + getConstant(MulOp.first), + getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1)); + if (Ops.empty()) + return getZero(Ty); + if (Ops.size() == 1) + return Ops[0]; + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + } + + // If we are adding something to a multiply expression, make sure the + // something is not already an operand of the multiply. If so, merge it into + // the multiply. + for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { + const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); + for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { + const SCEV *MulOpSCEV = Mul->getOperand(MulOp); + if (isa<SCEVConstant>(MulOpSCEV)) + continue; + for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) + if (MulOpSCEV == Ops[AddOp]) { + // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) + const SCEV *InnerMul = Mul->getOperand(MulOp == 0); + if (Mul->getNumOperands() != 2) { + // If the multiply has more than two operands, we must get the + // Y*Z term. + SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), + Mul->op_begin()+MulOp); + MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); + InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); + } + SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; + const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, + SCEV::FlagAnyWrap, Depth + 1); + if (Ops.size() == 2) return OuterMul; + if (AddOp < Idx) { + Ops.erase(Ops.begin()+AddOp); + Ops.erase(Ops.begin()+Idx-1); + } else { + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+AddOp-1); + } + Ops.push_back(OuterMul); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + + // Check this multiply against other multiplies being added together. + for (unsigned OtherMulIdx = Idx+1; + OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); + ++OtherMulIdx) { + const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); + // If MulOp occurs in OtherMul, we can fold the two multiplies + // together. + for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); + OMulOp != e; ++OMulOp) + if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { + // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) + const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); + if (Mul->getNumOperands() != 2) { + SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), + Mul->op_begin()+MulOp); + MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); + InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); + } + const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); + if (OtherMul->getNumOperands() != 2) { + SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(), + OtherMul->op_begin()+OMulOp); + MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); + InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); + } + SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; + const SCEV *InnerMulSum = + getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, + SCEV::FlagAnyWrap, Depth + 1); + if (Ops.size() == 2) return OuterMul; + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+OtherMulIdx-1); + Ops.push_back(OuterMul); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + } + } + } + + // If there are any add recurrences in the operands list, see if any other + // added values are loop invariant. If so, we can fold them into the + // recurrence. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) + ++Idx; + + // Scan over all recurrences, trying to fold loop invariants into them. + for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { + // Scan all of the other operands to this add and add them to the vector if + // they are loop invariant w.r.t. the recurrence. + SmallVector<const SCEV *, 8> LIOps; + const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); + const Loop *AddRecLoop = AddRec->getLoop(); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { + LIOps.push_back(Ops[i]); + Ops.erase(Ops.begin()+i); + --i; --e; + } + + // If we found some loop invariants, fold them into the recurrence. + if (!LIOps.empty()) { + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} + LIOps.push_back(AddRec->getStart()); + + SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), + AddRec->op_end()); + // This follows from the fact that the no-wrap flags on the outer add + // expression are applicable on the 0th iteration, when the add recurrence + // will be equal to its start value. + AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); + + // Build the new addrec. Propagate the NUW and NSW flags if both the + // outer add and the inner addrec are guaranteed to have no overflow. + // Always propagate NW. + Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); + const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); + + // If all of the other operands were loop invariant, we are done. + if (Ops.size() == 1) return NewRec; + + // Otherwise, add the folded AddRec by the non-invariant parts. + for (unsigned i = 0;; ++i) + if (Ops[i] == AddRec) { + Ops[i] = NewRec; + break; + } + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + + // Okay, if there weren't any loop invariants to be folded, check to see if + // there are multiple AddRec's with the same loop induction variable being + // added together. If so, we can fold them. + for (unsigned OtherIdx = Idx+1; + OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); + ++OtherIdx) { + // We expect the AddRecExpr's to be sorted in reverse dominance order, + // so that the 1st found AddRecExpr is dominated by all others. + assert(DT.dominates( + cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), + AddRec->getLoop()->getHeader()) && + "AddRecExprs are not sorted in reverse dominance order?"); + if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { + // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> + SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), + AddRec->op_end()); + for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); + ++OtherIdx) { + const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (OtherAddRec->getLoop() == AddRecLoop) { + for (unsigned i = 0, e = OtherAddRec->getNumOperands(); + i != e; ++i) { + if (i >= AddRecOps.size()) { + AddRecOps.append(OtherAddRec->op_begin()+i, + OtherAddRec->op_end()); + break; + } + SmallVector<const SCEV *, 2> TwoOps = { + AddRecOps[i], OtherAddRec->getOperand(i)}; + AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + } + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + } + } + // Step size has changed, so we cannot guarantee no self-wraparound. + Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + } + + // Otherwise couldn't fold anything into this recurrence. Move onto the + // next one. + } + + // Okay, it looks like we really DO need an add expr. Check to see if we + // already have one, otherwise create a new one. + return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops, + SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scAddExpr); + for (const SCEV *Op : Ops) + ID.AddPointer(Op); + void *IP = nullptr; + SCEVAddExpr *S = + static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + } + S->setNoWrapFlags(Flags); + return S; +} + +const SCEV * +ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops, + const Loop *L, SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + ID.AddPointer(L); + void *IP = nullptr; + SCEVAddRecExpr *S = + static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) + SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + } + S->setNoWrapFlags(Flags); + return S; +} + +const SCEV * +ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops, + SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scMulExpr); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = nullptr; + SCEVMulExpr *S = + static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), + O, Ops.size()); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + } + S->setNoWrapFlags(Flags); + return S; +} + +static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { + uint64_t k = i*j; + if (j > 1 && k / j != i) Overflow = true; + return k; +} + +/// Compute the result of "n choose k", the binomial coefficient. If an +/// intermediate computation overflows, Overflow will be set and the return will +/// be garbage. Overflow is not cleared on absence of overflow. +static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { + // We use the multiplicative formula: + // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . + // At each iteration, we take the n-th term of the numeral and divide by the + // (k-n)th term of the denominator. This division will always produce an + // integral result, and helps reduce the chance of overflow in the + // intermediate computations. However, we can still overflow even when the + // final result would fit. + + if (n == 0 || n == k) return 1; + if (k > n) return 0; + + if (k > n/2) + k = n-k; + + uint64_t r = 1; + for (uint64_t i = 1; i <= k; ++i) { + r = umul_ov(r, n-(i-1), Overflow); + r /= i; + } + return r; +} + +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantInAddMulChain(const SCEV *StartExpr) { + struct FindConstantInAddMulChain { + bool FoundConstant = false; + + bool follow(const SCEV *S) { + FoundConstant |= isa<SCEVConstant>(S); + return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S); + } + + bool isDone() const { + return FoundConstant; + } + }; + + FindConstantInAddMulChain F; + SCEVTraversal<FindConstantInAddMulChain> ST(F); + ST.visitAll(StartExpr); + return F.FoundConstant; +} + +/// Get a canonical multiply expression, or something simpler if possible. +const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags, + unsigned Depth) { + assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && + "only nuw or nsw allowed"); + assert(!Ops.empty() && "Cannot get empty mul!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && + "SCEVMulExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, &LI, DT); + + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); + + // Limit recursion calls depth. + if (Depth > MaxArithDepth || hasHugeExpression(Ops)) + return getOrCreateMulExpr(Ops, Flags); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + + if (Ops.size() == 2) + // C1*(C2+V) -> C1*C2 + C1*V + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, apply this + // transformation as well. + // + // TODO: There are some cases where this transformation is not + // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of + // this transformation should be narrowed down. + if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), + SCEV::FlagAnyWrap, Depth + 1), + getMulExpr(LHSC, Add->getOperand(1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); + + ++Idx; + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = + ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant one being multiplied, strip it off. + if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) { + Ops.erase(Ops.begin()); + --Idx; + } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { + // If we have a multiply of zero, it will always be zero. + return Ops[0]; + } else if (Ops[0]->isAllOnesValue()) { + // If we have a mul by -1 of an add, try distributing the -1 among the + // add operands. + if (Ops.size() == 2) { + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { + SmallVector<const SCEV *, 4> NewOps; + bool AnyFolded = false; + for (const SCEV *AddOp : Add->operands()) { + const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, + Depth + 1); + if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; + NewOps.push_back(Mul); + } + if (AnyFolded) + return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); + } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { + // Negation preserves a recurrence's no self-wrap property. + SmallVector<const SCEV *, 4> Operands; + for (const SCEV *AddRecOp : AddRec->operands()) + Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, + Depth + 1)); + + return getAddRecExpr(Operands, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); + } + } + } + + if (Ops.size() == 1) + return Ops[0]; + } + + // Skip over the add expression until we get to a multiply. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) + ++Idx; + + // If there are mul operands inline them all into this expression. + if (Idx < Ops.size()) { + bool DeletedMul = false; + while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { + if (Ops.size() > MulOpsInlineThreshold) + break; + // If we have an mul, expand the mul operands onto the end of the + // operands list. + Ops.erase(Ops.begin()+Idx); + Ops.append(Mul->op_begin(), Mul->op_end()); + DeletedMul = true; + } + + // If we deleted at least one mul, we added operands to the end of the + // list, and they are not necessarily sorted. Recurse to resort and + // resimplify any operands we just acquired. + if (DeletedMul) + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + + // If there are any add recurrences in the operands list, see if any other + // added values are loop invariant. If so, we can fold them into the + // recurrence. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) + ++Idx; + + // Scan over all recurrences, trying to fold loop invariants into them. + for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { + // Scan all of the other operands to this mul and add them to the vector + // if they are loop invariant w.r.t. the recurrence. + SmallVector<const SCEV *, 8> LIOps; + const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); + const Loop *AddRecLoop = AddRec->getLoop(); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { + LIOps.push_back(Ops[i]); + Ops.erase(Ops.begin()+i); + --i; --e; + } + + // If we found some loop invariants, fold them into the recurrence. + if (!LIOps.empty()) { + // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} + SmallVector<const SCEV *, 4> NewOps; + NewOps.reserve(AddRec->getNumOperands()); + const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) + NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i), + SCEV::FlagAnyWrap, Depth + 1)); + + // Build the new addrec. Propagate the NUW and NSW flags if both the + // outer mul and the inner addrec are guaranteed to have no overflow. + // + // No self-wrap cannot be guaranteed after changing the step size, but + // will be inferred if either NUW or NSW is true. + Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); + const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); + + // If all of the other operands were loop invariant, we are done. + if (Ops.size() == 1) return NewRec; + + // Otherwise, multiply the folded AddRec by the non-invariant parts. + for (unsigned i = 0;; ++i) + if (Ops[i] == AddRec) { + Ops[i] = NewRec; + break; + } + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + } + + // Okay, if there weren't any loop invariants to be folded, check to see + // if there are multiple AddRec's with the same loop induction variable + // being multiplied together. If so, we can fold them. + + // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. + // + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; + for (unsigned OtherIdx = Idx+1; + OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); + ++OtherIdx) { + const SCEVAddRecExpr *OtherAddRec = + dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) + continue; + + // Limit max number of arguments to avoid creation of unreasonably big + // SCEVAddRecs with very complex operands. + if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 > + MaxAddRecSize || isHugeExpression(AddRec) || + isHugeExpression(OtherAddRec)) + continue; + + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector<const SCEV*, 7> AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { + SmallVector <const SCEV *, 7> SumOps; + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, + SCEV::FlagAnyWrap, Depth + 1)); + } + } + if (SumOps.empty()) + SumOps.push_back(getZero(Ty)); + AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); + } + if (!Overflow) { + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop, + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = NewAddRec; + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + OpsModified = true; + AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); + if (!AddRec) + break; + } + } + if (OpsModified) + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + + // Otherwise couldn't fold anything into this recurrence. Move onto the + // next one. + } + + // Okay, it looks like we really DO need an mul expr. Check to see if we + // already have one, otherwise create a new one. + return getOrCreateMulExpr(Ops, Flags); +} + +/// Represents an unsigned remainder expression based on unsigned division. +const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, + const SCEV *RHS) { + assert(getEffectiveSCEVType(LHS->getType()) == + getEffectiveSCEVType(RHS->getType()) && + "SCEVURemExpr operand types don't match!"); + + // Short-circuit easy cases + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { + // If constant is one, the result is trivial + if (RHSC->getValue()->isOne()) + return getZero(LHS->getType()); // X urem 1 --> 0 + + // If constant is a power of two, fold into a zext(trunc(LHS)). + if (RHSC->getAPInt().isPowerOf2()) { + Type *FullTy = LHS->getType(); + Type *TruncTy = + IntegerType::get(getContext(), RHSC->getAPInt().logBase2()); + return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy); + } + } + + // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y) + const SCEV *UDiv = getUDivExpr(LHS, RHS); + const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); + return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); +} + +/// Get a canonical unsigned division expression, or something simpler if +/// possible. +const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, + const SCEV *RHS) { + assert(getEffectiveSCEVType(LHS->getType()) == + getEffectiveSCEVType(RHS->getType()) && + "SCEVUDivExpr operand types don't match!"); + + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { + if (RHSC->getValue()->isOne()) + return LHS; // X udiv 1 --> x + // If the denominator is zero, the result of the udiv is undefined. Don't + // try to analyze it, because the resolution chosen here may differ from + // the resolution chosen in other parts of the compiler. + if (!RHSC->getValue()->isZero()) { + // Determine if the division can be folded into the operands of + // its operands. + // TODO: Generalize this to non-constants by using known-bits information. + Type *Ty = LHS->getType(); + unsigned LZ = RHSC->getAPInt().countLeadingZeros(); + unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; + // For non-power-of-two values, effectively round the value up to the + // nearest power of two. + if (!RHSC->getAPInt().isPowerOf2()) + ++MaxShiftAmt; + IntegerType *ExtTy = + IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) + if (const SCEVConstant *Step = + dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) { + // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. + const APInt &StepInt = Step->getAPInt(); + const APInt &DivInt = RHSC->getAPInt(); + if (!StepInt.urem(DivInt) && + getZeroExtendExpr(AR, ExtTy) == + getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), + getZeroExtendExpr(Step, ExtTy), + AR->getLoop(), SCEV::FlagAnyWrap)) { + SmallVector<const SCEV *, 4> Operands; + for (const SCEV *Op : AR->operands()) + Operands.push_back(getUDivExpr(Op, RHS)); + return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); + } + /// Get a canonical UDivExpr for a recurrence. + /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. + // We can currently only fold X%N if X is constant. + const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart()); + if (StartC && !DivInt.urem(StepInt) && + getZeroExtendExpr(AR, ExtTy) == + getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), + getZeroExtendExpr(Step, ExtTy), + AR->getLoop(), SCEV::FlagAnyWrap)) { + const APInt &StartInt = StartC->getAPInt(); + const APInt &StartRem = StartInt.urem(StepInt); + if (StartRem != 0) + LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, + AR->getLoop(), SCEV::FlagNW); + } + } + // (A*B)/C --> A*(B/C) if safe and B/C can be folded. + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { + SmallVector<const SCEV *, 4> Operands; + for (const SCEV *Op : M->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); + if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) + // Find an operand that's safely divisible. + for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { + const SCEV *Op = M->getOperand(i); + const SCEV *Div = getUDivExpr(Op, RHSC); + if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { + Operands = SmallVector<const SCEV *, 4>(M->op_begin(), + M->op_end()); + Operands[i] = Div; + return getMulExpr(Operands); + } + } + } + + // (A/B)/C --> A/(B*C) if safe and B*C can be folded. + if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) { + if (auto *DivisorConstant = + dyn_cast<SCEVConstant>(OtherDiv->getRHS())) { + bool Overflow = false; + APInt NewRHS = + DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow); + if (Overflow) { + return getConstant(RHSC->getType(), 0, false); + } + return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS)); + } + } + + // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. + if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { + SmallVector<const SCEV *, 4> Operands; + for (const SCEV *Op : A->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); + if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { + Operands.clear(); + for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { + const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); + if (isa<SCEVUDivExpr>(Op) || + getMulExpr(Op, RHS) != A->getOperand(i)) + break; + Operands.push_back(Op); + } + if (Operands.size() == A->getNumOperands()) + return getAddExpr(Operands); + } + } + + // Fold if both operands are constant. + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { + Constant *LHSCV = LHSC->getValue(); + Constant *RHSCV = RHSC->getValue(); + return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV, + RHSCV))); + } + } + } + + FoldingSetNodeID ID; + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), + LHS, RHS); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; +} + +static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getAPInt().abs(); + APInt B = C2->getAPInt().abs(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B)); +} + +/// Get a canonical unsigned division expression, or something simpler if +/// possible. There is no representation for an exact udiv in SCEV IR, but we +/// can attempt to remove factors from the LHS and RHS. We can't do this when +/// it's not exact because the udiv may be clearing bits. +const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, + const SCEV *RHS) { + // TODO: we could try to find factors in all sorts of things, but for now we + // just deal with u/exact (multiply, constant). See SCEVDivision towards the + // end of this file for inspiration. + + const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); + if (!Mul || !Mul->hasNoUnsignedWrap()) + return getUDivExpr(LHS, RHS); + + if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { + // If the mulexpr multiplies by a constant, then that constant must be the + // first element of the mulexpr. + if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) { + if (LHSCst == RHSCst) { + SmallVector<const SCEV *, 2> Operands; + Operands.append(Mul->op_begin() + 1, Mul->op_end()); + return getMulExpr(Operands); + } + + // We can't just assume that LHSCst divides RHSCst cleanly, it could be + // that there's a factor provided by one of the other terms. We need to + // check. + APInt Factor = gcd(LHSCst, RHSCst); + if (!Factor.isIntN(1)) { + LHSCst = + cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor))); + RHSCst = + cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor))); + SmallVector<const SCEV *, 2> Operands; + Operands.push_back(LHSCst); + Operands.append(Mul->op_begin() + 1, Mul->op_end()); + LHS = getMulExpr(Operands); + RHS = RHSCst; + Mul = dyn_cast<SCEVMulExpr>(LHS); + if (!Mul) + return getUDivExactExpr(LHS, RHS); + } + } + } + + for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { + if (Mul->getOperand(i) == RHS) { + SmallVector<const SCEV *, 2> Operands; + Operands.append(Mul->op_begin(), Mul->op_begin() + i); + Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); + return getMulExpr(Operands); + } + } + + return getUDivExpr(LHS, RHS); +} + +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. +const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, + const Loop *L, + SCEV::NoWrapFlags Flags) { + SmallVector<const SCEV *, 4> Operands; + Operands.push_back(Start); + if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) + if (StepChrec->getLoop() == L) { + Operands.append(StepChrec->op_begin(), StepChrec->op_end()); + return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); + } + + Operands.push_back(Step); + return getAddRecExpr(Operands, L, Flags); +} + +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. +const SCEV * +ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, + const Loop *L, SCEV::NoWrapFlags Flags) { + if (Operands.size() == 1) return Operands[0]; +#ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); + for (unsigned i = 1, e = Operands.size(); i != e; ++i) + assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && + "SCEVAddRecExpr operand types don't match!"); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + assert(isLoopInvariant(Operands[i], L) && + "SCEVAddRecExpr operand is not loop-invariant!"); +#endif + + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X + } + + // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and + // use that information to infer NUW and NSW flags. However, computing a + // BE count requires calling getAddRecExpr, so we may not yet have a + // meaningful BE count at this point (and if we don't, we'd be stuck + // with a SCEVCouldNotCompute as the cached BE count). + + Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); + + // Canonicalize nested AddRecs in by nesting them in order of loop depth. + if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { + const Loop *NestedLoop = NestedAR->getLoop(); + if (L->contains(NestedLoop) + ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) + : (!NestedLoop->contains(L) && + DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { + SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); + Operands[0] = NestedAR->getStart(); + // AddRecs require their operands be loop-invariant with respect to their + // loops. Don't perform this transformation if it would break this + // requirement. + bool AllInvariant = all_of( + Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + + if (AllInvariant) { + // Create a recurrence for the outer loop with the same step size. + // + // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the + // inner recurrence has the same property. + SCEV::NoWrapFlags OuterFlags = + maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); + + NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); + AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { + return isLoopInvariant(Op, NestedLoop); + }); + + if (AllInvariant) { + // Ok, both add recurrences are valid after the transformation. + // + // The inner recurrence keeps its NW flag but only keeps NUW/NSW if + // the outer recurrence has the same property. + SCEV::NoWrapFlags InnerFlags = + maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); + return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); + } + } + // Reset Operands to its original state. + Operands[0] = NestedAR; + } + } + + // Okay, it looks like we really DO need an addrec expr. Check to see if we + // already have one, otherwise create a new one. + return getOrCreateAddRecExpr(Operands, L, Flags); +} + +const SCEV * +ScalarEvolution::getGEPExpr(GEPOperator *GEP, + const SmallVectorImpl<const SCEV *> &IndexExprs) { + const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); + // getSCEV(Base)->getType() has the same address space as Base->getType() + // because SCEV::getType() preserves the address space. + Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); + // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP + // instruction to its SCEV, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in any + // context. This can be fixed similarly to how these flags are handled for + // adds. + SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW + : SCEV::FlagAnyWrap; + + const SCEV *TotalOffset = getZero(IntPtrTy); + // The array size is unimportant. The first thing we do on CurTy is getting + // its element type. + Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0); + for (const SCEV *IndexExpr : IndexExprs) { + // Compute the (potentially symbolic) offset in bytes for this index. + if (StructType *STy = dyn_cast<StructType>(CurTy)) { + // For a struct, add the member offset. + ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); + unsigned FieldNo = Index->getZExtValue(); + const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + + // Add the field offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, FieldOffset); + + // Update CurTy to the type of the field at Index. + CurTy = STy->getTypeAtIndex(Index); + } else { + // Update CurTy to its element type. + CurTy = cast<SequentialType>(CurTy)->getElementType(); + // For an array, add the element offset, explicitly scaled. + const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); + // Getelementptr indices are signed. + IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); + + // Multiply the index by the element size to compute the element offset. + const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); + + // Add the element offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, LocalOffset); + } + } + + // Add the total offset from all the GEP indices to the base. + return getAddExpr(BaseExpr, TotalOffset, Wrap); +} + +std::tuple<const SCEV *, FoldingSetNodeID, void *> +ScalarEvolution::findExistingSCEVInCache(int SCEVType, + ArrayRef<const SCEV *> Ops) { + FoldingSetNodeID ID; + void *IP = nullptr; + ID.AddInteger(SCEVType); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + return std::tuple<const SCEV *, FoldingSetNodeID, void *>( + UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); +} + +const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind, + SmallVectorImpl<const SCEV *> &Ops) { + assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && + "Operand types don't match!"); +#endif + + bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr; + bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, &LI, DT); + + // Check if we have created the same expression before. + if (const SCEV *S = std::get<0>(findExistingSCEVInCache(Kind, Ops))) { + return S; + } + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + auto FoldOp = [&](const APInt &LHS, const APInt &RHS) { + if (Kind == scSMaxExpr) + return APIntOps::smax(LHS, RHS); + else if (Kind == scSMinExpr) + return APIntOps::smin(LHS, RHS); + else if (Kind == scUMaxExpr) + return APIntOps::umax(LHS, RHS); + else if (Kind == scUMinExpr) + return APIntOps::umin(LHS, RHS); + llvm_unreachable("Unknown SCEV min/max opcode"); + }; + + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + bool IsMinV = LHSC->getValue()->isMinValue(IsSigned); + bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned); + + if (IsMax ? IsMinV : IsMaxV) { + // If we are left with a constant minimum(/maximum)-int, strip it off. + Ops.erase(Ops.begin()); + --Idx; + } else if (IsMax ? IsMaxV : IsMinV) { + // If we have a max(/min) with a constant maximum(/minimum)-int, + // it will always be the extremum. + return LHSC; + } + + if (Ops.size() == 1) return Ops[0]; + } + + // Find the first operation of the same kind + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind) + ++Idx; + + // Check to see if one of the operands is of the same kind. If so, expand its + // operands onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedAny = false; + while (Ops[Idx]->getSCEVType() == Kind) { + const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]); + Ops.erase(Ops.begin()+Idx); + Ops.append(SMME->op_begin(), SMME->op_end()); + DeletedAny = true; + } + + if (DeletedAny) + return getMinMaxExpr(Kind, Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + llvm::CmpInst::Predicate GEPred = + IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; + llvm::CmpInst::Predicate LEPred = + IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred; + llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred; + for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) { + if (Ops[i] == Ops[i + 1] || + isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) { + // X op Y op Y --> X op Y + // X op Y --> X, if we know X, Y are ordered appropriately + Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2); + --i; + --e; + } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i], + Ops[i + 1])) { + // X op Y --> Y, if we know X, Y are ordered appropriately + Ops.erase(Ops.begin() + i, Ops.begin() + i + 1); + --i; + --e; + } + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced smax down to nothing!"); + + // Okay, it looks like we really DO need an expr. Check to see if we + // already have one, otherwise create a new one. + const SCEV *ExistingSCEV; + FoldingSetNodeID ID; + void *IP; + std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(Kind, Ops); + if (ExistingSCEV) + return ExistingSCEV; + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr( + ID.Intern(SCEVAllocator), static_cast<SCEVTypes>(Kind), O, Ops.size()); + + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return S; +} + +const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; + return getSMaxExpr(Ops); +} + +const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { + return getMinMaxExpr(scSMaxExpr, Ops); +} + +const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; + return getUMaxExpr(Ops); +} + +const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { + return getMinMaxExpr(scUMaxExpr, Ops); +} + +const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getSMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) { + return getMinMaxExpr(scSMinExpr, Ops); +} + +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getUMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) { + return getMinMaxExpr(scUMinExpr, Ops); +} + +const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { + // We can bypass creating a target-independent + // constant expression and then folding it back into a ConstantInt. + // This is just a compile-time optimization. + return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); +} + +const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, + StructType *STy, + unsigned FieldNo) { + // We can bypass creating a target-independent + // constant expression and then folding it back into a ConstantInt. + // This is just a compile-time optimization. + return getConstant( + IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); +} + +const SCEV *ScalarEvolution::getUnknown(Value *V) { + // Don't attempt to do anything other than create a SCEVUnknown object + // here. createSCEV only calls getUnknown after checking for all other + // interesting possibilities, and any other code that calls getUnknown + // is doing so in order to hide a value from SCEV canonicalization. + + FoldingSetNodeID ID; + ID.AddInteger(scUnknown); + ID.AddPointer(V); + void *IP = nullptr; + if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + assert(cast<SCEVUnknown>(S)->getValue() == V && + "Stale SCEVUnknown in uniquing map!"); + return S; + } + SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, + FirstUnknown); + FirstUnknown = cast<SCEVUnknown>(S); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + +//===----------------------------------------------------------------------===// +// Basic SCEV Analysis and PHI Idiom Recognition Code +// + +/// Test if values of the given type are analyzable within the SCEV +/// framework. This primarily includes integer types, and it can optionally +/// include pointer types if the ScalarEvolution class has access to +/// target-specific information. +bool ScalarEvolution::isSCEVable(Type *Ty) const { + // Integers and pointers are always SCEVable. + return Ty->isIntOrPtrTy(); +} + +/// Return the size in bits of the specified type, for which isSCEVable must +/// return true. +uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { + assert(isSCEVable(Ty) && "Type is not SCEVable!"); + if (Ty->isPointerTy()) + return getDataLayout().getIndexTypeSizeInBits(Ty); + return getDataLayout().getTypeSizeInBits(Ty); +} + +/// Return a type with the same bitwidth as the given type and which represents +/// how SCEV will treat the given type, for which isSCEVable must return +/// true. For pointer types, this is the pointer-sized integer type. +Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { + assert(isSCEVable(Ty) && "Type is not SCEVable!"); + + if (Ty->isIntegerTy()) + return Ty; + + // The only other support type is pointer. + assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); + return getDataLayout().getIntPtrType(Ty); +} + +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + +const SCEV *ScalarEvolution::getCouldNotCompute() { + return CouldNotCompute.get(); +} + +bool ScalarEvolution::checkValidity(const SCEV *S) const { + bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { + auto *SU = dyn_cast<SCEVUnknown>(S); + return SU && SU->getValue() == nullptr; + }); + + return !ContainsNulls; +} + +bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { + HasRecMapType::iterator I = HasRecMap.find(S); + if (I != HasRecMap.end()) + return I->second; + + bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); + HasRecMap.insert({S, FoundAddRec}); + return FoundAddRec; +} + +/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. +/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an +/// offset I, then return {S', I}, else return {\p S, nullptr}. +static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) { + const auto *Add = dyn_cast<SCEVAddExpr>(S); + if (!Add) + return {S, nullptr}; + + if (Add->getNumOperands() != 2) + return {S, nullptr}; + + auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0)); + if (!ConstOp) + return {S, nullptr}; + + return {Add->getOperand(1), ConstOp->getValue()}; +} + +/// Return the ValueOffsetPair set for \p S. \p S can be represented +/// by the value and offset from any ValueOffsetPair in the set. +SetVector<ScalarEvolution::ValueOffsetPair> * +ScalarEvolution::getSCEVValues(const SCEV *S) { + ExprValueMapType::iterator SI = ExprValueMap.find_as(S); + if (SI == ExprValueMap.end()) + return nullptr; +#ifndef NDEBUG + if (VerifySCEVMap) { + // Check there is no dangling Value in the set returned. + for (const auto &VE : SI->second) + assert(ValueExprMap.count(VE.first)); + } +#endif + return &SI->second; +} + +/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) +/// cannot be used separately. eraseValueFromMap should be used to remove +/// V from ValueExprMap and ExprValueMap at the same time. +void ScalarEvolution::eraseValueFromMap(Value *V) { + ValueExprMapType::iterator I = ValueExprMap.find_as(V); + if (I != ValueExprMap.end()) { + const SCEV *S = I->second; + // Remove {V, 0} from the set of ExprValueMap[S] + if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S)) + SV->remove({V, nullptr}); + + // Remove {V, Offset} from the set of ExprValueMap[Stripped] + const SCEV *Stripped; + ConstantInt *Offset; + std::tie(Stripped, Offset) = splitAddExpr(S); + if (Offset != nullptr) { + if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped)) + SV->remove({V, Offset}); + } + ValueExprMap.erase(V); + } +} + +/// Check whether value has nuw/nsw/exact set but SCEV does not. +/// TODO: In reality it is better to check the poison recursively +/// but this is better than nothing. +static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) { + if (auto *I = dyn_cast<Instruction>(V)) { + if (isa<OverflowingBinaryOperator>(I)) { + if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) { + if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap()) + return true; + if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap()) + return true; + } + } else if (isa<PossiblyExactOperator>(I) && I->isExact()) + return true; + } + return false; +} + +/// Return an existing SCEV if it exists, otherwise analyze the expression and +/// create a new one. +const SCEV *ScalarEvolution::getSCEV(Value *V) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + + const SCEV *S = getExistingSCEV(V); + if (S == nullptr) { + S = createSCEV(V); + // During PHI resolution, it is possible to create two SCEVs for the same + // V, so it is needed to double check whether V->S is inserted into + // ValueExprMap before insert S->{V, 0} into ExprValueMap. + std::pair<ValueExprMapType::iterator, bool> Pair = + ValueExprMap.insert({SCEVCallbackVH(V, this), S}); + if (Pair.second && !SCEVLostPoisonFlags(S, V)) { + ExprValueMap[S].insert({V, nullptr}); + + // If S == Stripped + Offset, add Stripped -> {V, Offset} into + // ExprValueMap. + const SCEV *Stripped = S; + ConstantInt *Offset = nullptr; + std::tie(Stripped, Offset) = splitAddExpr(S); + // If stripped is SCEVUnknown, don't bother to save + // Stripped -> {V, offset}. It doesn't simplify and sometimes even + // increase the complexity of the expansion code. + // If V is GetElementPtrInst, don't save Stripped -> {V, offset} + // because it may generate add/sub instead of GEP in SCEV expansion. + if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) && + !isa<GetElementPtrInst>(V)) + ExprValueMap[Stripped].insert({V, Offset}); + } + } + return S; +} + +const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + + ValueExprMapType::iterator I = ValueExprMap.find_as(V); + if (I != ValueExprMap.end()) { + const SCEV *S = I->second; + if (checkValidity(S)) + return S; + eraseValueFromMap(V); + forgetMemoizedResults(S); + } + return nullptr; +} + +/// Return a SCEV corresponding to -V = -1*V +const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, + SCEV::NoWrapFlags Flags) { + if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) + return getConstant( + cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue()))); + + Type *Ty = V->getType(); + Ty = getEffectiveSCEVType(Ty); + return getMulExpr( + V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); +} + +/// If Expr computes ~A, return A else return nullptr +static const SCEV *MatchNotExpr(const SCEV *Expr) { + const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); + if (!Add || Add->getNumOperands() != 2 || + !Add->getOperand(0)->isAllOnesValue()) + return nullptr; + + const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); + if (!AddRHS || AddRHS->getNumOperands() != 2 || + !AddRHS->getOperand(0)->isAllOnesValue()) + return nullptr; + + return AddRHS->getOperand(1); +} + +/// Return a SCEV corresponding to ~V = -1-V +const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { + if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) + return getConstant( + cast<ConstantInt>(ConstantExpr::getNot(VC->getValue()))); + + // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y) + if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) { + auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) { + SmallVector<const SCEV *, 2> MatchedOperands; + for (const SCEV *Operand : MME->operands()) { + const SCEV *Matched = MatchNotExpr(Operand); + if (!Matched) + return (const SCEV *)nullptr; + MatchedOperands.push_back(Matched); + } + return getMinMaxExpr( + SCEVMinMaxExpr::negate(static_cast<SCEVTypes>(MME->getSCEVType())), + MatchedOperands); + }; + if (const SCEV *Replaced = MatchMinMaxNegation(MME)) + return Replaced; + } + + Type *Ty = V->getType(); + Ty = getEffectiveSCEVType(Ty); + const SCEV *AllOnes = + getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))); + return getMinusSCEV(AllOnes, V); +} + +const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, + SCEV::NoWrapFlags Flags, + unsigned Depth) { + // Fast path: X - X --> 0. + if (LHS == RHS) + return getZero(LHS->getType()); + + // We represent LHS - RHS as LHS + (-1)*RHS. This transformation + // makes it so that we cannot make much use of NUW. + auto AddFlags = SCEV::FlagAnyWrap; + const bool RHSIsNotMinSigned = + !getSignedRangeMin(RHS).isMinSignedValue(); + if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { + // Let M be the minimum representable signed value. Then (-1)*RHS + // signed-wraps if and only if RHS is M. That can happen even for + // a NSW subtraction because e.g. (-1)*M signed-wraps even though + // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + + // (-1)*RHS, we need to prove that RHS != M. + // + // If LHS is non-negative and we know that LHS - RHS does not + // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap + // either by proving that RHS > M or that LHS >= 0. + if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { + AddFlags = SCEV::FlagNSW; + } + } + + // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - + // RHS is NSW and LHS >= 0. + // + // The difficulty here is that the NSW flag may have been proven + // relative to a loop that is to be found in a recurrence in LHS and + // not in RHS. Applying NSW to (-1)*M may then let the NSW have a + // larger scope than intended. + auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + + return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); +} + +const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, + unsigned Depth) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) + return getTruncateExpr(V, Ty, Depth); + return getZeroExtendExpr(V, Ty, Depth); +} + +const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, + unsigned Depth) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) + return getTruncateExpr(V, Ty, Depth); + return getSignExtendExpr(V, Ty, Depth); +} + +const SCEV * +ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot noop or zero extend with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && + "getNoopOrZeroExtend cannot truncate!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getZeroExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot noop or sign extend with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && + "getNoopOrSignExtend cannot truncate!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getSignExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot noop or any extend with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && + "getNoopOrAnyExtend cannot truncate!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getAnyExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + "Cannot truncate or noop with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && + "getTruncateOrNoop cannot extend!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getTruncateExpr(V, Ty); +} + +const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + const SCEV *PromotedLHS = LHS; + const SCEV *PromotedRHS = RHS; + + if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) + PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); + else + PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + + return getUMaxExpr(PromotedLHS, PromotedRHS); +} + +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getUMinFromMismatchedTypes(Ops); +} + +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( + SmallVectorImpl<const SCEV *> &Ops) { + assert(!Ops.empty() && "At least one operand must be!"); + // Trivial case. + if (Ops.size() == 1) + return Ops[0]; + + // Find the max type first. + Type *MaxType = nullptr; + for (auto *S : Ops) + if (MaxType) + MaxType = getWiderType(MaxType, S->getType()); + else + MaxType = S->getType(); + + // Extend all ops to max type. + SmallVector<const SCEV *, 2> PromotedOps; + for (auto *S : Ops) + PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); + + // Generate umin. + return getUMinExpr(PromotedOps); +} + +const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { + // A pointer operand may evaluate to a nonpointer expression, such as null. + if (!V->getType()->isPointerTy()) + return V; + + if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { + return getPointerBase(Cast->getOperand()); + } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { + const SCEV *PtrOp = nullptr; + for (const SCEV *NAryOp : NAry->operands()) { + if (NAryOp->getType()->isPointerTy()) { + // Cannot find the base of an expression with multiple pointer operands. + if (PtrOp) + return V; + PtrOp = NAryOp; + } + } + if (!PtrOp) + return V; + return getPointerBase(PtrOp); + } + return V; +} + +/// Push users of the given Instruction onto the given Worklist. +static void +PushDefUseChildren(Instruction *I, + SmallVectorImpl<Instruction *> &Worklist) { + // Push the def-use children onto the Worklist stack. + for (User *U : I->users()) + Worklist.push_back(cast<Instruction>(U)); +} + +void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { + SmallVector<Instruction *, 16> Worklist; + PushDefUseChildren(PN, Worklist); + + SmallPtrSet<Instruction *, 8> Visited; + Visited.insert(PN); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I).second) + continue; + + auto It = ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + const SCEV *Old = It->second; + + // Short-circuit the def-use traversal if the symbolic name + // ceases to appear in expressions. + if (Old != SymName && !hasOperand(Old, SymName)) + continue; + + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, it's a PHI that's in the progress of being computed + // by createNodeForPHI, or it's a single-value PHI. In the first case, + // additional loop trip count information isn't going to change anything. + // In the second case, createNodeForPHI will perform the necessary + // updates on its own when it gets to that point. In the third, we do + // want to forget the SCEVUnknown. + if (!isa<PHINode>(I) || + !isa<SCEVUnknown>(Old) || + (I != PN && Old == SymName)) { + eraseValueFromMap(It->first); + forgetMemoizedResults(Old); + } + } + + PushDefUseChildren(I, Worklist); + } +} + +namespace { + +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start +/// expression in case its Loop is L. If it is not L then +/// if IgnoreOtherLoops is true then use AddRec itself +/// otherwise rewrite cannot be done. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = true) { + SCEVInitRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(S); + if (Rewriter.hasSeenLoopVariantSCEVUnknown()) + return SE.getCouldNotCompute(); + return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops + ? SE.getCouldNotCompute() + : Result; + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!SE.isLoopInvariant(Expr, L)) + SeenLoopVariantSCEVUnknown = true; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only re-write AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getStart(); + SeenOtherLoops = true; + return Expr; + } + + bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + + bool hasSeenOtherLoops() { return SeenOtherLoops; } + +private: + explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + + const Loop *L; + bool SeenLoopVariantSCEVUnknown = false; + bool SeenOtherLoops = false; +}; + +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post +/// increment expression in case its Loop is L. If it is not L then +/// use AddRec itself. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { + SCEVPostIncRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(S); + return Rewriter.hasSeenLoopVariantSCEVUnknown() + ? SE.getCouldNotCompute() + : Result; + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!SE.isLoopInvariant(Expr, L)) + SeenLoopVariantSCEVUnknown = true; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only re-write AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getPostIncExpr(SE); + SeenOtherLoops = true; + return Expr; + } + + bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + + bool hasSeenOtherLoops() { return SeenOtherLoops; } + +private: + explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + + const Loop *L; + bool SeenLoopVariantSCEVUnknown = false; + bool SeenOtherLoops = false; +}; + +/// This class evaluates the compare condition by matching it against the +/// condition of loop latch. If there is a match we assume a true value +/// for the condition while building SCEV nodes. +class SCEVBackedgeConditionFolder + : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + bool IsPosBECond = false; + Value *BECond = nullptr; + if (BasicBlock *Latch = L->getLoopLatch()) { + BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); + if (BI && BI->isConditional()) { + assert(BI->getSuccessor(0) != BI->getSuccessor(1) && + "Both outgoing branches should not target same header!"); + BECond = BI->getCondition(); + IsPosBECond = BI->getSuccessor(0) == L->getHeader(); + } else { + return S; + } + } + SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); + return Rewriter.visit(S); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + const SCEV *Result = Expr; + bool InvariantF = SE.isLoopInvariant(Expr, L); + + if (!InvariantF) { + Instruction *I = cast<Instruction>(Expr->getValue()); + switch (I->getOpcode()) { + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + Optional<const SCEV *> Res = + compareWithBackedgeCondition(SI->getCondition()); + if (Res.hasValue()) { + bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); + Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); + } + break; + } + default: { + Optional<const SCEV *> Res = compareWithBackedgeCondition(I); + if (Res.hasValue()) + Result = Res.getValue(); + break; + } + } + } + return Result; + } + +private: + explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, + bool IsPosBECond, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), + IsPositiveBECond(IsPosBECond) {} + + Optional<const SCEV *> compareWithBackedgeCondition(Value *IC); + + const Loop *L; + /// Loop back condition. + Value *BackedgeCond = nullptr; + /// Set to true if loop back is on positive branch condition. + bool IsPositiveBECond; +}; + +Optional<const SCEV *> +SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { + + // If value matches the backedge condition for loop latch, + // then return a constant evolution node based on loopback + // branch taken. + if (BackedgeCond == IC) + return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) + : SE.getZero(Type::getInt1Ty(SE.getContext())); + return None; +} + +class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + SCEVShiftRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(S); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // Only allow AddRecExprs for this loop. + if (!SE.isLoopInvariant(Expr, L)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (Expr->getLoop() == L && Expr->isAffine()) + return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); + Valid = false; + return Expr; + } + + bool isValid() { return Valid; } + +private: + explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + + const Loop *L; + bool Valid = true; +}; + +} // end anonymous namespace + +SCEV::NoWrapFlags +ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { + if (!AR->isAffine()) + return SCEV::FlagAnyWrap; + + using OBO = OverflowingBinaryOperator; + + SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + + if (!AR->hasNoSignedWrap()) { + ConstantRange AddRecRange = getSignedRange(AR); + ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); + + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoSignedWrap); + if (NSWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); + } + + if (!AR->hasNoUnsignedWrap()) { + ConstantRange AddRecRange = getUnsignedRange(AR); + ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); + + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoUnsignedWrap); + if (NUWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); + } + + return Result; +} + +namespace { + +/// Represents an abstract binary operation. This may exist as a +/// normal instruction or constant expression, or may have been +/// derived from an expression tree. +struct BinaryOp { + unsigned Opcode; + Value *LHS; + Value *RHS; + bool IsNSW = false; + bool IsNUW = false; + + /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or + /// constant expression. + Operator *Op = nullptr; + + explicit BinaryOp(Operator *Op) + : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), + Op(Op) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { + IsNSW = OBO->hasNoSignedWrap(); + IsNUW = OBO->hasNoUnsignedWrap(); + } + } + + explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, + bool IsNUW = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {} +}; + +} // end anonymous namespace + +/// Try to map \p V into a BinaryOp, and return \c None on failure. +static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return None; + + // Implementation detail: all the cleverness here should happen without + // creating new SCEV expressions -- our caller knowns tricks to avoid creating + // SCEV expressions when possible, and we should not break that. + + switch (Op->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::And: + case Instruction::Or: + case Instruction::AShr: + case Instruction::Shl: + return BinaryOp(Op); + + case Instruction::Xor: + if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) + // If the RHS of the xor is a signmask, then this is just an add. + // Instcombine turns add of signmask into xor as a strength reduction step. + if (RHSC->getValue().isSignMask()) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); + return BinaryOp(Op); + + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().ult(BitWidth)) { + Constant *X = + ConstantInt::get(SA->getContext(), + APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); + } + } + return BinaryOp(Op); + + case Instruction::ExtractValue: { + auto *EVI = cast<ExtractValueInst>(Op); + if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) + break; + + auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()); + if (!WO) + break; + + Instruction::BinaryOps BinOp = WO->getBinaryOp(); + bool Signed = WO->isSigned(); + // TODO: Should add nuw/nsw flags for mul as well. + if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT)) + return BinaryOp(BinOp, WO->getLHS(), WO->getRHS()); + + // Now that we know that all uses of the arithmetic-result component of + // CI are guarded by the overflow check, we can go ahead and pretend + // that the arithmetic is non-overflowing. + return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(), + /* IsNSW = */ Signed, /* IsNUW = */ !Signed); + } + + default: + break; + } + + return None; +} + +/// Helper function to createAddRecFromPHIWithCasts. We have a phi +/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via +/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the +/// way. This function checks if \p Op, an operand of this SCEVAddExpr, +/// follows one of the following patterns: +/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// If the SCEV expression of \p Op conforms with one of the expected patterns +/// we return the type of the truncation operation, and indicate whether the +/// truncated type should be treated as signed/unsigned by setting +/// \p Signed to true/false, respectively. +static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, + bool &Signed, ScalarEvolution &SE) { + // The case where Op == SymbolicPHI (that is, with no type conversions on + // the way) is handled by the regular add recurrence creating logic and + // would have already been triggered in createAddRecForPHI. Reaching it here + // means that createAddRecFromPHI had failed for this PHI before (e.g., + // because one of the other operands of the SCEVAddExpr updating this PHI is + // not invariant). + // + // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in + // this case predicates that allow us to prove that Op == SymbolicPHI will + // be added. + if (Op == SymbolicPHI) + return nullptr; + + unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType()); + unsigned NewBits = SE.getTypeSizeInBits(Op->getType()); + if (SourceBits != NewBits) + return nullptr; + + const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op); + const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op); + if (!SExt && !ZExt) + return nullptr; + const SCEVTruncateExpr *Trunc = + SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand()) + : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand()); + if (!Trunc) + return nullptr; + const SCEV *X = Trunc->getOperand(); + if (X != SymbolicPHI) + return nullptr; + Signed = SExt != nullptr; + return Trunc->getType(); +} + +static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { + if (!PN->getType()->isIntegerTy()) + return nullptr; + const Loop *L = LI.getLoopFor(PN->getParent()); + if (!L || L->getHeader() != PN->getParent()) + return nullptr; + return L; +} + +// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the +// computation that updates the phi follows the following pattern: +// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum +// which correspond to a phi->trunc->sext/zext->add->phi update chain. +// If so, try to see if it can be rewritten as an AddRecExpr under some +// Predicates. If successful, return them as a pair. Also cache the results +// of the analysis. +// +// Example usage scenario: +// Say the Rewriter is called for the following SCEV: +// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +// where: +// %X = phi i64 (%Start, %BEValue) +// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X), +// and call this function with %SymbolicPHI = %X. +// +// The analysis will find that the value coming around the backedge has +// the following SCEV: +// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +// Upon concluding that this matches the desired pattern, the function +// will return the pair {NewAddRec, SmallPredsVec} where: +// NewAddRec = {%Start,+,%Step} +// SmallPredsVec = {P1, P2, P3} as follows: +// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw> +// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64) +// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64) +// The returned pair means that SymbolicPHI can be rewritten into NewAddRec +// under the predicates {P1,P2,P3}. +// This predicated rewrite will be cached in PredicatedSCEVRewrites: +// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} +// +// TODO's: +// +// 1) Extend the Induction descriptor to also support inductions that involve +// casts: When needed (namely, when we are called in the context of the +// vectorizer induction analysis), a Set of cast instructions will be +// populated by this method, and provided back to isInductionPHI. This is +// needed to allow the vectorizer to properly record them to be ignored by +// the cost model and to avoid vectorizing them (otherwise these casts, +// which are redundant under the runtime overflow checks, will be +// vectorized, which can be costly). +// +// 2) Support additional induction/PHISCEV patterns: We also want to support +// inductions where the sext-trunc / zext-trunc operations (partly) occur +// after the induction update operation (the induction increment): +// +// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) +// which correspond to a phi->add->trunc->sext/zext->phi update chain. +// +// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) +// which correspond to a phi->trunc->add->sext/zext->phi update chain. +// +// 3) Outline common code with createAddRecFromPHI to avoid duplication. +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { + SmallVector<const SCEVPredicate *, 3> Predicates; + + // *** Part1: Analyze if we have a phi-with-cast pattern for which we can + // return an AddRec expression under some predicate. + + auto *PN = cast<PHINode>(SymbolicPHI->getValue()); + const Loop *L = isIntegerLoopHeaderPHI(PN, LI); + assert(L && "Expecting an integer loop header phi"); + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi as an addrec if it has a unique entry value and a unique + // backedge value. + Value *BEValueV = nullptr, *StartValueV = nullptr; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *V = PN->getIncomingValue(i); + if (L->contains(PN->getIncomingBlock(i))) { + if (!BEValueV) { + BEValueV = V; + } else if (BEValueV != V) { + BEValueV = nullptr; + break; + } + } else if (!StartValueV) { + StartValueV = V; + } else if (StartValueV != V) { + StartValueV = nullptr; + break; + } + } + if (!BEValueV || !StartValueV) + return None; + + const SCEV *BEValue = getSCEV(BEValueV); + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, possibly with casts that we can ignore under + // an appropriate runtime guard, then we found a simple induction variable! + const auto *Add = dyn_cast<SCEVAddExpr>(BEValue); + if (!Add) + return None; + + // If there is a single occurrence of the symbolic value, possibly + // casted, replace it with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + Type *TruncTy = nullptr; + bool Signed; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if ((TruncTy = + isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this))) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex == Add->getNumOperands()) + return None; + + // Create an add with everything but the specified operand. + SmallVector<const SCEV *, 8> Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = getAddExpr(Ops); + + // The runtime checks will not be valid if the step amount is + // varying inside the loop. + if (!isLoopInvariant(Accum, L)) + return None; + + // *** Part2: Create the predicates + + // Analysis was successful: we have a phi-with-cast pattern for which we + // can return an AddRec expression under the following predicates: + // + // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum) + // fits within the truncated type (does not overflow) for i = 0 to n-1. + // P2: An Equal predicate that guarantees that + // Start = (Ext ix (Trunc iy (Start) to ix) to iy) + // P3: An Equal predicate that guarantees that + // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy) + // + // As we next prove, the above predicates guarantee that: + // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy) + // + // + // More formally, we want to prove that: + // Expr(i+1) = Start + (i+1) * Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // + // Given that: + // 1) Expr(0) = Start + // 2) Expr(1) = Start + Accum + // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2 + // 3) Induction hypothesis (step i): + // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + // + // Proof: + // Expr(i+1) = + // = Start + (i+1)*Accum + // = (Start + i*Accum) + Accum + // = Expr(i) + Accum + // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum + // :: from step i + // + // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum + // + // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + // + (Ext ix (Trunc iy (Accum) to ix) to iy) + // + Accum :: from P3 + // + // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) + // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y) + // + // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // + // By induction, the same applies to all iterations 1<=i<n: + // + + // Create a truncated addrec for which we will add a no overflow check (P1). + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = + getAddRecExpr(getTruncateExpr(StartVal, TruncTy), + getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); + + // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr. + // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV + // will be constant. + // + // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't + // add P1. + if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { + SCEVWrapPredicate::IncrementWrapFlags AddedFlags = + Signed ? SCEVWrapPredicate::IncrementNSSW + : SCEVWrapPredicate::IncrementNUSW; + const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); + Predicates.push_back(AddRecPred); + } + + // Create the Equal Predicates P2,P3: + + // It is possible that the predicates P2 and/or P3 are computable at + // compile time due to StartVal and/or Accum being constants. + // If either one is, then we can check that now and escape if either P2 + // or P3 is false. + + // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) + // for each of StartVal and Accum + auto getExtendedExpr = [&](const SCEV *Expr, + bool CreateSignExtend) -> const SCEV * { + assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); + const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); + const SCEV *ExtendedExpr = + CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) + : getZeroExtendExpr(TruncatedExpr, Expr->getType()); + return ExtendedExpr; + }; + + // Given: + // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy + // = getExtendedExpr(Expr) + // Determine whether the predicate P: Expr == ExtendedExpr + // is known to be false at compile time + auto PredIsKnownFalse = [&](const SCEV *Expr, + const SCEV *ExtendedExpr) -> bool { + return Expr != ExtendedExpr && + isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); + }; + + const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); + if (PredIsKnownFalse(StartVal, StartExtended)) { + LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); + return None; + } + + // The Step is always Signed (because the overflow checks are either + // NSSW or NUSW) + const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); + if (PredIsKnownFalse(Accum, AccumExtended)) { + LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); + return None; + } + + auto AppendPredicate = [&](const SCEV *Expr, + const SCEV *ExtendedExpr) -> void { + if (Expr != ExtendedExpr && + !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { + const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); + LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred); + Predicates.push_back(Pred); + } + }; + + AppendPredicate(StartVal, StartExtended); + AppendPredicate(Accum, AccumExtended); + + // *** Part3: Predicates are ready. Now go ahead and create the new addrec in + // which the casts had been folded away. The caller can rewrite SymbolicPHI + // into NewAR if it will also add the runtime overflow checks specified in + // Predicates. + auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); + + std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = + std::make_pair(NewAR, Predicates); + // Remember the result of the analysis for this SCEV at this locayyytion. + PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; + return PredRewrite; +} + +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { + auto *PN = cast<PHINode>(SymbolicPHI->getValue()); + const Loop *L = isIntegerLoopHeaderPHI(PN, LI); + if (!L) + return None; + + // Check to see if we already analyzed this PHI. + auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); + if (I != PredicatedSCEVRewrites.end()) { + std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite = + I->second; + // Analysis was done before and failed to create an AddRec: + if (Rewrite.first == SymbolicPHI) + return None; + // Analysis was done before and succeeded to create an AddRec under + // a predicate: + assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec"); + assert(!(Rewrite.second).empty() && "Expected to find Predicates"); + return Rewrite; + } + + Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> + Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); + + // Record in the cache that the analysis failed + if (!Rewrite) { + SmallVector<const SCEVPredicate *, 3> Predicates; + PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates}; + return None; + } + + return Rewrite; +} + +// FIXME: This utility is currently required because the Rewriter currently +// does not rewrite this expression: +// {0, +, (sext ix (trunc iy to ix) to iy)} +// into {0, +, %step}, +// even when the following Equal predicate exists: +// "%step == (sext ix (trunc iy to ix) to iy)". +bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( + const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const { + if (AR1 == AR2) + return true; + + auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { + if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) && + !Preds.implies(SE.getEqualPredicate(Expr2, Expr1))) + return false; + return true; + }; + + if (!areExprsEqual(AR1->getStart(), AR2->getStart()) || + !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE))) + return false; + return true; +} + +/// A helper function for createAddRecFromPHI to handle simple cases. +/// +/// This function tries to find an AddRec expression for the simplest (yet most +/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). +/// If it fails, createAddRecFromPHI will use a more general, but slow, +/// technique for finding the AddRec expression. +const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, + Value *BEValueV, + Value *StartValueV) { + const Loop *L = LI.getLoopFor(PN->getParent()); + assert(L && L->getHeader() == PN->getParent()); + assert(BEValueV && StartValueV); + + auto BO = MatchBinaryOp(BEValueV, DT); + if (!BO) + return nullptr; + + if (BO->Opcode != Instruction::Add) + return nullptr; + + const SCEV *Accum = nullptr; + if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) + Accum = getSCEV(BO->RHS); + else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) + Accum = getSCEV(BO->LHS); + + if (!Accum) + return nullptr; + + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); + + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; +} + +const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { + const Loop *L = LI.getLoopFor(PN->getParent()); + if (!L || L->getHeader() != PN->getParent()) + return nullptr; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi as an addrec if it has a unique entry value and a unique + // backedge value. + Value *BEValueV = nullptr, *StartValueV = nullptr; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *V = PN->getIncomingValue(i); + if (L->contains(PN->getIncomingBlock(i))) { + if (!BEValueV) { + BEValueV = V; + } else if (BEValueV != V) { + BEValueV = nullptr; + break; + } + } else if (!StartValueV) { + StartValueV = V; + } else if (StartValueV != V) { + StartValueV = nullptr; + break; + } + } + if (!BEValueV || !StartValueV) + return nullptr; + + assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && + "PHI node already processed?"); + + // First, try to find AddRec expression without creating a fictituos symbolic + // value for PN. + if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) + return S; + + // Handle PHI node value symbolically. + const SCEV *SymbolicName = getUnknown(PN); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + + // Using this symbolic name for the PHI, analyze the value coming around + // the back-edge. + const SCEV *BEValue = getSCEV(BEValueV); + + // NOTE: If BEValue is loop invariant, we know that the PHI node just + // has a special value for the first iteration of the loop. + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, then we found a simple induction variable! + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { + // If there is a single occurrence of the symbolic value, replace it + // with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicName) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex != Add->getNumOperands()) { + // Create an add with everything but the specified operand. + SmallVector<const SCEV *, 8> Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), + L, *this)); + const SCEV *Accum = getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (isLoopInvariant(Accum, L) || + (isa<SCEVAddRecExpr>(Accum) && + cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); + } + } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { + // If the increment is an inbounds GEP, then we know the address + // space cannot be wrapped around. We cannot make any guarantee + // about signed or unsigned overflow because pointers are + // unsigned but we may have a negative index from the base + // pointer. We can guarantee that no unsigned wrap occurs if the + // indices form a positive value. + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { + Flags = setFlags(Flags, SCEV::FlagNW); + + const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); + if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) + Flags = setFlags(Flags, SCEV::FlagNUW); + } + + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. + } + + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; + } + } + } else { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { + const SCEV *StartVal = getSCEV(StartValueV); + if (Start == StartVal) { + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; + } + } + } + + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + eraseValueFromMap(PN); + + return nullptr; +} + +// Checks if the SCEV S is available at BB. S is considered available at BB +// if S can be materialized at BB without introducing a fault. +static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, + BasicBlock *BB) { + struct CheckAvailable { + bool TraversalDone = false; + bool Available = true; + + const Loop *L = nullptr; // The loop BB is in (can be nullptr) + BasicBlock *BB = nullptr; + DominatorTree &DT; + + CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) + : L(L), BB(BB), DT(DT) {} + + bool setUnavailable() { + TraversalDone = true; + Available = false; + return false; + } + + bool follow(const SCEV *S) { + switch (S->getSCEVType()) { + case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: + case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: + // These expressions are available if their operand(s) is/are. + return true; + + case scAddRecExpr: { + // We allow add recurrences that are on the loop BB is in, or some + // outer loop. This guarantees availability because the value of the + // add recurrence at BB is simply the "current" value of the induction + // variable. We can relax this in the future; for instance an add + // recurrence on a sibling dominating loop is also available at BB. + const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop(); + if (L && (ARLoop == L || ARLoop->contains(L))) + return true; + + return setUnavailable(); + } + + case scUnknown: { + // For SCEVUnknown, we check for simple dominance. + const auto *SU = cast<SCEVUnknown>(S); + Value *V = SU->getValue(); + + if (isa<Argument>(V)) + return false; + + if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB)) + return false; + + return setUnavailable(); + } + + case scUDivExpr: + case scCouldNotCompute: + // We do not try to smart about these at all. + return setUnavailable(); + } + llvm_unreachable("switch should be fully covered!"); + } + + bool isDone() { return TraversalDone; } + }; + + CheckAvailable CA(L, BB, DT); + SCEVTraversal<CheckAvailable> ST(CA); + + ST.visitAll(S); + return CA.Available; +} + +// Try to match a control flow sequence that branches out at BI and merges back +// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful +// match. +static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, + Value *&C, Value *&LHS, Value *&RHS) { + C = BI->getCondition(); + + BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); + BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); + + if (!LeftEdge.isSingleEdge()) + return false; + + assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); + + Use &LeftUse = Merge->getOperandUse(0); + Use &RightUse = Merge->getOperandUse(1); + + if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { + LHS = LeftUse; + RHS = RightUse; + return true; + } + + if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { + LHS = RightUse; + RHS = LeftUse; + return true; + } + + return false; +} + +const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { + auto IsReachable = + [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; + if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { + const Loop *L = LI.getLoopFor(PN->getParent()); + + // We don't want to break LCSSA, even in a SCEV expression tree. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) + return nullptr; + + // Try to match + // + // br %cond, label %left, label %right + // left: + // br label %merge + // right: + // br label %merge + // merge: + // V = phi [ %x, %left ], [ %y, %right ] + // + // as "select %cond, %x, %y" + + BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); + assert(IDom && "At least the entry block should dominate PN"); + + auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); + Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; + + if (BI && BI->isConditional() && + BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && + IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && + IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) + return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); + } + + return nullptr; +} + +const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { + if (const SCEV *S = createAddRecFromPHI(PN)) + return S; + + if (const SCEV *S = createNodeFromSelectLikePHI(PN)) + return S; + + // If the PHI has a single incoming value, follow that value, unless the + // PHI's incoming blocks are in a different loop, in which case doing so + // risks breaking LCSSA form. Instcombine would normally zap these, but + // it doesn't have DominatorTree information, so it may miss cases. + if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) + if (LI.replacementPreservesLCSSAForm(PN, V)) + return getSCEV(V); + + // If it's not a loop phi, we can't handle it yet. + return getUnknown(PN); +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { + // Handle "constant" branch or select. This can occur for instance when a + // loop pass transforms an inner loop and moves on to process the outer loop. + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + return getSCEV(CI->isOne() ? TrueVal : FalseVal); + + // Try to match some simple smax or umax patterns. + auto *ICI = dyn_cast<ICmpInst>(Cond); + if (!ICI) + return getUnknown(I); + + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + // a >s b ? a+x : b+x -> smax(a, b)+x + // a >s b ? b+x : a+x -> smin(a, b)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, RS); + if (LDiff == RDiff) + return getAddExpr(getSMaxExpr(LS, RS), LDiff); + LDiff = getMinusSCEV(LA, RS); + RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getSMinExpr(LS, RS), LDiff); + } + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // a >u b ? a+x : b+x -> umax(a, b)+x + // a >u b ? b+x : a+x -> umin(a, b)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, RS); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(LS, RS), LDiff); + LDiff = getMinusSCEV(LA, RS); + RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getUMinExpr(LS, RS), LDiff); + } + break; + case ICmpInst::ICMP_NE: + // n != 0 ? n+x : 1+x -> umax(n, 1)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && + isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { + const SCEV *One = getOne(I->getType()); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, One); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(One, LS), LDiff); + } + break; + case ICmpInst::ICMP_EQ: + // n == 0 ? 1+x : n+x -> umax(n, 1)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && + isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { + const SCEV *One = getOne(I->getType()); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, One); + const SCEV *RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(One, LS), LDiff); + } + break; + default: + break; + } + + return getUnknown(I); +} + +/// Expand GEP instructions into add and multiply operations. This allows them +/// to be analyzed by regular SCEV code. +const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { + // Don't attempt to analyze GEPs over unsized objects. + if (!GEP->getSourceElementType()->isSized()) + return getUnknown(GEP); + + SmallVector<const SCEV *, 4> IndexExprs; + for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) + IndexExprs.push_back(getSCEV(*Index)); + return getGEPExpr(GEP, IndexExprs); +} + +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) + return C->getAPInt().countTrailingZeros(); + + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) + return std::min(GetMinTrailingZeros(T->getOperand()), + (uint32_t)getTypeSizeInBits(T->getType())); + + if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; + } + + if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; + } + + if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; + } + + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { + // The result is the sum of all operands results. + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = getTypeSizeInBits(M->getType()); + for (unsigned i = 1, e = M->getNumOperands(); + SumOpRes != BitWidth && i != e; ++i) + SumOpRes = + std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); + return SumOpRes; + } + + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; + } + + if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // For a SCEVUnknown, ask ValueTracking. + KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); + return Known.countMinTrailingZeros(); + } + + // SCEVUDivExpr + return 0; +} + +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { + auto I = MinTrailingZerosCache.find(S); + if (I != MinTrailingZerosCache.end()) + return I->second; + + uint32_t Result = GetMinTrailingZerosImpl(S); + auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + assert(InsertPair.second && "Should insert a new key"); + return InsertPair.first->second; +} + +/// Helper method to assign a range to V from metadata present in the IR. +static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { + if (Instruction *I = dyn_cast<Instruction>(V)) + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) + return getConstantRangeFromMetadata(*MD); + + return None; +} + +/// Determine the range for a particular SCEV. If SignHint is +/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges +/// with a "cleaner" unsigned (resp. signed) representation. +const ConstantRange & +ScalarEvolution::getRangeRef(const SCEV *S, + ScalarEvolution::RangeSignHint SignHint) { + DenseMap<const SCEV *, ConstantRange> &Cache = + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges + : SignedRanges; + ConstantRange::PreferredRangeType RangeType = + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED + ? ConstantRange::Unsigned : ConstantRange::Signed; + + // See if we've computed this range already. + DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S); + if (I != Cache.end()) + return I->second; + + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) + return setRange(C, SignHint, ConstantRange(C->getAPInt())); + + unsigned BitWidth = getTypeSizeInBits(S->getType()); + ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); + + // If the value has known zeros, the maximum value will have those known zeros + // as well. + uint32_t TZ = GetMinTrailingZeros(S); + if (TZ != 0) { + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + ConservativeResult = + ConstantRange(APInt::getMinValue(BitWidth), + APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + else + ConservativeResult = ConstantRange( + APInt::getSignedMinValue(BitWidth), + APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + } + + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + ConstantRange X = getRangeRef(Add->getOperand(0), SignHint); + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) + X = X.add(getRangeRef(Add->getOperand(i), SignHint)); + return setRange(Add, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { + ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint); + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) + X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint)); + return setRange(Mul, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { + ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) + X = X.smax(getRangeRef(SMax->getOperand(i), SignHint)); + return setRange(SMax, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { + ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) + X = X.umax(getRangeRef(UMax->getOperand(i), SignHint)); + return setRange(UMax, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVSMinExpr *SMin = dyn_cast<SCEVSMinExpr>(S)) { + ConstantRange X = getRangeRef(SMin->getOperand(0), SignHint); + for (unsigned i = 1, e = SMin->getNumOperands(); i != e; ++i) + X = X.smin(getRangeRef(SMin->getOperand(i), SignHint)); + return setRange(SMin, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVUMinExpr *UMin = dyn_cast<SCEVUMinExpr>(S)) { + ConstantRange X = getRangeRef(UMin->getOperand(0), SignHint); + for (unsigned i = 1, e = UMin->getNumOperands(); i != e; ++i) + X = X.umin(getRangeRef(UMin->getOperand(i), SignHint)); + return setRange(UMin, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + + if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { + ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint); + ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint); + return setRange(UDiv, SignHint, + ConservativeResult.intersectWith(X.udiv(Y), RangeType)); + } + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { + ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint); + return setRange(ZExt, SignHint, + ConservativeResult.intersectWith(X.zeroExtend(BitWidth), + RangeType)); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { + ConstantRange X = getRangeRef(SExt->getOperand(), SignHint); + return setRange(SExt, SignHint, + ConservativeResult.intersectWith(X.signExtend(BitWidth), + RangeType)); + } + + if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { + ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint); + return setRange(Trunc, SignHint, + ConservativeResult.intersectWith(X.truncate(BitWidth), + RangeType)); + } + + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { + // If there's no unsigned wrap, the value will never be less than its + // initial value. + if (AddRec->hasNoUnsignedWrap()) + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) + if (!C->getValue()->isZero()) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(C->getAPInt(), APInt(BitWidth, 0)), RangeType); + + // If there's no signed wrap, and all the operands have the same sign or + // zero, the value won't ever change sign. + if (AddRec->hasNoSignedWrap()) { + bool AllNonNeg = true; + bool AllNonPos = true; + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; + if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; + } + if (AllNonNeg) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt(BitWidth, 0), + APInt::getSignedMinValue(BitWidth)), RangeType); + else if (AllNonPos) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth), + APInt(BitWidth, 1)), RangeType); + } + + // TODO: non-affine addrec + if (AddRec->isAffine()) { + const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop()); + if (!isa<SCEVCouldNotCompute>(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromAffine.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine, RangeType); + + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromFactoring.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + } + } + + return setRange(AddRec, SignHint, std::move(ConservativeResult)); + } + + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // Check if the IR explicitly contains !range metadata. + Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(), + RangeType); + + // Split here to avoid paying the compile-time cost of calling both + // computeKnownBits and ComputeNumSignBits. This restriction can be lifted + // if needed. + const DataLayout &DL = getDataLayout(); + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + // For a SCEVUnknown, ask ValueTracking. + KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + if (Known.One != ~Known.Zero + 1) + ConservativeResult = + ConservativeResult.intersectWith( + ConstantRange(Known.One, ~Known.Zero + 1), RangeType); + } else { + assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && + "generalize as needed!"); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + if (NS > 1) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1), + RangeType); + } + + // A range of Phi is a subset of union of all ranges of its input. + if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) { + // Make sure that we do not run over cycled Phis. + if (PendingPhiRanges.insert(Phi).second) { + ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); + for (auto &Op : Phi->operands()) { + auto OpRange = getRangeRef(getSCEV(Op), SignHint); + RangeFromOps = RangeFromOps.unionWith(OpRange); + // No point to continue if we already have a full set. + if (RangeFromOps.isFullSet()) + break; + } + ConservativeResult = + ConservativeResult.intersectWith(RangeFromOps, RangeType); + bool Erased = PendingPhiRanges.erase(Phi); + assert(Erased && "Failed to erase Phi properly?"); + (void) Erased; + } + } + + return setRange(U, SignHint, std::move(ConservativeResult)); + } + + return setRange(S, SignHint, std::move(ConservativeResult)); +} + +// Given a StartRange, Step and MaxBECount for an expression compute a range of +// values that the expression can take. Initially, the expression has a value +// from StartRange and then is changed by Step up to MaxBECount times. Signed +// argument defines if we treat Step as signed or unsigned. +static ConstantRange getRangeForAffineARHelper(APInt Step, + const ConstantRange &StartRange, + const APInt &MaxBECount, + unsigned BitWidth, bool Signed) { + // If either Step or MaxBECount is 0, then the expression won't change, and we + // just need to return the initial range. + if (Step == 0 || MaxBECount == 0) + return StartRange; + + // If we don't know anything about the initial value (i.e. StartRange is + // FullRange), then we don't know anything about the final range either. + // Return FullRange. + if (StartRange.isFullSet()) + return ConstantRange::getFull(BitWidth); + + // If Step is signed and negative, then we use its absolute value, but we also + // note that we're moving in the opposite direction. + bool Descending = Signed && Step.isNegative(); + + if (Signed) + // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: + // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. + // This equations hold true due to the well-defined wrap-around behavior of + // APInt. + Step = Step.abs(); + + // Check if Offset is more than full span of BitWidth. If it is, the + // expression is guaranteed to overflow. + if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) + return ConstantRange::getFull(BitWidth); + + // Offset is by how much the expression can change. Checks above guarantee no + // overflow here. + APInt Offset = Step * MaxBECount; + + // Minimum value of the final range will match the minimal value of StartRange + // if the expression is increasing and will be decreased by Offset otherwise. + // Maximum value of the final range will match the maximal value of StartRange + // if the expression is decreasing and will be increased by Offset otherwise. + APInt StartLower = StartRange.getLower(); + APInt StartUpper = StartRange.getUpper() - 1; + APInt MovedBoundary = Descending ? (StartLower - std::move(Offset)) + : (StartUpper + std::move(Offset)); + + // It's possible that the new minimum/maximum value will fall into the initial + // range (due to wrap around). This means that the expression can take any + // value in this bitwidth, and we have to return full range. + if (StartRange.contains(MovedBoundary)) + return ConstantRange::getFull(BitWidth); + + APInt NewLower = + Descending ? std::move(MovedBoundary) : std::move(StartLower); + APInt NewUpper = + Descending ? std::move(StartUpper) : std::move(MovedBoundary); + NewUpper += 1; + + // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. + return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper)); +} + +ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + assert(!isa<SCEVCouldNotCompute>(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + "Precondition!"); + + MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); + APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount); + + // First, consider step signed. + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange StepSRange = getSignedRange(Step); + + // If Step can be both positive and negative, we need to find ranges for the + // maximum absolute step values in both directions and union them. + ConstantRange SR = + getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, + MaxBECountValue, BitWidth, /* Signed = */ true); + SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), + StartSRange, MaxBECountValue, + BitWidth, /* Signed = */ true)); + + // Next, consider step unsigned. + ConstantRange UR = getRangeForAffineARHelper( + getUnsignedRangeMax(Step), getUnsignedRange(Start), + MaxBECountValue, BitWidth, /* Signed = */ false); + + // Finally, intersect signed and unsigned ranges. + return SR.intersectWith(UR, ConstantRange::Smallest); +} + +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) + // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + + struct SelectPattern { + Value *Condition = nullptr; + APInt TrueValue; + APInt FalseValue; + + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, + const SCEV *S) { + Optional<unsigned> CastOp; + APInt Offset(BitWidth, 0); + + assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && + "Should be!"); + + // Peel off a constant offset: + if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { + // In the future we could consider being smarter here and handle + // {Start+Step,+,Step} too. + if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) + return; + + Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); + S = SA->getOperand(1); + } + + // Peel off a cast operation + if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { + CastOp = SCast->getSCEVType(); + S = SCast->getOperand(); + } + + using namespace llvm::PatternMatch; + + auto *SU = dyn_cast<SCEVUnknown>(S); + const APInt *TrueVal, *FalseVal; + if (!SU || + !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), + m_APInt(FalseVal)))) { + Condition = nullptr; + return; + } + + TrueValue = *TrueVal; + FalseValue = *FalseVal; + + // Re-apply the cast we peeled off earlier + if (CastOp.hasValue()) + switch (*CastOp) { + default: + llvm_unreachable("Unknown SCEV cast type!"); + + case scTruncate: + TrueValue = TrueValue.trunc(BitWidth); + FalseValue = FalseValue.trunc(BitWidth); + break; + case scZeroExtend: + TrueValue = TrueValue.zext(BitWidth); + FalseValue = FalseValue.zext(BitWidth); + break; + case scSignExtend: + TrueValue = TrueValue.sext(BitWidth); + FalseValue = FalseValue.sext(BitWidth); + break; + } + + // Re-apply the constant offset we peeled off earlier + TrueValue += Offset; + FalseValue += Offset; + } + + bool isRecognized() { return Condition != nullptr; } + }; + + SelectPattern StartPattern(*this, BitWidth, Start); + if (!StartPattern.isRecognized()) + return ConstantRange::getFull(BitWidth); + + SelectPattern StepPattern(*this, BitWidth, Step); + if (!StepPattern.isRecognized()) + return ConstantRange::getFull(BitWidth); + + if (StartPattern.Condition != StepPattern.Condition) { + // We don't handle this case today; but we could, by considering four + // possibilities below instead of two. I'm not sure if there are cases where + // that will help over what getRange already does, though. + return ConstantRange::getFull(BitWidth); + } + + // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to + // construct arbitrary general SCEV expressions here. This function is called + // from deep in the call stack, and calling getSCEV (on a sext instruction, + // say) can end up caching a suboptimal value. + + // FIXME: without the explicit `this` receiver below, MSVC errors out with + // C2352 and C2512 (otherwise it isn't needed). + + const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); + const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); + const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); + const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + + ConstantRange TrueRange = + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + ConstantRange FalseRange = + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + + return TrueRange.unionWith(FalseRange); +} + +SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { + if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; + const BinaryOperator *BinOp = cast<BinaryOperator>(V); + + // Return early if there are no flags to propagate to the SCEV. + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BinOp->hasNoUnsignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + if (BinOp->hasNoSignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + if (Flags == SCEV::FlagAnyWrap) + return SCEV::FlagAnyWrap; + + return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; +} + +bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { + // Here we check that I is in the header of the innermost loop containing I, + // since we only deal with instructions in the loop header. The actual loop we + // need to check later will come from an add recurrence, but getting that + // requires computing the SCEV of the operands, which can be expensive. This + // check we can do cheaply to rule out some cases early. + Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); + if (InnermostContainingLoop == nullptr || + InnermostContainingLoop->getHeader() != I->getParent()) + return false; + + // Only proceed if we can prove that I does not yield poison. + if (!programUndefinedIfFullPoison(I)) + return false; + + // At this point we know that if I is executed, then it does not wrap + // according to at least one of NSW or NUW. If I is not executed, then we do + // not know if the calculation that I represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from I to + // the SCEV, we must guarantee no wrapping for that SCEV also when it is + // derived from other instructions that map to the same SCEV. We cannot make + // that guarantee for cases where I is not executed. So we need to find the + // loop that I is considered in relation to and prove that I is executed for + // every iteration of that loop. That implies that the value that I + // calculates does not wrap anywhere in the loop, so then we can apply the + // flags to the SCEV. + // + // We check isLoopInvariant to disambiguate in case we are adding recurrences + // from different loops, so that we know which loop to prove that I is + // executed in. + for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { + // I could be an extractvalue from a call to an overflow intrinsic. + // TODO: We can do better here in some cases. + if (!isSCEVable(I->getOperand(OpIndex)->getType())) + return false; + const SCEV *Op = getSCEV(I->getOperand(OpIndex)); + if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { + bool AllOtherOpsLoopInvariant = true; + for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); + ++OtherOpIndex) { + if (OtherOpIndex != OpIndex) { + const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); + if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { + AllOtherOpsLoopInvariant = false; + break; + } + } + } + if (AllOtherOpsLoopInvariant && + isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) + return true; + } + } + return false; +} + +bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { + // If we know that \c I can never be poison period, then that's enough. + if (isSCEVExprNeverPoison(I)) + return true; + + // For an add recurrence specifically, we assume that infinite loops without + // side effects are undefined behavior, and then reason as follows: + // + // If the add recurrence is poison in any iteration, it is poison on all + // future iterations (since incrementing poison yields poison). If the result + // of the add recurrence is fed into the loop latch condition and the loop + // does not contain any throws or exiting blocks other than the latch, we now + // have the ability to "choose" whether the backedge is taken or not (by + // choosing a sufficiently evil value for the poison feeding into the branch) + // for every iteration including and after the one in which \p I first became + // poison. There are two possibilities (let's call the iteration in which \p + // I first became poison as K): + // + // 1. In the set of iterations including and after K, the loop body executes + // no side effects. In this case executing the backege an infinte number + // of times will yield undefined behavior. + // + // 2. In the set of iterations including and after K, the loop body executes + // at least one side effect. In this case, that specific instance of side + // effect is control dependent on poison, which also yields undefined + // behavior. + + auto *ExitingBB = L->getExitingBlock(); + auto *LatchBB = L->getLoopLatch(); + if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) + return false; + + SmallPtrSet<const Instruction *, 16> Pushed; + SmallVector<const Instruction *, 8> PoisonStack; + + // We start by assuming \c I, the post-inc add recurrence, is poison. Only + // things that are known to be fully poison under that assumption go on the + // PoisonStack. + Pushed.insert(I); + PoisonStack.push_back(I); + + bool LatchControlDependentOnPoison = false; + while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { + const Instruction *Poison = PoisonStack.pop_back_val(); + + for (auto *PoisonUser : Poison->users()) { + if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { + if (Pushed.insert(cast<Instruction>(PoisonUser)).second) + PoisonStack.push_back(cast<Instruction>(PoisonUser)); + } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { + assert(BI->isConditional() && "Only possibility!"); + if (BI->getParent() == LatchBB) { + LatchControlDependentOnPoison = true; + break; + } + } + } + } + + return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); +} + +ScalarEvolution::LoopProperties +ScalarEvolution::getLoopProperties(const Loop *L) { + using LoopProperties = ScalarEvolution::LoopProperties; + + auto Itr = LoopPropertiesCache.find(L); + if (Itr == LoopPropertiesCache.end()) { + auto HasSideEffects = [](Instruction *I) { + if (auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isSimple(); + + return I->mayHaveSideEffects(); + }; + + LoopProperties LP = {/* HasNoAbnormalExits */ true, + /*HasNoSideEffects*/ true}; + + for (auto *BB : L->getBlocks()) + for (auto &I : *BB) { + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + LP.HasNoAbnormalExits = false; + if (HasSideEffects(&I)) + LP.HasNoSideEffects = false; + if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) + break; // We're already as pessimistic as we can get. + } + + auto InsertPair = LoopPropertiesCache.insert({L, LP}); + assert(InsertPair.second && "We just checked!"); + Itr = InsertPair.first; + } + + return Itr->second; +} + +const SCEV *ScalarEvolution::createSCEV(Value *V) { + if (!isSCEVable(V->getType())) + return getUnknown(V); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + // Don't attempt to analyze instructions in blocks that aren't + // reachable. Such instructions don't matter, and they aren't required + // to obey basic rules for definitions dominating uses which this + // analysis depends on. + if (!DT.isReachableFromEntry(I->getParent())) + return getUnknown(UndefValue::get(V->getType())); + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + return getConstant(CI); + else if (isa<ConstantPointerNull>(V)) + return getZero(V->getType()); + else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) + return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); + else if (!isa<ConstantExpr>(V)) + return getUnknown(V); + + Operator *U = cast<Operator>(V); + if (auto BO = MatchBinaryOp(U, DT)) { + switch (BO->Opcode) { + case Instruction::Add: { + // The simple thing to do would be to just call getSCEV on both operands + // and call getAddExpr with the result. However if we're looking at a + // bunch of things all added together, this can be quite inefficient, + // because it leads to N-1 getAddExpr calls for N ultimate operands. + // Instead, gather up all the operands and make a single getAddExpr call. + // LLVM IR canonical form means we need only traverse the left operands. + SmallVector<const SCEV *, 4> AddOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + AddOps.push_back(OpSCEV); + break; + } + + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + const SCEV *RHS = getSCEV(BO->RHS); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + const SCEV *LHS = getSCEV(BO->LHS); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + else + AddOps.push_back(getAddExpr(LHS, RHS, Flags)); + break; + } + } + + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); + else + AddOps.push_back(getSCEV(BO->RHS)); + + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || (NewBO->Opcode != Instruction::Add && + NewBO->Opcode != Instruction::Sub)) { + AddOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getAddExpr(AddOps); + } + + case Instruction::Mul: { + SmallVector<const SCEV *, 4> MulOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + MulOps.push_back(OpSCEV); + break; + } + + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back( + getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); + break; + } + } + + MulOps.push_back(getSCEV(BO->RHS)); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || NewBO->Opcode != Instruction::Mul) { + MulOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getMulExpr(MulOps); + } + case Instruction::UDiv: + return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::URem: + return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::Sub: { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->Op) + Flags = getNoWrapFlagsFromUB(BO->Op); + return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); + } + case Instruction::And: + // For an expression like x&255 that merely masks off the high bits, + // use zext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + if (CI->isZero()) + return getSCEV(BO->RHS); + if (CI->isMinusOne()) + return getSCEV(BO->LHS); + const APInt &A = CI->getValue(); + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use computeKnownBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned TZ = A.countTrailingZeros(); + unsigned BitWidth = A.getBitWidth(); + KnownBits Known(BitWidth); + computeKnownBits(BO->LHS, Known, getDataLayout(), + 0, &AC, nullptr, &DT); + + APInt EffectiveMask = + APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); + if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + const SCEV *LHS = getSCEV(BO->LHS); + const SCEV *ShiftedLHS = nullptr; + if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { + if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { + // For an expression like (x * 8) & 8, simplify the multiply. + unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); + unsigned GCD = std::min(MulZeros, TZ); + APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); + SmallVector<const SCEV*, 4> MulOps; + MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); + MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); + auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); + } + } + if (!ShiftedLHS) + ShiftedLHS = getUDivExpr(LHS, MulCount); + return getMulExpr( + getZeroExtendExpr( + getTruncateExpr(ShiftedLHS, + IntegerType::get(getContext(), BitWidth - LZ - TZ)), + BO->LHS->getType()), + MulCount); + } + } + break; + + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + const SCEV *LHS = getSCEV(BO->LHS); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); + const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( + OldAR->getNoWrapFlags()); + } + return S; + } + } + break; + + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + // If the RHS of xor is -1, then this is a not operation. + if (CI->isMinusOne()) + return getNotSCEV(getSCEV(BO->LHS)); + + // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has trimmed non-demanded bits out + // of an xor with -1. + if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) + if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) + if (LBO->getOpcode() == Instruction::And && + LCI->getValue() == CI->getValue()) + if (const SCEVZeroExtendExpr *Z = + dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { + Type *UTy = BO->LHS->getType(); + const SCEV *Z0 = Z->getOperand(); + Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is serving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (CI->getValue().isMask(Z0TySize)) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = CI->getValue().trunc(Z0TySize); + if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignMask()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } + } + break; + + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { + uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().uge(BitWidth)) + break; + + // It is currently not resolved how to interpret NSW for left + // shift by BitWidth - 1, so we avoid applying flags in that + // case. Remove this check (or this comment) once the situation + // is resolved. See + // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html + // and http://reviews.llvm.org/D8890 . + auto Flags = SCEV::FlagAnyWrap; + if (BO->Op && SA->getValue().ult(BitWidth - 1)) + Flags = getNoWrapFlagsFromUB(BO->Op); + + Constant *X = ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); + } + break; + + case Instruction::AShr: { + // AShr X, C, where C is a constant. + ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); + if (!CI) + break; + + Type *OuterTy = BO->LHS->getType(); + uint64_t BitWidth = getTypeSizeInBits(OuterTy); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; + + if (CI->isZero()) + return getSCEV(BO->LHS); // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + + Operator *L = dyn_cast<Operator>(BO->LHS); + if (L && L->getOpcode() == Instruction::Shl) { + // X = Shl A, n + // Y = AShr X, m + // Both n and m are constant. + + const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + if (L->getOperand(1) == BO->RHS) + // For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + return getSignExtendExpr( + getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + + ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); + if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmt > AShrAmt) { + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, + ShlAmt - AShrAmt); + return getSignExtendExpr( + getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), + getConstant(Mul)), OuterTy); + } + } + } + break; + } + } + } + + switch (U->getOpcode()) { + case Instruction::Trunc: + return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::ZExt: + return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::SExt: + if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) { + // The NSW flag of a subtract does not always survive the conversion to + // A + (-1)*B. By pushing sign extension onto its operands we are much + // more likely to preserve NSW and allow later AddRec optimisations. + // + // NOTE: This is effectively duplicating this logic from getSignExtend: + // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> + // but by that point the NSW information has potentially been lost. + if (BO->Opcode == Instruction::Sub && BO->IsNSW) { + Type *Ty = U->getType(); + auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); + auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); + return getMinusSCEV(V1, V2, SCEV::FlagNSW); + } + } + return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) + return getSCEV(U->getOperand(0)); + break; + + // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can + // lead to pointer expressions which cannot safely be expanded to GEPs, + // because ScalarEvolution doesn't respect the GEP aliasing rules when + // simplifying integer expressions. + + case Instruction::GetElementPtr: + return createNodeForGEP(cast<GEPOperator>(U)); + + case Instruction::PHI: + return createNodeForPHI(cast<PHINode>(U)); + + case Instruction::Select: + // U can also be a select constant expr, which let fall through. Since + // createNodeForSelect only works for a condition that is an `ICmpInst`, and + // constant expressions cannot have instructions as operands, we'd have + // returned getUnknown for a select constant expressions anyway. + if (isa<Instruction>(U)) + return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), + U->getOperand(1), U->getOperand(2)); + break; + + case Instruction::Call: + case Instruction::Invoke: + if (Value *RV = CallSite(U).getReturnedArgOperand()) + return getSCEV(RV); + break; + } + + return getUnknown(V); +} + +//===----------------------------------------------------------------------===// +// Iteration Count Computation Code +// + +static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { + if (!ExitCount) + return 0; + + ConstantInt *ExitConst = ExitCount->getValue(); + + // Guard against huge trip counts. + if (ExitConst->getValue().getActiveBits() > 32) + return 0; + + // In case of integer overflow, this returns 0, which is correct. + return ((unsigned)ExitConst->getZExtValue()) + 1; +} + +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripCount(L, ExitingBB); + + // No trip count information for multiple exits. + return 0; +} + +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEVConstant *ExitCount = + dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); + return getConstantTripCount(ExitCount); +} + +unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { + const auto *MaxExitCount = + dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L)); + return getConstantTripCount(MaxExitCount); +} + +unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripMultiple(L, ExitingBB); + + // No trip multiple information for multiple exits. + return 0; +} + +/// Returns the largest constant divisor of the trip count of this loop as a +/// normal unsigned value, if possible. This means that the actual trip count is +/// always a multiple of the returned value (don't forget the trip count could +/// very well be zero as well!). +/// +/// Returns 1 if the trip count is unknown or not guaranteed to be the +/// multiple of a constant (which is also the case if the trip count is simply +/// constant, use getSmallConstantTripCount for that case), Will also return 1 +/// if the trip count is very large (>= 2^32). +/// +/// As explained in the comments for getSmallConstantTripCount, this assumes +/// that control exits the loop via ExitingBlock. +unsigned +ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEV *ExitCount = getExitCount(L, ExitingBlock); + if (ExitCount == getCouldNotCompute()) + return 1; + + // Get the trip count from the BE count by adding 1. + const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType())); + + const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); + if (!TC) + // Attempt to factor more general cases. Returns the greatest power of + // two divisor. If overflow happens, the trip count expression is still + // divisible by the greatest power of 2 divisor returned. + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); + + ConstantInt *Result = TC->getValue(); + + // Guard against huge trip counts (this requires checking + // for zero to handle the case where the trip count == -1 and the + // addition wraps). + if (!Result || Result->getValue().getActiveBits() > 32 || + Result->getValue().getActiveBits() == 0) + return 1; + + return (unsigned)Result->getZExtValue(); +} + +/// Get the expression for the number of loop iterations for which this loop is +/// guaranteed not to exit via ExitingBlock. Otherwise return +/// SCEVCouldNotCompute. +const SCEV *ScalarEvolution::getExitCount(const Loop *L, + BasicBlock *ExitingBlock) { + return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); +} + +const SCEV * +ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, + SCEVUnionPredicate &Preds) { + return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); +} + +const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { + return getBackedgeTakenInfo(L).getExact(L, this); +} + +/// Similar to getBackedgeTakenCount, except return the least SCEV value that is +/// known never to be less than the actual backedge taken count. +const SCEV *ScalarEvolution::getConstantMaxBackedgeTakenCount(const Loop *L) { + return getBackedgeTakenInfo(L).getMax(this); +} + +bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { + return getBackedgeTakenInfo(L).isMaxOrZero(this); +} + +/// Push PHI nodes in the header of the given loop onto the given Worklist. +static void +PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { + BasicBlock *Header = L->getHeader(); + + // Push all Loop-header PHIs onto the Worklist stack. + for (PHINode &PN : Header->phis()) + Worklist.push_back(&PN); +} + +const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { + auto &BTI = getBackedgeTakenInfo(L); + if (BTI.hasFullInfo()) + return BTI; + + auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + + if (!Pair.second) + return Pair.first->second; + + BackedgeTakenInfo Result = + computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + + return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); +} + +const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { + // Initially insert an invalid entry for this loop. If the insertion + // succeeds, proceed to actually compute a backedge-taken count and + // update the value. The temporary CouldNotCompute value tells SCEV + // code elsewhere that it shouldn't attempt to request a new + // backedge-taken count, which could result in infinite recursion. + std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = + BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + if (!Pair.second) + return Pair.first->second; + + // computeBackedgeTakenCount may allocate memory for its result. Inserting it + // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result + // must be cleared in this scope. + BackedgeTakenInfo Result = computeBackedgeTakenCount(L); + + // In product build, there are no usage of statistic. + (void)NumTripCountsComputed; + (void)NumTripCountsNotComputed; +#if LLVM_ENABLE_STATS || !defined(NDEBUG) + const SCEV *BEExact = Result.getExact(L, this); + if (BEExact != getCouldNotCompute()) { + assert(isLoopInvariant(BEExact, L) && + isLoopInvariant(Result.getMax(this), L) && + "Computed backedge-taken count isn't loop invariant for loop!"); + ++NumTripCountsComputed; + } + else if (Result.getMax(this) == getCouldNotCompute() && + isa<PHINode>(L->getHeader()->begin())) { + // Only count loops that have phi nodes as not being computable. + ++NumTripCountsNotComputed; + } +#endif // LLVM_ENABLE_STATS || !defined(NDEBUG) + + // Now that we know more about the trip count for this loop, forget any + // existing SCEV values for PHI nodes in this loop since they are only + // conservative estimates made without the benefit of trip count + // information. This is similar to the code in forgetLoop, except that + // it handles SCEVUnknown PHI nodes specially. + if (Result.hasAnyInfo()) { + SmallVector<Instruction *, 16> Worklist; + PushLoopPHIs(L, Worklist); + + SmallPtrSet<Instruction *, 8> Discovered; + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + + ValueExprMapType::iterator It = + ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + const SCEV *Old = It->second; + + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, or it's a PHI that's in the progress of being computed + // by createNodeForPHI. In the former case, additional loop trip + // count information isn't going to change anything. In the later + // case, createNodeForPHI will perform the necessary updates on its + // own when it gets to that point. + if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) { + eraseValueFromMap(It->first); + forgetMemoizedResults(Old); + } + if (PHINode *PN = dyn_cast<PHINode>(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + // Since we don't need to invalidate anything for correctness and we're + // only invalidating to make SCEV's results more precise, we get to stop + // early to avoid invalidating too much. This is especially important in + // cases like: + // + // %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node + // loop0: + // %pn0 = phi + // ... + // loop1: + // %pn1 = phi + // ... + // + // where both loop0 and loop1's backedge taken count uses the SCEV + // expression for %v. If we don't have the early stop below then in cases + // like the above, getBackedgeTakenInfo(loop1) will clear out the trip + // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip + // count for loop1, effectively nullifying SCEV's trip count cache. + for (auto *U : I->users()) + if (auto *I = dyn_cast<Instruction>(U)) { + auto *LoopForUser = LI.getLoopFor(I->getParent()); + if (LoopForUser && L->contains(LoopForUser) && + Discovered.insert(I).second) + Worklist.push_back(I); + } + } + } + + // Re-lookup the insert position, since the call to + // computeBackedgeTakenCount above could result in a + // recusive call to getBackedgeTakenInfo (on a different + // loop), which would invalidate the iterator computed + // earlier. + return BackedgeTakenCounts.find(L)->second = std::move(Result); +} + +void ScalarEvolution::forgetAllLoops() { + // This method is intended to forget all info about loops. It should + // invalidate caches as if the following happened: + // - The trip counts of all loops have changed arbitrarily + // - Every llvm::Value has been updated in place to produce a different + // result. + BackedgeTakenCounts.clear(); + PredicatedBackedgeTakenCounts.clear(); + LoopPropertiesCache.clear(); + ConstantEvolutionLoopExitValue.clear(); + ValueExprMap.clear(); + ValuesAtScopes.clear(); + LoopDispositions.clear(); + BlockDispositions.clear(); + UnsignedRanges.clear(); + SignedRanges.clear(); + ExprValueMap.clear(); + HasRecMap.clear(); + MinTrailingZerosCache.clear(); + PredicatedSCEVRewrites.clear(); +} + +void ScalarEvolution::forgetLoop(const Loop *L) { + // Drop any stored trip count value. + auto RemoveLoopFromBackedgeMap = + [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) { + auto BTCPos = Map.find(L); + if (BTCPos != Map.end()) { + BTCPos->second.clear(); + Map.erase(BTCPos); + } + }; + + SmallVector<const Loop *, 16> LoopWorklist(1, L); + SmallVector<Instruction *, 32> Worklist; + SmallPtrSet<Instruction *, 16> Visited; + + // Iterate over all the loops and sub-loops to drop SCEV information. + while (!LoopWorklist.empty()) { + auto *CurrL = LoopWorklist.pop_back_val(); + + RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL); + RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL); + + // Drop information about predicated SCEV rewrites for this loop. + for (auto I = PredicatedSCEVRewrites.begin(); + I != PredicatedSCEVRewrites.end();) { + std::pair<const SCEV *, const Loop *> Entry = I->first; + if (Entry.second == CurrL) + PredicatedSCEVRewrites.erase(I++); + else + ++I; + } + + auto LoopUsersItr = LoopUsers.find(CurrL); + if (LoopUsersItr != LoopUsers.end()) { + for (auto *S : LoopUsersItr->second) + forgetMemoizedResults(S); + LoopUsers.erase(LoopUsersItr); + } + + // Drop information about expressions based on loop-header PHIs. + PushLoopPHIs(CurrL, Worklist); + + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I).second) + continue; + + ValueExprMapType::iterator It = + ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); + forgetMemoizedResults(It->second); + if (PHINode *PN = dyn_cast<PHINode>(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist); + } + + LoopPropertiesCache.erase(CurrL); + // Forget all contained loops too, to avoid dangling entries in the + // ValuesAtScopes map. + LoopWorklist.append(CurrL->begin(), CurrL->end()); + } +} + +void ScalarEvolution::forgetTopmostLoop(const Loop *L) { + while (Loop *Parent = L->getParentLoop()) + L = Parent; + forgetLoop(L); +} + +void ScalarEvolution::forgetValue(Value *V) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return; + + // Drop information about expressions based on loop-header PHIs. + SmallVector<Instruction *, 16> Worklist; + Worklist.push_back(I); + + SmallPtrSet<Instruction *, 8> Visited; + while (!Worklist.empty()) { + I = Worklist.pop_back_val(); + if (!Visited.insert(I).second) + continue; + + ValueExprMapType::iterator It = + ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); + forgetMemoizedResults(It->second); + if (PHINode *PN = dyn_cast<PHINode>(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist); + } +} + +/// Get the exact loop backedge taken count considering all loop exits. A +/// computable result can only be returned for loops with all exiting blocks +/// dominating the latch. howFarToZero assumes that the limit of each loop test +/// is never skipped. This is a valid assumption as long as the loop exits via +/// that test. For precise results, it is the caller's responsibility to specify +/// the relevant loop exiting block using getExact(ExitingBlock, SE). +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, + SCEVUnionPredicate *Preds) const { + // If any exits were not computable, the loop is not computable. + if (!isComplete() || ExitNotTaken.empty()) + return SE->getCouldNotCompute(); + + const BasicBlock *Latch = L->getLoopLatch(); + // All exiting blocks we have collected must dominate the only backedge. + if (!Latch) + return SE->getCouldNotCompute(); + + // All exiting blocks we have gathered dominate loop's latch, so exact trip + // count is simply a minimum out of all these calculated exit counts. + SmallVector<const SCEV *, 2> Ops; + for (auto &ENT : ExitNotTaken) { + const SCEV *BECount = ENT.ExactNotTaken; + assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); + assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && + "We should only have known counts for exiting blocks that dominate " + "latch!"); + + Ops.push_back(BECount); + + if (Preds && !ENT.hasAlwaysTruePredicate()) + Preds->add(ENT.Predicate.get()); + + assert((Preds || ENT.hasAlwaysTruePredicate()) && + "Predicate should be always true!"); + } + + return SE->getUMinFromMismatchedTypes(Ops); +} + +/// Get the exact not taken count for this loop exit. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, + ScalarEvolution *SE) const { + for (auto &ENT : ExitNotTaken) + if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) + return ENT.ExactNotTaken; + + return SE->getCouldNotCompute(); +} + +/// getMax - Get the max backedge taken count for the loop. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { + return !ENT.hasAlwaysTruePredicate(); + }; + + if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) + return SE->getCouldNotCompute(); + + assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) && + "No point in having a non-constant max backedge taken count!"); + return getMax(); +} + +bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const { + auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { + return !ENT.hasAlwaysTruePredicate(); + }; + return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); +} + +bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, + ScalarEvolution *SE) const { + if (getMax() && getMax() != SE->getCouldNotCompute() && + SE->hasOperand(getMax(), S)) + return true; + + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE->getCouldNotCompute() && + SE->hasOperand(ENT.ExactNotTaken, S)) + return true; + + return false; +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) + : ExactNotTaken(E), MaxNotTaken(E) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList) + : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) { + assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || + !isa<SCEVCouldNotCompute>(MaxNotTaken)) && + "Exact is not allowed to be less precise than Max"); + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); + for (auto *PredSet : PredSetList) + for (auto *P : *PredSet) + addPredicate(P); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + const SmallPtrSetImpl<const SCEVPredicate *> &PredSet) + : ExitLimit(E, M, MaxOrZero, {&PredSet}) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, + bool MaxOrZero) + : ExitLimit(E, M, MaxOrZero, None) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each +/// computable exit into a persistent ExitNotTakenInfo array. +ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( + ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> + ExitCounts, + bool Complete, const SCEV *MaxCount, bool MaxOrZero) + : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { + using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; + + ExitNotTaken.reserve(ExitCounts.size()); + std::transform( + ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), + [&](const EdgeExitInfo &EEI) { + BasicBlock *ExitBB = EEI.first; + const ExitLimit &EL = EEI.second; + if (EL.Predicates.empty()) + return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); + + std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); + for (auto *Pred : EL.Predicates) + Predicate->add(Pred); + + return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); + }); + assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) && + "No point in having a non-constant max backedge taken count!"); +} + +/// Invalidate this result and free the ExitNotTakenInfo array. +void ScalarEvolution::BackedgeTakenInfo::clear() { + ExitNotTaken.clear(); +} + +/// Compute the number of times the backedge of the specified loop will execute. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::computeBackedgeTakenCount(const Loop *L, + bool AllowPredicates) { + SmallVector<BasicBlock *, 8> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; + + SmallVector<EdgeExitInfo, 4> ExitCounts; + bool CouldComputeBECount = true; + BasicBlock *Latch = L->getLoopLatch(); // may be NULL. + const SCEV *MustExitMaxBECount = nullptr; + const SCEV *MayExitMaxBECount = nullptr; + bool MustExitMaxOrZero = false; + + // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts + // and compute maxBECount. + // Do a union of all the predicates here. + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + BasicBlock *ExitBB = ExitingBlocks[i]; + ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); + + assert((AllowPredicates || EL.Predicates.empty()) && + "Predicated exit limit when predicates are not allowed!"); + + // 1. For each exit that can be computed, add an entry to ExitCounts. + // CouldComputeBECount is true only if all exits can be computed. + if (EL.ExactNotTaken == getCouldNotCompute()) + // We couldn't compute an exact value for this exit, so + // we won't be able to compute an exact value for the loop. + CouldComputeBECount = false; + else + ExitCounts.emplace_back(ExitBB, EL); + + // 2. Derive the loop's MaxBECount from each exit's max number of + // non-exiting iterations. Partition the loop exits into two kinds: + // LoopMustExits and LoopMayExits. + // + // If the exit dominates the loop latch, it is a LoopMustExit otherwise it + // is a LoopMayExit. If any computable LoopMustExit is found, then + // MaxBECount is the minimum EL.MaxNotTaken of computable + // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum + // EL.MaxNotTaken, where CouldNotCompute is considered greater than any + // computable EL.MaxNotTaken. + if (EL.MaxNotTaken != getCouldNotCompute() && Latch && + DT.dominates(ExitBB, Latch)) { + if (!MustExitMaxBECount) { + MustExitMaxBECount = EL.MaxNotTaken; + MustExitMaxOrZero = EL.MaxOrZero; + } else { + MustExitMaxBECount = + getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken); + } + } else if (MayExitMaxBECount != getCouldNotCompute()) { + if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute()) + MayExitMaxBECount = EL.MaxNotTaken; + else { + MayExitMaxBECount = + getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken); + } + } + } + const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : + (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); + // The loop backedge will be taken the maximum or zero times if there's + // a single exit that must be taken the maximum or zero times. + bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); + return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, + MaxBECount, MaxOrZero); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + bool AllowPredicates) { + assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); + // If our exiting block does not dominate the latch, then its connection with + // loop's exit limit may be far from trivial. + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch || !DT.dominates(ExitingBlock, Latch)) + return getCouldNotCompute(); + + bool IsOnlyExit = (L->getExitingBlock() != nullptr); + Instruction *Term = ExitingBlock->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { + assert(BI->isConditional() && "If unconditional, it can't be in loop!"); + bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); + assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && + "It should have one successor in loop and one exit block!"); + // Proceed to the next level to examine the exit condition expression. + return computeExitLimitFromCond( + L, BI->getCondition(), ExitIfTrue, + /*ControlsExit=*/IsOnlyExit, AllowPredicates); + } + + if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { + // For switch, make sure that there is a single exit from the loop. + BasicBlock *Exit = nullptr; + for (auto *SBB : successors(ExitingBlock)) + if (!L->contains(SBB)) { + if (Exit) // Multiple exit successors. + return getCouldNotCompute(); + Exit = SBB; + } + assert(Exit && "Exiting block must have at least one exit"); + return computeExitLimitFromSingleExitSwitch(L, SI, Exit, + /*ControlsExit=*/IsOnlyExit); + } + + return getCouldNotCompute(); +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( + const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { + ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); + return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, + ControlsExit, AllowPredicates); +} + +Optional<ScalarEvolution::ExitLimit> +ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, + bool ExitIfTrue, bool ControlsExit, + bool AllowPredicates) { + (void)this->L; + (void)this->ExitIfTrue; + (void)this->AllowPredicates; + + assert(this->L == L && this->ExitIfTrue == ExitIfTrue && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + auto Itr = TripCountMap.find({ExitCond, ControlsExit}); + if (Itr == TripCountMap.end()) + return None; + return Itr->second; +} + +void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, + bool ExitIfTrue, + bool ControlsExit, + bool AllowPredicates, + const ExitLimit &EL) { + assert(this->L == L && this->ExitIfTrue == ExitIfTrue && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + + auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); + assert(InsertResult.second && "Expected successful insertion!"); + (void)InsertResult; + (void)ExitIfTrue; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { + + if (auto MaybeEL = + Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) + return *MaybeEL; + + ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue, + ControlsExit, AllowPredicates); + Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL); + return EL; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { + // Check if the controlling expression for this loop is an And or Or. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { + if (BO->getOpcode() == Instruction::And) { + // Recurse on the operands of the and. + bool EitherMayExit = !ExitIfTrue; + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); + if (EitherMayExit) { + // Both conditions must be true for the loop to continue executing. + // Choose the less conservative count. + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = + getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + if (EL0.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.MaxNotTaken; + else if (EL1.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.MaxNotTaken; + else + MaxBECount = + getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); + } else { + // Both conditions must be true at the same time for the loop to exit. + // For now, be conservative. + if (EL0.MaxNotTaken == EL1.MaxNotTaken) + MaxBECount = EL0.MaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; + } + + // There are cases (e.g. PR26207) where computeExitLimitFromCond is able + // to be more aggressive when computing BECount than when computing + // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and + // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken + // to not. + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + + return ExitLimit(BECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); + } + if (BO->getOpcode() == Instruction::Or) { + // Recurse on the operands of the or. + bool EitherMayExit = ExitIfTrue; + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); + if (EitherMayExit) { + // Both conditions must be false for the loop to continue executing. + // Choose the less conservative count. + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = + getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + if (EL0.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.MaxNotTaken; + else if (EL1.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.MaxNotTaken; + else + MaxBECount = + getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); + } else { + // Both conditions must be false at the same time for the loop to exit. + // For now, be conservative. + if (EL0.MaxNotTaken == EL1.MaxNotTaken) + MaxBECount = EL0.MaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; + } + // There are cases (e.g. PR26207) where computeExitLimitFromCond is able + // to be more aggressive when computing BECount than when computing + // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and + // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken + // to not. + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + + return ExitLimit(BECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); + } + } + + // With an icmp, it may be feasible to compute an exact backedge-taken count. + // Proceed to the next level to examine the icmp. + if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { + ExitLimit EL = + computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit); + if (EL.hasFullInfo() || !AllowPredicates) + return EL; + + // Try again, but use SCEV predicates this time. + return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit, + /*AllowPredicates=*/true); + } + + // Check for a constant condition. These are normally stripped out by + // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to + // preserve the CFG and is temporarily leaving constant conditions + // in place. + if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { + if (ExitIfTrue == !CI->getZExtValue()) + // The backedge is always taken. + return getCouldNotCompute(); + else + // The backedge is never taken. + return getZero(CI->getType()); + } + + // If it's not an integer or pointer comparison then compute it the hard way. + return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimitFromICmp(const Loop *L, + ICmpInst *ExitCond, + bool ExitIfTrue, + bool ControlsExit, + bool AllowPredicates) { + // If the condition was exit on true, convert the condition to exit on false + ICmpInst::Predicate Pred; + if (!ExitIfTrue) + Pred = ExitCond->getPredicate(); + else + Pred = ExitCond->getInversePredicate(); + const ICmpInst::Predicate OriginalPred = Pred; + + // Handle common loops like: for (X = "string"; *X; ++X) + if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) + if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { + ExitLimit ItCnt = + computeLoadConstantCompareExitLimit(LI, RHS, L, Pred); + if (ItCnt.hasAnyInfo()) + return ItCnt; + } + + const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); + const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); + + // Try to evaluate any dependencies out of the loop. + LHS = getSCEVAtScope(LHS, L); + RHS = getSCEVAtScope(RHS, L); + + // At this point, we would like to compute how many iterations of the + // loop the predicate will return true for these inputs. + if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { + // If there is a loop-invariant, force it into the RHS. + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // Simplify the operands before analyzing them. + (void)SimplifyICmpOperands(Pred, LHS, RHS); + + // If we have a comparison of a chrec against a constant, try to use value + // ranges to answer this query. + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) + if (AddRec->getLoop() == L) { + // Form the constant range. + ConstantRange CompRange = + ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); + + const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); + if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; + } + + switch (Pred) { + case ICmpInst::ICMP_NE: { // while (X != Y) + // Convert to: while (X-Y != 0) + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, + AllowPredicates); + if (EL.hasAnyInfo()) return EL; + break; + } + case ICmpInst::ICMP_EQ: { // while (X == Y) + // Convert to: while (X-Y == 0) + ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); + if (EL.hasAnyInfo()) return EL; + break; + } + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: { // while (X < Y) + bool IsSigned = Pred == ICmpInst::ICMP_SLT; + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); + if (EL.hasAnyInfo()) return EL; + break; + } + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: { // while (X > Y) + bool IsSigned = Pred == ICmpInst::ICMP_SGT; + ExitLimit EL = + howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); + if (EL.hasAnyInfo()) return EL; + break; + } + default: + break; + } + + auto *ExhaustiveCount = + computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + + if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) + return ExhaustiveCount; + + return computeShiftCompareExitLimit(ExitCond->getOperand(0), + ExitCond->getOperand(1), L, OriginalPred); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, + SwitchInst *Switch, + BasicBlock *ExitingBlock, + bool ControlsExit) { + assert(!L->contains(ExitingBlock) && "Not an exiting block!"); + + // Give up if the exit is the default dest of a switch. + if (Switch->getDefaultDest() == ExitingBlock) + return getCouldNotCompute(); + + assert(L->contains(Switch->getDefaultDest()) && + "Default case must not exit the loop!"); + const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); + const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); + + // while (X != Y) --> while (X-Y != 0) + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + if (EL.hasAnyInfo()) + return EL; + + return getCouldNotCompute(); +} + +static ConstantInt * +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, + ScalarEvolution &SE) { + const SCEV *InVal = SE.getConstant(C); + const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); + assert(isa<SCEVConstant>(Val) && + "Evaluation of SCEV at constant didn't fold correctly?"); + return cast<SCEVConstant>(Val)->getValue(); +} + +/// Given an exit condition of 'icmp op load X, cst', try to see if we can +/// compute the backedge execution count. +ScalarEvolution::ExitLimit +ScalarEvolution::computeLoadConstantCompareExitLimit( + LoadInst *LI, + Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { + if (LI->isVolatile()) return getCouldNotCompute(); + + // Check to see if the loaded pointer is a getelementptr of a global. + // TODO: Use SCEV instead of manually grubbing with GEPs. + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)); + if (!GEP) return getCouldNotCompute(); + + // Make sure that it is really a constant global we are gepping, with an + // initializer, and make sure the first IDX is really 0. + GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || + GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) || + !cast<Constant>(GEP->getOperand(1))->isNullValue()) + return getCouldNotCompute(); + + // Okay, we allow one non-constant index into the GEP instruction. + Value *VarIdx = nullptr; + std::vector<Constant*> Indexes; + unsigned VarIdxNum = 0; + for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { + Indexes.push_back(CI); + } else if (!isa<ConstantInt>(GEP->getOperand(i))) { + if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. + VarIdx = GEP->getOperand(i); + VarIdxNum = i-2; + Indexes.push_back(nullptr); + } + + // Loop-invariant loads may be a byproduct of loop optimization. Skip them. + if (!VarIdx) + return getCouldNotCompute(); + + // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. + // Check to see if X is a loop variant variable value now. + const SCEV *Idx = getSCEV(VarIdx); + Idx = getSCEVAtScope(Idx, L); + + // We can only recognize very limited forms of loop index expressions, in + // particular, only affine AddRec's like {C1,+,C2}. + const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx); + if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || + !isa<SCEVConstant>(IdxExpr->getOperand(0)) || + !isa<SCEVConstant>(IdxExpr->getOperand(1))) + return getCouldNotCompute(); + + unsigned MaxSteps = MaxBruteForceIterations; + for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { + ConstantInt *ItCst = ConstantInt::get( + cast<IntegerType>(IdxExpr->getType()), IterationNum); + ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); + + // Form the GEP offset. + Indexes[VarIdxNum] = Val; + + Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), + Indexes); + if (!Result) break; // Cannot compute! + + // Evaluate the condition for this iteration. + Result = ConstantExpr::getICmp(predicate, Result, RHS); + if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure + if (cast<ConstantInt>(Result)->getValue().isMinValue()) { + ++NumArrayLenItCounts; + return getConstant(ItCst); // Found terminating iteration! + } + } + return getCouldNotCompute(); +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( + Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { + ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV); + if (!RHS) + return getCouldNotCompute(); + + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return getCouldNotCompute(); + + const BasicBlock *Predecessor = L->getLoopPredecessor(); + if (!Predecessor) + return getCouldNotCompute(); + + // Return true if V is of the form "LHS `shift_op` <positive constant>". + // Return LHS in OutLHS and shift_opt in OutOpCode. + auto MatchPositiveShift = + [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + + using namespace PatternMatch; + + ConstantInt *ShiftAmt; + if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::LShr; + else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::AShr; + else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::Shl; + else + return false; + + return ShiftAmt->getValue().isStrictlyPositive(); + }; + + // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in + // + // loop: + // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] + // %iv.shifted = lshr i32 %iv, <positive constant> + // + // Return true on a successful match. Return the corresponding PHI node (%iv + // above) in PNOut and the opcode of the shift operation in OpCodeOut. + auto MatchShiftRecurrence = + [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { + Optional<Instruction::BinaryOps> PostShiftOpCode; + + { + Instruction::BinaryOps OpC; + Value *V; + + // If we encounter a shift instruction, "peel off" the shift operation, + // and remember that we did so. Later when we inspect %iv's backedge + // value, we will make sure that the backedge value uses the same + // operation. + // + // Note: the peeled shift operation does not have to be the same + // instruction as the one feeding into the PHI's backedge value. We only + // really care about it being the same *kind* of shift instruction -- + // that's all that is required for our later inferences to hold. + if (MatchPositiveShift(LHS, V, OpC)) { + PostShiftOpCode = OpC; + LHS = V; + } + } + + PNOut = dyn_cast<PHINode>(LHS); + if (!PNOut || PNOut->getParent() != L->getHeader()) + return false; + + Value *BEValue = PNOut->getIncomingValueForBlock(Latch); + Value *OpLHS; + + return + // The backedge value for the PHI node must be a shift by a positive + // amount + MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + + // of the PHI node itself + OpLHS == PNOut && + + // and the kind of shift should be match the kind of shift we peeled + // off, if any. + (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); + }; + + PHINode *PN; + Instruction::BinaryOps OpCode; + if (!MatchShiftRecurrence(LHS, PN, OpCode)) + return getCouldNotCompute(); + + const DataLayout &DL = getDataLayout(); + + // The key rationale for this optimization is that for some kinds of shift + // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 + // within a finite number of iterations. If the condition guarding the + // backedge (in the sense that the backedge is taken if the condition is true) + // is false for the value the shift recurrence stabilizes to, then we know + // that the backedge is taken only a finite number of times. + + ConstantInt *StableValue = nullptr; + switch (OpCode) { + default: + llvm_unreachable("Impossible case!"); + + case Instruction::AShr: { + // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most + // bitwidth(K) iterations. + Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); + KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); + auto *Ty = cast<IntegerType>(RHS->getType()); + if (Known.isNonNegative()) + StableValue = ConstantInt::get(Ty, 0); + else if (Known.isNegative()) + StableValue = ConstantInt::get(Ty, -1, true); + else + return getCouldNotCompute(); + + break; + } + case Instruction::LShr: + case Instruction::Shl: + // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>} + // stabilize to 0 in at most bitwidth(K) iterations. + StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0); + break; + } + + auto *Result = + ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); + assert(Result->getType()->isIntegerTy(1) && + "Otherwise cannot be an operand to a branch instruction"); + + if (Result->isZeroValue()) { + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *UpperBound = + getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); + return ExitLimit(getCouldNotCompute(), UpperBound, false); + } + + return getCouldNotCompute(); +} + +/// Return true if we can constant fold an instruction of the specified type, +/// assuming that all operands were constants. +static bool CanConstantFold(const Instruction *I) { + if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || + isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || + isa<LoadInst>(I) || isa<ExtractValueInst>(I)) + return true; + + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (const Function *F = CI->getCalledFunction()) + return canConstantFoldCallTo(CI, F); + return false; +} + +/// Determine whether this instruction can constant evolve within this loop +/// assuming its operands can all constant evolve. +static bool canConstantEvolve(Instruction *I, const Loop *L) { + // An instruction outside of the loop can't be derived from a loop PHI. + if (!L->contains(I)) return false; + + if (isa<PHINode>(I)) { + // We don't currently keep track of the control flow needed to evaluate + // PHIs, so we cannot handle PHIs inside of loops. + return L->getHeader() == I->getParent(); + } + + // If we won't be able to constant fold this expression even if the operands + // are constants, bail early. + return CanConstantFold(I); +} + +/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by +/// recursing through each instruction operand until reaching a loop header phi. +static PHINode * +getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, + DenseMap<Instruction *, PHINode *> &PHIMap, + unsigned Depth) { + if (Depth > MaxConstantEvolvingDepth) + return nullptr; + + // Otherwise, we can evaluate this instruction if all of its operands are + // constant or derived from a PHI node themselves. + PHINode *PHI = nullptr; + for (Value *Op : UseInst->operands()) { + if (isa<Constant>(Op)) continue; + + Instruction *OpInst = dyn_cast<Instruction>(Op); + if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; + + PHINode *P = dyn_cast<PHINode>(OpInst); + if (!P) + // If this operand is already visited, reuse the prior result. + // We may have P != PHI if this is the deepest point at which the + // inconsistent paths meet. + P = PHIMap.lookup(OpInst); + if (!P) { + // Recurse and memoize the results, whether a phi is found or not. + // This recursive call invalidates pointers into PHIMap. + P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); + PHIMap[OpInst] = P; + } + if (!P) + return nullptr; // Not evolving from PHI + if (PHI && PHI != P) + return nullptr; // Evolving from multiple different PHIs. + PHI = P; + } + // This is a expression evolving from a constant PHI! + return PHI; +} + +/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node +/// in the loop that V is derived from. We allow arbitrary operations along the +/// way, but the operands of an operation must either be constants or a value +/// derived from a constant PHI. If this expression does not fit with these +/// constraints, return null. +static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I || !canConstantEvolve(I, L)) return nullptr; + + if (PHINode *PN = dyn_cast<PHINode>(I)) + return PN; + + // Record non-constant instructions contained by the loop. + DenseMap<Instruction *, PHINode *> PHIMap; + return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); +} + +/// EvaluateExpression - Given an expression that passes the +/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node +/// in the loop has the value PHIVal. If we can't fold this expression for some +/// reason, return null. +static Constant *EvaluateExpression(Value *V, const Loop *L, + DenseMap<Instruction *, Constant *> &Vals, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + // Convenient constant check, but redundant for recursive calls. + if (Constant *C = dyn_cast<Constant>(V)) return C; + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return nullptr; + + if (Constant *C = Vals.lookup(I)) return C; + + // An instruction inside the loop depends on a value outside the loop that we + // weren't given a mapping for, or a value such as a call inside the loop. + if (!canConstantEvolve(I, L)) return nullptr; + + // An unmapped PHI can be due to a branch or another loop inside this loop, + // or due to this not being the initial iteration through a loop where we + // couldn't compute the evolution of this particular PHI last time. + if (isa<PHINode>(I)) return nullptr; + + std::vector<Constant*> Operands(I->getNumOperands()); + + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i)); + if (!Operand) { + Operands[i] = dyn_cast<Constant>(I->getOperand(i)); + if (!Operands[i]) return nullptr; + continue; + } + Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); + Vals[Operand] = C; + if (!C) return nullptr; + Operands[i] = C; + } + + if (CmpInst *CI = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], DL, TLI); + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (!LI->isVolatile()) + return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); + } + return ConstantFoldInstOperands(I, Operands, DL, TLI); +} + + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { + Constant *IncomingVal = nullptr; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingBlock(i) == BB) + continue; + + auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i)); + if (!CurrentVal) + return nullptr; + + if (IncomingVal != CurrentVal) { + if (IncomingVal) + return nullptr; + IncomingVal = CurrentVal; + } + } + + return IncomingVal; +} + +/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is +/// in the header of its containing loop, we know the loop executes a +/// constant number of times, and the PHI node is just a recurrence +/// involving constants, fold it. +Constant * +ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, + const APInt &BEs, + const Loop *L) { + auto I = ConstantEvolutionLoopExitValue.find(PN); + if (I != ConstantEvolutionLoopExitValue.end()) + return I->second; + + if (BEs.ugt(MaxBruteForceIterations)) + return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it. + + Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; + + DenseMap<Instruction *, Constant *> CurrentIterVals; + BasicBlock *Header = L->getHeader(); + assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); + + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return nullptr; + + for (PHINode &PHI : Header->phis()) { + if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) + CurrentIterVals[&PHI] = StartCST; + } + if (!CurrentIterVals.count(PN)) + return RetVal = nullptr; + + Value *BEValue = PN->getIncomingValueForBlock(Latch); + + // Execute the loop symbolically to determine the exit value. + assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && + "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"); + + unsigned NumIterations = BEs.getZExtValue(); // must be in range + unsigned IterationNum = 0; + const DataLayout &DL = getDataLayout(); + for (; ; ++IterationNum) { + if (IterationNum == NumIterations) + return RetVal = CurrentIterVals[PN]; // Got exit value! + + // Compute the value of the PHIs for the next iteration. + // EvaluateExpression adds non-phi values to the CurrentIterVals map. + DenseMap<Instruction *, Constant *> NextIterVals; + Constant *NextPHI = + EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); + if (!NextPHI) + return nullptr; // Couldn't evaluate! + NextIterVals[PN] = NextPHI; + + bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; + + // Also evaluate the other PHI nodes. However, we don't get to stop if we + // cease to be able to evaluate one of them or if they stop evolving, + // because that doesn't necessarily prevent us from computing PN. + SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute; + for (const auto &I : CurrentIterVals) { + PHINode *PHI = dyn_cast<PHINode>(I.first); + if (!PHI || PHI == PN || PHI->getParent() != Header) continue; + PHIsToCompute.emplace_back(PHI, I.second); + } + // We use two distinct loops because EvaluateExpression may invalidate any + // iterators into CurrentIterVals. + for (const auto &I : PHIsToCompute) { + PHINode *PHI = I.first; + Constant *&NextPHI = NextIterVals[PHI]; + if (!NextPHI) { // Not already computed. + Value *BEValue = PHI->getIncomingValueForBlock(Latch); + NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); + } + if (NextPHI != I.second) + StoppedEvolving = false; + } + + // If all entries in CurrentIterVals == NextIterVals then we can stop + // iterating, the loop can't continue to change. + if (StoppedEvolving) + return RetVal = CurrentIterVals[PN]; + + CurrentIterVals.swap(NextIterVals); + } +} + +const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { + PHINode *PN = getConstantEvolvingPHI(Cond, L); + if (!PN) return getCouldNotCompute(); + + // If the loop is canonicalized, the PHI will have exactly two entries. + // That's the only form we support here. + if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); + + DenseMap<Instruction *, Constant *> CurrentIterVals; + BasicBlock *Header = L->getHeader(); + assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); + + BasicBlock *Latch = L->getLoopLatch(); + assert(Latch && "Should follow from NumIncomingValues == 2!"); + + for (PHINode &PHI : Header->phis()) { + if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) + CurrentIterVals[&PHI] = StartCST; + } + if (!CurrentIterVals.count(PN)) + return getCouldNotCompute(); + + // Okay, we find a PHI node that defines the trip count of this loop. Execute + // the loop symbolically to determine when the condition gets a value of + // "ExitWhen". + unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. + const DataLayout &DL = getDataLayout(); + for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ + auto *CondVal = dyn_cast_or_null<ConstantInt>( + EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); + + // Couldn't symbolically evaluate. + if (!CondVal) return getCouldNotCompute(); + + if (CondVal->getValue() == uint64_t(ExitWhen)) { + ++NumBruteForceTripCountsComputed; + return getConstant(Type::getInt32Ty(getContext()), IterationNum); + } + + // Update all the PHI nodes for the next iteration. + DenseMap<Instruction *, Constant *> NextIterVals; + + // Create a list of which PHIs we need to compute. We want to do this before + // calling EvaluateExpression on them because that may invalidate iterators + // into CurrentIterVals. + SmallVector<PHINode *, 8> PHIsToCompute; + for (const auto &I : CurrentIterVals) { + PHINode *PHI = dyn_cast<PHINode>(I.first); + if (!PHI || PHI->getParent() != Header) continue; + PHIsToCompute.push_back(PHI); + } + for (PHINode *PHI : PHIsToCompute) { + Constant *&NextPHI = NextIterVals[PHI]; + if (NextPHI) continue; // Already computed! + + Value *BEValue = PHI->getIncomingValueForBlock(Latch); + NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); + } + CurrentIterVals.swap(NextIterVals); + } + + // Too many iterations were needed to evaluate. + return getCouldNotCompute(); +} + +const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { + SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = + ValuesAtScopes[V]; + // Check to see if we've folded this expression at this loop before. + for (auto &LS : Values) + if (LS.first == L) + return LS.second ? LS.second : V; + + Values.emplace_back(L, nullptr); + + // Otherwise compute it. + const SCEV *C = computeSCEVAtScope(V, L); + for (auto &LS : reverse(ValuesAtScopes[V])) + if (LS.first == L) { + LS.second = C; + break; + } + return C; +} + +/// This builds up a Constant using the ConstantExpr interface. That way, we +/// will return Constants for objects which aren't represented by a +/// SCEVConstant, because SCEVConstant is restricted to ConstantInt. +/// Returns NULL if the SCEV isn't representable as a Constant. +static Constant *BuildConstantFromSCEV(const SCEV *V) { + switch (static_cast<SCEVTypes>(V->getSCEVType())) { + case scCouldNotCompute: + case scAddRecExpr: + break; + case scConstant: + return cast<SCEVConstant>(V)->getValue(); + case scUnknown: + return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); + case scSignExtend: { + const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) + return ConstantExpr::getSExt(CastOp, SS->getType()); + break; + } + case scZeroExtend: { + const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) + return ConstantExpr::getZExt(CastOp, SZ->getType()); + break; + } + case scTruncate: { + const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) + return ConstantExpr::getTrunc(CastOp, ST->getType()); + break; + } + case scAddExpr: { + const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); + if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { + if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { + unsigned AS = PTy->getAddressSpace(); + Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); + C = ConstantExpr::getBitCast(C, DestPtrTy); + } + for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { + Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); + if (!C2) return nullptr; + + // First pointer! + if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { + unsigned AS = C2->getType()->getPointerAddressSpace(); + std::swap(C, C2); + Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); + // The offsets have been converted to bytes. We can add bytes to an + // i8* by GEP with the byte count in the first index. + C = ConstantExpr::getBitCast(C, DestPtrTy); + } + + // Don't bother trying to sum two pointers. We probably can't + // statically compute a load that results from it anyway. + if (C2->getType()->isPointerTy()) + return nullptr; + + if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { + if (PTy->getElementType()->isStructTy()) + C2 = ConstantExpr::getIntegerCast( + C2, Type::getInt32Ty(C->getContext()), true); + C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); + } else + C = ConstantExpr::getAdd(C, C2); + } + return C; + } + break; + } + case scMulExpr: { + const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); + if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { + // Don't bother with pointers at all. + if (C->getType()->isPointerTy()) return nullptr; + for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { + Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); + if (!C2 || C2->getType()->isPointerTy()) return nullptr; + C = ConstantExpr::getMul(C, C2); + } + return C; + } + break; + } + case scUDivExpr: { + const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); + if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) + if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) + if (LHS->getType() == RHS->getType()) + return ConstantExpr::getUDiv(LHS, RHS); + break; + } + case scSMaxExpr: + case scUMaxExpr: + case scSMinExpr: + case scUMinExpr: + break; // TODO: smax, umax, smin, umax. + } + return nullptr; +} + +const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { + if (isa<SCEVConstant>(V)) return V; + + // If this instruction is evolved from a constant-evolving PHI, compute the + // exit value from the loop without using SCEVs. + if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { + if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { + if (PHINode *PN = dyn_cast<PHINode>(I)) { + const Loop *LI = this->LI[I->getParent()]; + // Looking for loop exit value. + if (LI && LI->getParentLoop() == L && + PN->getParent() == LI->getHeader()) { + // Okay, there is no closed form solution for the PHI node. Check + // to see if the loop that contains it has a known backedge-taken + // count. If so, we may be able to force computation of the exit + // value. + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); + // This trivial case can show up in some degenerate cases where + // the incoming IR has not yet been fully simplified. + if (BackedgeTakenCount->isZero()) { + Value *InitValue = nullptr; + bool MultipleInitValues = false; + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (!LI->contains(PN->getIncomingBlock(i))) { + if (!InitValue) + InitValue = PN->getIncomingValue(i); + else if (InitValue != PN->getIncomingValue(i)) { + MultipleInitValues = true; + break; + } + } + } + if (!MultipleInitValues && InitValue) + return getSCEV(InitValue); + } + // Do we have a loop invariant value flowing around the backedge + // for a loop which must execute the backedge? + if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && + isKnownPositive(BackedgeTakenCount) && + PN->getNumIncomingValues() == 2) { + unsigned InLoopPred = LI->contains(PN->getIncomingBlock(0)) ? 0 : 1; + const SCEV *OnBackedge = getSCEV(PN->getIncomingValue(InLoopPred)); + if (IsAvailableOnEntry(LI, DT, OnBackedge, PN->getParent())) + return OnBackedge; + } + if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) { + // Okay, we know how many times the containing loop executes. If + // this is a constant evolving PHI node, get the final value at + // the specified iteration number. + Constant *RV = + getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); + if (RV) return getSCEV(RV); + } + } + + // If there is a single-input Phi, evaluate it at our scope. If we can + // prove that this replacement does not break LCSSA form, use new value. + if (PN->getNumOperands() == 1) { + const SCEV *Input = getSCEV(PN->getOperand(0)); + const SCEV *InputAtScope = getSCEVAtScope(Input, L); + // TODO: We can generalize it using LI.replacementPreservesLCSSAForm, + // for the simplest case just support constants. + if (isa<SCEVConstant>(InputAtScope)) return InputAtScope; + } + } + + // Okay, this is an expression that we cannot symbolically evaluate + // into a SCEV. Check to see if it's possible to symbolically evaluate + // the arguments into constants, and if so, try to constant propagate the + // result. This is particularly useful for computing loop exit values. + if (CanConstantFold(I)) { + SmallVector<Constant *, 4> Operands; + bool MadeImprovement = false; + for (Value *Op : I->operands()) { + if (Constant *C = dyn_cast<Constant>(Op)) { + Operands.push_back(C); + continue; + } + + // If any of the operands is non-constant and if they are + // non-integer and non-pointer, don't even try to analyze them + // with scev techniques. + if (!isSCEVable(Op->getType())) + return V; + + const SCEV *OrigV = getSCEV(Op); + const SCEV *OpV = getSCEVAtScope(OrigV, L); + MadeImprovement |= OrigV != OpV; + + Constant *C = BuildConstantFromSCEV(OpV); + if (!C) return V; + if (C->getType() != Op->getType()) + C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, + Op->getType(), + false), + C, Op->getType()); + Operands.push_back(C); + } + + // Check to see if getSCEVAtScope actually made an improvement. + if (MadeImprovement) { + Constant *C = nullptr; + const DataLayout &DL = getDataLayout(); + if (const CmpInst *CI = dyn_cast<CmpInst>(I)) + C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], DL, &TLI); + else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (!LI->isVolatile()) + C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); + } else + C = ConstantFoldInstOperands(I, Operands, DL, &TLI); + if (!C) return V; + return getSCEV(C); + } + } + } + + // This is some other type of SCEVUnknown, just return it. + return V; + } + + if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) { + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { + const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + if (OpAtScope != Comm->getOperand(i)) { + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(), + Comm->op_begin()+i); + NewOps.push_back(OpAtScope); + + for (++i; i != e; ++i) { + OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + NewOps.push_back(OpAtScope); + } + if (isa<SCEVAddExpr>(Comm)) + return getAddExpr(NewOps, Comm->getNoWrapFlags()); + if (isa<SCEVMulExpr>(Comm)) + return getMulExpr(NewOps, Comm->getNoWrapFlags()); + if (isa<SCEVMinMaxExpr>(Comm)) + return getMinMaxExpr(Comm->getSCEVType(), NewOps); + llvm_unreachable("Unknown commutative SCEV type!"); + } + } + // If we got here, all operands are loop invariant. + return Comm; + } + + if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { + const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); + const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); + if (LHS == Div->getLHS() && RHS == Div->getRHS()) + return Div; // must be loop invariant + return getUDivExpr(LHS, RHS); + } + + // If this is a loop recurrence for a loop that does not contain L, then we + // are dealing with the final value computed by the loop. + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { + // First, attempt to evaluate each operand. + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); + if (OpAtScope == AddRec->getOperand(i)) + continue; + + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(), + AddRec->op_begin()+i); + NewOps.push_back(OpAtScope); + for (++i; i != e; ++i) + NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); + + const SCEV *FoldedRec = + getAddRecExpr(NewOps, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); + AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec); + // The addrec may be folded to a nonrecurrence, for example, if the + // induction variable is multiplied by zero after constant folding. Go + // ahead and return the folded value. + if (!AddRec) + return FoldedRec; + break; + } + + // If the scope is outside the addrec's loop, evaluate it by using the + // loop exit value of the addrec. + if (!AddRec->getLoop()->contains(L)) { + // To evaluate this recurrence, we need to know how many times the AddRec + // loop iterates. Compute this now. + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; + + // Then, evaluate the AddRec. + return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); + } + + return AddRec; + } + + if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) { + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getZeroExtendExpr(Op, Cast->getType()); + } + + if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) { + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getSignExtendExpr(Op, Cast->getType()); + } + + if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) { + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getTruncateExpr(Op, Cast->getType()); + } + + llvm_unreachable("Unknown SCEV type!"); +} + +const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { + return getSCEVAtScope(getSCEV(V), L); +} + +const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { + if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) + return stripInjectiveFunctions(ZExt->getOperand()); + if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) + return stripInjectiveFunctions(SExt->getOperand()); + return S; +} + +/// Finds the minimum unsigned root of the following equation: +/// +/// A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, + ScalarEvolution &SE) { + uint32_t BW = A.getBitWidth(); + assert(BW == SE.getTypeSizeInBits(B->getType())); + assert(A != 0 && "A must be non-zero."); + + // 1. D = gcd(A, N) + // + // The gcd of A and N may have only one prime factor: 2. The number of + // trailing zeros in A is its multiplicity + uint32_t Mult2 = A.countTrailingZeros(); + // D = 2^Mult2 + + // 2. Check if B is divisible by D. + // + // B is divisible by D if and only if the multiplicity of prime factor 2 for B + // is not less than multiplicity of this prime factor for D. + if (SE.GetMinTrailingZeros(B) < Mult2) + return SE.getCouldNotCompute(); + + // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic + // modulo (N / D). + // + // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent + // (N / D) in general. The inverse itself always fits into BW bits, though, + // so we immediately truncate it. + APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D + APInt Mod(BW + 1, 0); + Mod.setBit(BW - Mult2); // Mod = N / D + APInt I = AD.multiplicativeInverse(Mod).trunc(BW); + + // 4. Compute the minimum unsigned root of the equation: + // I * (B / D) mod (N / D) + // To simplify the computation, we factor out the divide by D: + // (I * B mod N) / D + const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); +} + +/// For a given quadratic addrec, generate coefficients of the corresponding +/// quadratic equation, multiplied by a common value to ensure that they are +/// integers. +/// The returned value is a tuple { A, B, C, M, BitWidth }, where +/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C +/// were multiplied by, and BitWidth is the bit width of the original addrec +/// coefficients. +/// This function returns None if the addrec coefficients are not compile- +/// time constants. +static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>> +GetQuadraticEquation(const SCEVAddRecExpr *AddRec) { + assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); + const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); + const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); + const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); + LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: " + << *AddRec << '\n'); + + // We currently can only solve this if the coefficients are constants. + if (!LC || !MC || !NC) { + LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n"); + return None; + } + + APInt L = LC->getAPInt(); + APInt M = MC->getAPInt(); + APInt N = NC->getAPInt(); + assert(!N.isNullValue() && "This is not a quadratic addrec"); + + unsigned BitWidth = LC->getAPInt().getBitWidth(); + unsigned NewWidth = BitWidth + 1; + LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: " + << BitWidth << '\n'); + // The sign-extension (as opposed to a zero-extension) here matches the + // extension used in SolveQuadraticEquationWrap (with the same motivation). + N = N.sext(NewWidth); + M = M.sext(NewWidth); + L = L.sext(NewWidth); + + // The increments are M, M+N, M+2N, ..., so the accumulated values are + // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is, + // L+M, L+2M+N, L+3M+3N, ... + // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N. + // + // The equation Acc = 0 is then + // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0. + // In a quadratic form it becomes: + // N n^2 + (2M-N) n + 2L = 0. + + APInt A = N; + APInt B = 2 * M - A; + APInt C = 2 * L; + APInt T = APInt(NewWidth, 2); + LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B + << "x + " << C << ", coeff bw: " << NewWidth + << ", multiplied by " << T << '\n'); + return std::make_tuple(A, B, C, T, BitWidth); +} + +/// Helper function to compare optional APInts: +/// (a) if X and Y both exist, return min(X, Y), +/// (b) if neither X nor Y exist, return None, +/// (c) if exactly one of X and Y exists, return that value. +static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) { + if (X.hasValue() && Y.hasValue()) { + unsigned W = std::max(X->getBitWidth(), Y->getBitWidth()); + APInt XW = X->sextOrSelf(W); + APInt YW = Y->sextOrSelf(W); + return XW.slt(YW) ? *X : *Y; + } + if (!X.hasValue() && !Y.hasValue()) + return None; + return X.hasValue() ? *X : *Y; +} + +/// Helper function to truncate an optional APInt to a given BitWidth. +/// When solving addrec-related equations, it is preferable to return a value +/// that has the same bit width as the original addrec's coefficients. If the +/// solution fits in the original bit width, truncate it (except for i1). +/// Returning a value of a different bit width may inhibit some optimizations. +/// +/// In general, a solution to a quadratic equation generated from an addrec +/// may require BW+1 bits, where BW is the bit width of the addrec's +/// coefficients. The reason is that the coefficients of the quadratic +/// equation are BW+1 bits wide (to avoid truncation when converting from +/// the addrec to the equation). +static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) { + if (!X.hasValue()) + return None; + unsigned W = X->getBitWidth(); + if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth)) + return X->trunc(BitWidth); + return X; +} + +/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n +/// iterations. The values L, M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW, +/// where BW is the bit width of the addrec's coefficients. +/// If the calculated value is a BW-bit integer (for BW > 1), it will be +/// returned as such, otherwise the bit width of the returned value may +/// be greater than BW. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases +/// like x^2 = 5, no integer solutions exist, in other cases an integer +/// solution may exist, but SolveQuadraticEquationWrap may fail to find it. +static Optional<APInt> +SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) + return None; + + std::tie(A, B, C, M, BitWidth) = *T; + LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n"); + Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1); + if (!X.hasValue()) + return None; + + ConstantInt *CX = ConstantInt::get(SE.getContext(), *X); + ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE); + if (!V->isZero()) + return None; + + return TruncIfPossible(X, BitWidth); +} + +/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n +/// iterations. The values M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n such that c(n) does not belong to the given range, +/// while c(n-1) does. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution for the +/// bounds of the range. +static Optional<APInt> +SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, + const ConstantRange &Range, ScalarEvolution &SE) { + assert(AddRec->getOperand(0)->isZero() && + "Starting value of addrec should be 0"); + LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range " + << Range << ", addrec " << *AddRec << '\n'); + // This case is handled in getNumIterationsInRange. Here we can assume that + // we start in the range. + assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && + "Addrec's initial value should be in range"); + + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) + return None; + + // Be careful about the return value: there can be two reasons for not + // returning an actual number. First, if no solutions to the equations + // were found, and second, if the solutions don't leave the given range. + // The first case means that the actual solution is "unknown", the second + // means that it's known, but not valid. If the solution is unknown, we + // cannot make any conclusions. + // Return a pair: the optional solution and a flag indicating if the + // solution was found. + auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> { + // Solve for signed overflow and unsigned overflow, pick the lower + // solution. + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary " + << Bound << " (before multiplying by " << M << ")\n"); + Bound *= M; // The quadratic equation multiplier. + + Optional<APInt> SO = None; + if (BitWidth > 1) { + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "signed overflow\n"); + SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth); + } + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "unsigned overflow\n"); + Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, + BitWidth+1); + + auto LeavesRange = [&] (const APInt &X) { + ConstantInt *C0 = ConstantInt::get(SE.getContext(), X); + ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE); + if (Range.contains(V0->getValue())) + return false; + // X should be at least 1, so X-1 is non-negative. + ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1); + ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE); + if (Range.contains(V1->getValue())) + return true; + return false; + }; + + // If SolveQuadraticEquationWrap returns None, it means that there can + // be a solution, but the function failed to find it. We cannot treat it + // as "no solution". + if (!SO.hasValue() || !UO.hasValue()) + return { None, false }; + + // Check the smaller value first to see if it leaves the range. + // At this point, both SO and UO must have values. + Optional<APInt> Min = MinOptional(SO, UO); + if (LeavesRange(*Min)) + return { Min, true }; + Optional<APInt> Max = Min == SO ? UO : SO; + if (LeavesRange(*Max)) + return { Max, true }; + + // Solutions were found, but were eliminated, hence the "true". + return { None, true }; + }; + + std::tie(A, B, C, M, BitWidth) = *T; + // Lower bound is inclusive, subtract 1 to represent the exiting value. + APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1; + APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth()); + auto SL = SolveForBoundary(Lower); + auto SU = SolveForBoundary(Upper); + // If any of the solutions was unknown, no meaninigful conclusions can + // be made. + if (!SL.second || !SU.second) + return None; + + // Claim: The correct solution is not some value between Min and Max. + // + // Justification: Assuming that Min and Max are different values, one of + // them is when the first signed overflow happens, the other is when the + // first unsigned overflow happens. Crossing the range boundary is only + // possible via an overflow (treating 0 as a special case of it, modeling + // an overflow as crossing k*2^W for some k). + // + // The interesting case here is when Min was eliminated as an invalid + // solution, but Max was not. The argument is that if there was another + // overflow between Min and Max, it would also have been eliminated if + // it was considered. + // + // For a given boundary, it is possible to have two overflows of the same + // type (signed/unsigned) without having the other type in between: this + // can happen when the vertex of the parabola is between the iterations + // corresponding to the overflows. This is only possible when the two + // overflows cross k*2^W for the same k. In such case, if the second one + // left the range (and was the first one to do so), the first overflow + // would have to enter the range, which would mean that either we had left + // the range before or that we started outside of it. Both of these cases + // are contradictions. + // + // Claim: In the case where SolveForBoundary returns None, the correct + // solution is not some value between the Max for this boundary and the + // Min of the other boundary. + // + // Justification: Assume that we had such Max_A and Min_B corresponding + // to range boundaries A and B and such that Max_A < Min_B. If there was + // a solution between Max_A and Min_B, it would have to be caused by an + // overflow corresponding to either A or B. It cannot correspond to B, + // since Min_B is the first occurrence of such an overflow. If it + // corresponded to A, it would have to be either a signed or an unsigned + // overflow that is larger than both eliminated overflows for A. But + // between the eliminated overflows and this overflow, the values would + // cover the entire value space, thus crossing the other boundary, which + // is a contradiction. + + return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, + bool AllowPredicates) { + + // This is only used for loops with a "x != y" exit test. The exit condition + // is now expressed as a single expression, V = x-y. So the exit test is + // effectively V != 0. We know and take advantage of the fact that this + // expression only being used in a comparison by zero context. + + SmallPtrSet<const SCEVPredicate *, 4> Predicates; + // If the value is a constant + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { + // If the value is already zero, the branch will execute zero times. + if (C->getValue()->isZero()) return C; + return getCouldNotCompute(); // Otherwise it will loop infinitely. + } + + const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V)); + + if (!AddRec && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); + + if (!AddRec || AddRec->getLoop() != L) + return getCouldNotCompute(); + + // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of + // the quadratic equation to solve it. + if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { + // We can only use this value if the chrec ends up with an exact zero + // value at this index. When solving for "X*X != 5", for example, we + // should not accept a root of 2. + if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { + const auto *R = cast<SCEVConstant>(getConstant(S.getValue())); + return ExitLimit(R, R, false, Predicates); + } + return getCouldNotCompute(); + } + + // Otherwise we can only handle this if it is affine. + if (!AddRec->isAffine()) + return getCouldNotCompute(); + + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) + // + // equivalent to: + // + // Step*N = -Start (mod 2^BW) + // + // where BW is the common bit width of Start and Step. + + // Get the initial value for the loop. + const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); + const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + + // For now we handle only constant steps. + // + // TODO: Handle a nonconstant Step given AddRec<NUW>. If the + // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap + // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. + // We have not yet seen any such cases. + const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); + if (!StepC || StepC->getValue()->isZero()) + return getCouldNotCompute(); + + // For positive steps (counting up until unsigned overflow): + // N = -Start/Step (as unsigned) + // For negative steps (counting down to zero): + // N = Start/-Step + // First compute the unsigned distance from zero in the direction of Step. + bool CountDown = StepC->getAPInt().isNegative(); + const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); + + // Handle unitary steps, which cannot wraparound. + // 1*N = -Start; -1*N = Start (mod 2^BW), so: + // N = Distance (as unsigned) + if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) { + APInt MaxBECount = getUnsignedRangeMax(Distance); + + // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, + // we end up with a loop whose backedge-taken count is n - 1. Detect this + // case, and see if we can improve the bound. + // + // Explicitly handling this here is necessary because getUnsignedRange + // isn't context-sensitive; it doesn't know that we only care about the + // range inside the loop. + const SCEV *Zero = getZero(Distance->getType()); + const SCEV *One = getOne(Distance->getType()); + const SCEV *DistancePlusOne = getAddExpr(Distance, One); + if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { + // If Distance + 1 doesn't overflow, we can compute the maximum distance + // as "unsigned_max(Distance + 1) - 1". + ConstantRange CR = getUnsignedRange(DistancePlusOne); + MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); + } + return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); + } + + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->hasNoSelfWrap() && + loopHasNoAbnormalExits(AddRec->getLoop())) { + const SCEV *Exact = + getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + const SCEV *Max = + Exact == getCouldNotCompute() + ? Exact + : getConstant(getUnsignedRangeMax(Exact)); + return ExitLimit(Exact, Max, false, Predicates); + } + + // Solve the general equation. + const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(), + getNegativeSCEV(Start), *this); + const SCEV *M = E == getCouldNotCompute() + ? E + : getConstant(getUnsignedRangeMax(E)); + return ExitLimit(E, M, false, Predicates); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { + // Loops that look like: while (X == 0) are very strange indeed. We don't + // handle them yet except for the trivial case. This could be expanded in the + // future as needed. + + // If the value is a constant, check to see if it is known to be non-zero + // already. If so, the backedge will execute zero times. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { + if (!C->getValue()->isZero()) + return getZero(C->getType()); + return getCouldNotCompute(); // Otherwise it will loop infinitely. + } + + // We could implement others, but I really doubt anyone writes loops like + // this, and if they did, they would already be constant folded. + return getCouldNotCompute(); +} + +std::pair<BasicBlock *, BasicBlock *> +ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { + // If the block has a unique predecessor, then there is no path from the + // predecessor to the block that does not go through the direct edge + // from the predecessor to the block. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return {Pred, BB}; + + // A loop's header is defined to be a block that dominates the loop. + // If the header has a unique predecessor outside the loop, it must be + // a block that has exactly one successor that can reach the loop. + if (Loop *L = LI.getLoopFor(BB)) + return {L->getLoopPredecessor(), L->getHeader()}; + + return {nullptr, nullptr}; +} + +/// SCEV structural equivalence is usually sufficient for testing whether two +/// expressions are equal, however for the purposes of looking for a condition +/// guarding a loop, it can be useful to be a little more general, since a +/// front-end may have replicated the controlling expression. +static bool HasSameValue(const SCEV *A, const SCEV *B) { + // Quick check to see if they are the same SCEV. + if (A == B) return true; + + auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { + // Not all instructions that are "identical" compute the same value. For + // instance, two distinct alloca instructions allocating the same type are + // identical and do not read memory; but compute distinct values. + return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A)); + }; + + // Otherwise, if they're both SCEVUnknown, it's possible that they hold + // two different instructions with the same value. Check for this case. + if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) + if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) + if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) + if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) + if (ComputesEqualValues(AI, BI)) + return true; + + // Otherwise assume they may have a different value. + return false; +} + +bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, + const SCEV *&LHS, const SCEV *&RHS, + unsigned Depth) { + bool Changed = false; + // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or + // '0 != 0'. + auto TrivialCase = [&](bool TriviallyTrue) { + LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); + Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return true; + }; + // If we hit the max recursion limit bail out. + if (Depth >= 3) + return false; + + // Canonicalize a constant to the right side. + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { + // Check for both operands constant. + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { + if (ConstantExpr::getICmp(Pred, + LHSC->getValue(), + RHSC->getValue())->isNullValue()) + return TrivialCase(false); + else + return TrivialCase(true); + } + // Otherwise swap the operands to put the constant on the right. + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + Changed = true; + } + + // If we're comparing an addrec with a value which is loop-invariant in the + // addrec's loop, put the addrec on the left. Also make a dominance check, + // as both operands could be addrecs loop-invariant in each other's loop. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) { + const Loop *L = AR->getLoop(); + if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + Changed = true; + } + } + + // If there's a constant operand, canonicalize comparisons with boundary + // cases, and canonicalize *-or-equal comparisons to regular comparisons. + if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { + const APInt &RA = RC->getAPInt(); + + bool SimplifiedByConstantRange = false; + + if (!ICmpInst::isEquality(Pred)) { + ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); + if (ExactCR.isFullSet()) + return TrivialCase(true); + else if (ExactCR.isEmptySet()) + return TrivialCase(false); + + APInt NewRHS; + CmpInst::Predicate NewPred; + if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && + ICmpInst::isEquality(NewPred)) { + // We were able to convert an inequality to an equality. + Pred = NewPred; + RHS = getConstant(NewRHS); + Changed = SimplifiedByConstantRange = true; + } + } + + if (!SimplifiedByConstantRange) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. + if (!RA) + if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) + if (const SCEVMulExpr *ME = + dyn_cast<SCEVMulExpr>(AE->getOperand(0))) + if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && + ME->getOperand(0)->isAllOnesValue()) { + RHS = AE->getOperand(1); + LHS = ME->getOperand(1); + Changed = true; + } + break; + + + // The "Should have been caught earlier!" messages refer to the fact + // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above + // should have fired on the corresponding cases, and canonicalized the + // check to trivial case. + + case ICmpInst::ICMP_UGE: + assert(!RA.isMinValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_UGT; + RHS = getConstant(RA - 1); + Changed = true; + break; + case ICmpInst::ICMP_ULE: + assert(!RA.isMaxValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_ULT; + RHS = getConstant(RA + 1); + Changed = true; + break; + case ICmpInst::ICMP_SGE: + assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_SGT; + RHS = getConstant(RA - 1); + Changed = true; + break; + case ICmpInst::ICMP_SLE: + assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_SLT; + RHS = getConstant(RA + 1); + Changed = true; + break; + } + } + } + + // Check for obvious equality. + if (HasSameValue(LHS, RHS)) { + if (ICmpInst::isTrueWhenEqual(Pred)) + return TrivialCase(true); + if (ICmpInst::isFalseWhenEqual(Pred)) + return TrivialCase(false); + } + + // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by + // adding or subtracting 1 from one of the operands. + switch (Pred) { + case ICmpInst::ICMP_SLE: + if (!getSignedRangeMax(RHS).isMaxSignedValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SLT; + Changed = true; + } else if (!getSignedRangeMin(LHS).isMinSignedValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SLT; + Changed = true; + } + break; + case ICmpInst::ICMP_SGE: + if (!getSignedRangeMin(RHS).isMinSignedValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SGT; + Changed = true; + } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SGT; + Changed = true; + } + break; + case ICmpInst::ICMP_ULE: + if (!getUnsignedRangeMax(RHS).isMaxValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_ULT; + Changed = true; + } else if (!getUnsignedRangeMin(LHS).isMinValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); + Pred = ICmpInst::ICMP_ULT; + Changed = true; + } + break; + case ICmpInst::ICMP_UGE: + if (!getUnsignedRangeMin(RHS).isMinValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); + Pred = ICmpInst::ICMP_UGT; + Changed = true; + } else if (!getUnsignedRangeMax(LHS).isMaxValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_UGT; + Changed = true; + } + break; + default: + break; + } + + // TODO: More simplifications are possible here. + + // Recursively simplify until we either hit a recursion limit or nothing + // changes. + if (Changed) + return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); + + return Changed; +} + +bool ScalarEvolution::isKnownNegative(const SCEV *S) { + return getSignedRangeMax(S).isNegative(); +} + +bool ScalarEvolution::isKnownPositive(const SCEV *S) { + return getSignedRangeMin(S).isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { + return !getSignedRangeMin(S).isNegative(); +} + +bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { + return !getSignedRangeMax(S).isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonZero(const SCEV *S) { + return isKnownNegative(S) || isKnownPositive(S); +} + +std::pair<const SCEV *, const SCEV *> +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { + // Compute SCEV on entry of loop L. + const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + if (Start == getCouldNotCompute()) + return { Start, Start }; + // Compute post increment SCEV for loop L. + const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); + return { Start, PostInc }; +} + +bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // First collect all loops. + SmallPtrSet<const Loop *, 8> LoopsUsed; + getUsedLoops(LHS, LoopsUsed); + getUsedLoops(RHS, LoopsUsed); + + if (LoopsUsed.empty()) + return false; + + // Domination relationship must be a linear order on collected loops. +#ifndef NDEBUG + for (auto *L1 : LoopsUsed) + for (auto *L2 : LoopsUsed) + assert((DT.dominates(L1->getHeader(), L2->getHeader()) || + DT.dominates(L2->getHeader(), L1->getHeader())) && + "Domination relationship is not a linear order"); +#endif + + const Loop *MDL = + *std::max_element(LoopsUsed.begin(), LoopsUsed.end(), + [&](const Loop *L1, const Loop *L2) { + return DT.properlyDominates(L1->getHeader(), L2->getHeader()); + }); + + // Get init and post increment value for LHS. + auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); + // if LHS contains unknown non-invariant SCEV then bail out. + if (SplitLHS.first == getCouldNotCompute()) + return false; + assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC"); + // Get init and post increment value for RHS. + auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); + // if RHS contains unknown non-invariant SCEV then bail out. + if (SplitRHS.first == getCouldNotCompute()) + return false; + assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC"); + // It is possible that init SCEV contains an invariant load but it does + // not dominate MDL and is not available at MDL loop entry, so we should + // check it here. + if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) || + !isAvailableAtLoopEntry(SplitRHS.first, MDL)) + return false; + + return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) && + isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, + SplitRHS.second); +} + +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Canonicalize the inputs first. + (void)SimplifyICmpOperands(Pred, LHS, RHS); + + if (isKnownViaInduction(Pred, LHS, RHS)) + return true; + + if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) + return true; + + // Otherwise see what can be done with some simple reasoning. + return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS); +} + +bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, + const SCEVAddRecExpr *LHS, + const SCEV *RHS) { + const Loop *L = LHS->getLoop(); + return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && + isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); +} + +bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); + +#ifndef NDEBUG + // Verify an invariant: inverting the predicate should turn a monotonically + // increasing change to a monotonically decreasing one, and vice versa. + bool IncreasingSwapped; + bool ResultSwapped = isMonotonicPredicateImpl( + LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); + + assert(Result == ResultSwapped && "should be able to analyze both!"); + if (ResultSwapped) + assert(Increasing == !IncreasingSwapped && + "monotonicity should flip as we flip the predicate"); +#endif + + return Result; +} + +bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + + // A zero step value for LHS means the induction variable is essentially a + // loop invariant value. We don't really depend on the predicate actually + // flipping from false to true (for increasing predicates, and the other way + // around for decreasing predicates), all we care about is that *if* the + // predicate changes then it only changes from false to true. + // + // A zero step value in itself is not very useful, but there may be places + // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be + // as general as possible. + + switch (Pred) { + default: + return false; // Conservative answer + + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (!LHS->hasNoUnsignedWrap()) + return false; + + Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + if (!LHS->hasNoSignedWrap()) + return false; + + const SCEV *Step = LHS->getStepRecurrence(*this); + + if (isKnownNonNegative(Step)) { + Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; + return true; + } + + if (isKnownNonPositive(Step)) { + Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; + return true; + } + + return false; + } + + } + + llvm_unreachable("switch has default clause!"); +} + +bool ScalarEvolution::isLoopInvariantPredicate( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, + ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, + const SCEV *&InvariantRHS) { + + // If there is a loop-invariant, force it into the RHS, otherwise bail out. + if (!isLoopInvariant(RHS, L)) { + if (!isLoopInvariant(LHS, L)) + return false; + + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS); + if (!ArLHS || ArLHS->getLoop() != L) + return false; + + bool Increasing; + if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) + return false; + + // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to + // true as the loop iterates, and the backedge is control dependent on + // "ArLHS `Pred` RHS" == true then we can reason as follows: + // + // * if the predicate was false in the first iteration then the predicate + // is never evaluated again, since the loop exits without taking the + // backedge. + // * if the predicate was true in the first iteration then it will + // continue to be true for all future iterations since it is + // monotonically increasing. + // + // For both the above possibilities, we can replace the loop varying + // predicate with its value on the first iteration of the loop (which is + // loop invariant). + // + // A similar reasoning applies for a monotonically decreasing predicate, by + // replacing true with false and false with true in the above two bullets. + + auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); + + if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) + return false; + + InvariantPred = Pred; + InvariantLHS = ArLHS->getStart(); + InvariantRHS = RHS; + return true; +} + +bool ScalarEvolution::isKnownPredicateViaConstantRanges( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + if (HasSameValue(LHS, RHS)) + return ICmpInst::isTrueWhenEqual(Pred); + + // This code is split out from isKnownPredicate because it is called from + // within isLoopEntryGuardedByCond. + + auto CheckRanges = + [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { + return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) + .contains(RangeLHS); + }; + + // The check at the top of the function catches the case where the values are + // known to be equal. + if (Pred == CmpInst::ICMP_EQ) + return false; + + if (Pred == CmpInst::ICMP_NE) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || + CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || + isKnownNonZero(getMinusSCEV(LHS, RHS)); + + if (CmpInst::isSigned(Pred)) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + + return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); +} + +bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer. + // Return Y via OutY. + auto MatchBinaryAddToConst = + [this](const SCEV *Result, const SCEV *X, APInt &OutY, + SCEV::NoWrapFlags ExpectedFlags) { + const SCEV *NonConstOp, *ConstOp; + SCEV::NoWrapFlags FlagsPresent; + + if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || + !isa<SCEVConstant>(ConstOp) || NonConstOp != X) + return false; + + OutY = cast<SCEVConstant>(ConstOp)->getAPInt(); + return (FlagsPresent & ExpectedFlags) == ExpectedFlags; + }; + + APInt C; + + switch (Pred) { + default: + break; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLE: + // X s<= (X + C)<nsw> if C >= 0 + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) + return true; + + // (X + C)<nsw> s<= X if C <= 0 + if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && + !C.isStrictlyPositive()) + return true; + break; + + case ICmpInst::ICMP_SGT: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLT: + // X s< (X + C)<nsw> if C > 0 + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && + C.isStrictlyPositive()) + return true; + + // (X + C)<nsw> s< X if C < 0 + if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) + return true; + break; + } + + return false; +} + +bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) + return false; + + // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on + // the stack can result in exponential time complexity. + SaveAndRestore<bool> Restore(ProvingSplitPredicate, true); + + // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L + // + // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use + // isKnownPredicate. isKnownPredicate is more powerful, but also more + // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the + // interesting cases seen in practice. We can consider "upgrading" L >= 0 to + // use isKnownPredicate later if needed. + return isKnownNonNegative(RHS) && + isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && + isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); +} + +bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // No need to even try if we know the module has no guards. + if (!HasGuards) + return false; + + return any_of(*BB, [&](Instruction &I) { + using namespace llvm::PatternMatch; + + Value *Condition; + return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( + m_Value(Condition))) && + isImpliedCond(Pred, LHS, RHS, Condition, false); + }); +} + +/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is +/// protected by a conditional between LHS and RHS. This is used to +/// to eliminate casts. +bool +ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) return true; + + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + + + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) + return true; + + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + + BranchInst *LoopContinuePredicate = + dyn_cast<BranchInst>(Latch->getTerminator()); + if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && + isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), + LoopContinuePredicate->getSuccessor(0) != L->getHeader())) + return true; + + // We don't want more than one activation of the following loops on the stack + // -- that can lead to O(n!) time complexity. + if (WalkingBEDominatingConds) + return false; + + SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true); + + // See if we can exploit a trip count to prove the predicate. + const auto &BETakenInfo = getBackedgeTakenInfo(L); + const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); + if (LatchBECount != getCouldNotCompute()) { + // We know that Latch branches back to the loop header exactly + // LatchBECount times. This means the backdege condition at Latch is + // equivalent to "{0,+,1} u< LatchBECount". + Type *Ty = LatchBECount->getType(); + auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); + const SCEV *LoopCounter = + getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); + if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, + LatchBECount)) + return true; + } + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC.assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast<CallInst>(AssumeVH); + if (!DT.dominates(CI, Latch->getTerminator())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + + // If the loop is not reachable from the entry block, we risk running into an + // infinite loop as we walk up into the dom tree. These loops do not matter + // anyway, so we just return a conservative answer when we see them. + if (!DT.isReachableFromEntry(L->getHeader())) + return false; + + if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) + return true; + + for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; + DTN != HeaderDTN; DTN = DTN->getIDom()) { + assert(DTN && "should reach the loop header before reaching the root!"); + + BasicBlock *BB = DTN->getBlock(); + if (isImpliedViaGuard(BB, Pred, LHS, RHS)) + return true; + + BasicBlock *PBB = BB->getSinglePredecessor(); + if (!PBB) + continue; + + BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator()); + if (!ContinuePredicate || !ContinuePredicate->isConditional()) + continue; + + Value *Condition = ContinuePredicate->getCondition(); + + // If we have an edge `E` within the loop body that dominates the only + // latch, the condition guarding `E` also guards the backedge. This + // reasoning works only for loops with a single latch. + + BasicBlockEdge DominatingEdge(PBB, BB); + if (DominatingEdge.isSingleEdge()) { + // We're constructively (and conservatively) enumerating edges within the + // loop body that dominate the latch. The dominator tree better agree + // with us on this: + assert(DT.dominates(DominatingEdge, Latch) && "should be!"); + + if (isImpliedCond(Pred, LHS, RHS, Condition, + BB != ContinuePredicate->getSuccessor(0))) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) return false; + + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + + // Both LHS and RHS must be available at loop entry. + assert(isAvailableAtLoopEntry(LHS, L) && + "LHS is not available at Loop Entry"); + assert(isAvailableAtLoopEntry(RHS, L) && + "RHS is not available at Loop Entry"); + + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) + return true; + + // If we cannot prove strict comparison (e.g. a > b), maybe we can prove + // the facts (a >= b && a != b) separately. A typical situation is when the + // non-strict comparison is known from ranges and non-equality is known from + // dominating predicates. If we are proving strict comparison, we always try + // to prove non-equality and non-strict comparison separately. + auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred); + const bool ProvingStrictComparison = (Pred != NonStrictPredicate); + bool ProvedNonStrictComparison = false; + bool ProvedNonEquality = false; + + if (ProvingStrictComparison) { + ProvedNonStrictComparison = + isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS); + ProvedNonEquality = + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + + // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard. + auto ProveViaGuard = [&](BasicBlock *Block) { + if (isImpliedViaGuard(Block, Pred, LHS, RHS)) + return true; + if (ProvingStrictComparison) { + if (!ProvedNonStrictComparison) + ProvedNonStrictComparison = + isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS); + if (!ProvedNonEquality) + ProvedNonEquality = + isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + return false; + }; + + // Try to prove (Pred, LHS, RHS) using isImpliedCond. + auto ProveViaCond = [&](Value *Condition, bool Inverse) { + if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse)) + return true; + if (ProvingStrictComparison) { + if (!ProvedNonStrictComparison) + ProvedNonStrictComparison = + isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse); + if (!ProvedNonEquality) + ProvedNonEquality = + isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + return false; + }; + + // Starting at the loop predecessor, climb up the predecessor chain, as long + // as there are predecessors that can be found that have unique successors + // leading to the original header. + for (std::pair<BasicBlock *, BasicBlock *> + Pair(L->getLoopPredecessor(), L->getHeader()); + Pair.first; + Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + + if (ProveViaGuard(Pair.first)) + return true; + + BranchInst *LoopEntryPredicate = + dyn_cast<BranchInst>(Pair.first->getTerminator()); + if (!LoopEntryPredicate || + LoopEntryPredicate->isUnconditional()) + continue; + + if (ProveViaCond(LoopEntryPredicate->getCondition(), + LoopEntryPredicate->getSuccessor(0) != Pair.second)) + return true; + } + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC.assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast<CallInst>(AssumeVH); + if (!DT.dominates(CI, L->getHeader())) + continue; + + if (ProveViaCond(CI->getArgOperand(0), false)) + return true; + } + + return false; +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + Value *FoundCondValue, + bool Inverse) { + if (!PendingLoopPredicates.insert(FoundCondValue).second) + return false; + + auto ClearOnExit = + make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); + + // Recursively handle And and Or conditions. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { + if (BO->getOpcode() == Instruction::And) { + if (!Inverse) + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + } else if (BO->getOpcode() == Instruction::Or) { + if (Inverse) + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + } + } + + ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); + if (!ICI) return false; + + // Now that we found a conditional branch that dominates the loop or controls + // the loop latch. Check to see if it is the comparison we are looking for. + ICmpInst::Predicate FoundPred; + if (Inverse) + FoundPred = ICI->getInversePredicate(); + else + FoundPred = ICI->getPredicate(); + + const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); + const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); + + return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, + ICmpInst::Predicate FoundPred, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + // Balance the types. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + LHS = getSignExtendExpr(LHS, FoundLHS->getType()); + RHS = getSignExtendExpr(RHS, FoundLHS->getType()); + } else { + LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); + RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); + } + } else if (getTypeSizeInBits(LHS->getType()) > + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(FoundPred)) { + FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); + } else { + FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); + } + } + + // Canonicalize the query to match the way instcombine will have + // canonicalized the comparison. + if (SimplifyICmpOperands(Pred, LHS, RHS)) + if (LHS == RHS) + return CmpInst::isTrueWhenEqual(Pred); + if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) + if (FoundLHS == FoundRHS) + return CmpInst::isFalseWhenEqual(FoundPred); + + // Check to see if we can make the LHS or RHS match. + if (LHS == FoundRHS || RHS == FoundLHS) { + if (isa<SCEVConstant>(RHS)) { + std::swap(FoundLHS, FoundRHS); + FoundPred = ICmpInst::getSwappedPredicate(FoundPred); + } else { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + } + + // Check whether the found predicate is the same as the desired predicate. + if (FoundPred == Pred) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + + // Check whether swapping the found predicate makes it the same as the + // desired predicate. + if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { + if (isa<SCEVConstant>(RHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); + else + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), + RHS, LHS, FoundLHS, FoundRHS); + } + + // Unsigned comparison is the same as signed comparison when both the operands + // are non-negative. + if (CmpInst::isUnsigned(FoundPred) && + CmpInst::getSignedPredicate(FoundPred) == Pred && + isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa<SCEVConstant>(FoundLHS)) { + C = cast<SCEVConstant>(FoundLHS); + V = FoundRHS; + } else { + C = cast<SCEVConstant>(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRangeMin(V) : getUnsignedRangeMin(V); + + if (Min == C->getAPInt()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + LLVM_FALLTHROUGH; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + LLVM_FALLTHROUGH; + + default: + // No change + break; + } + } + } + + // Check whether the actual condition is beyond sufficient. + if (FoundPred == ICmpInst::ICMP_EQ) + if (ICmpInst::isTrueWhenEqual(Pred)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + if (Pred == ICmpInst::ICMP_NE) + if (!ICmpInst::isTrueWhenEqual(FoundPred)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + // Otherwise assume the worst. + return false; +} + +bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, + const SCEV *&L, const SCEV *&R, + SCEV::NoWrapFlags &Flags) { + const auto *AE = dyn_cast<SCEVAddExpr>(Expr); + if (!AE || AE->getNumOperands() != 2) + return false; + + L = AE->getOperand(0); + R = AE->getOperand(1); + Flags = AE->getNoWrapFlags(); + return true; +} + +Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, + const SCEV *Less) { + // We avoid subtracting expressions here because this function is usually + // fairly deep in the call stack (i.e. is called many times). + + // X - X = 0. + if (More == Less) + return APInt(getTypeSizeInBits(More->getType()), 0); + + if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) { + const auto *LAR = cast<SCEVAddRecExpr>(Less); + const auto *MAR = cast<SCEVAddRecExpr>(More); + + if (LAR->getLoop() != MAR->getLoop()) + return None; + + // We look at affine expressions only; not for correctness but to keep + // getStepRecurrence cheap. + if (!LAR->isAffine() || !MAR->isAffine()) + return None; + + if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) + return None; + + Less = LAR->getStart(); + More = MAR->getStart(); + + // fall through + } + + if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) { + const auto &M = cast<SCEVConstant>(More)->getAPInt(); + const auto &L = cast<SCEVConstant>(Less)->getAPInt(); + return M - L; + } + + SCEV::NoWrapFlags Flags; + const SCEV *LLess = nullptr, *RLess = nullptr; + const SCEV *LMore = nullptr, *RMore = nullptr; + const SCEVConstant *C1 = nullptr, *C2 = nullptr; + // Compare (X + C1) vs X. + if (splitBinaryAdd(Less, LLess, RLess, Flags)) + if ((C1 = dyn_cast<SCEVConstant>(LLess))) + if (RLess == More) + return -(C1->getAPInt()); + + // Compare X vs (X + C2). + if (splitBinaryAdd(More, LMore, RMore, Flags)) + if ((C2 = dyn_cast<SCEVConstant>(LMore))) + if (RMore == Less) + return C2->getAPInt(); + + // Compare (X + C1) vs (X + C2). + if (C1 && C2 && RLess == RMore) + return C2->getAPInt() - C1->getAPInt(); + + return None; +} + +bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) + return false; + + const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS); + if (!AddRecLHS) + return false; + + const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS); + if (!AddRecFoundLHS) + return false; + + // We'd like to let SCEV reason about control dependencies, so we constrain + // both the inequalities to be about add recurrences on the same loop. This + // way we can use isLoopEntryGuardedByCond later. + + const Loop *L = AddRecFoundLHS->getLoop(); + if (L != AddRecLHS->getLoop()) + return false; + + // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) + // + // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) + // ... (2) + // + // Informal proof for (2), assuming (1) [*]: + // + // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] + // + // Then + // + // FoundLHS s< FoundRHS s< INT_MIN - C + // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] + // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] + // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< + // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] + // <=> FoundLHS + C s< FoundRHS + C + // + // [*]: (1) can be proved by ruling out overflow. + // + // [**]: This can be proved by analyzing all the four possibilities: + // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and + // (A s>= 0, B s>= 0). + // + // Note: + // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" + // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS + // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS + // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is + // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + + // C)". + + Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); + Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); + if (!LDiff || !RDiff || *LDiff != *RDiff) + return false; + + if (LDiff->isMinValue()) + return true; + + APInt FoundRHSLimit; + + if (Pred == CmpInst::ICMP_ULT) { + FoundRHSLimit = -(*RDiff); + } else { + assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); + FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; + } + + // Try to prove (1) or (2), as needed. + return isAvailableAtLoopEntry(FoundRHS, L) && + isLoopEntryGuardedByCond(L, Pred, FoundRHS, + getConstant(FoundRHSLimit)); +} + +bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, unsigned Depth) { + const PHINode *LPhi = nullptr, *RPhi = nullptr; + + auto ClearOnExit = make_scope_exit([&]() { + if (LPhi) { + bool Erased = PendingMerges.erase(LPhi); + assert(Erased && "Failed to erase LPhi!"); + (void)Erased; + } + if (RPhi) { + bool Erased = PendingMerges.erase(RPhi); + assert(Erased && "Failed to erase RPhi!"); + (void)Erased; + } + }); + + // Find respective Phis and check that they are not being pending. + if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) + if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) { + if (!PendingMerges.insert(Phi).second) + return false; + LPhi = Phi; + } + if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS)) + if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) { + // If we detect a loop of Phi nodes being processed by this method, for + // example: + // + // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ] + // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ] + // + // we don't want to deal with a case that complex, so return conservative + // answer false. + if (!PendingMerges.insert(Phi).second) + return false; + RPhi = Phi; + } + + // If none of LHS, RHS is a Phi, nothing to do here. + if (!LPhi && !RPhi) + return false; + + // If there is a SCEVUnknown Phi we are interested in, make it left. + if (!LPhi) { + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + std::swap(LPhi, RPhi); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!"); + const BasicBlock *LBB = LPhi->getParent(); + const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); + + auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || + isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) || + isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); + }; + + if (RPhi && RPhi->getParent() == LBB) { + // Case one: RHS is also a SCEVUnknown Phi from the same basic block. + // If we compare two Phis from the same block, and for each entry block + // the predicate is true for incoming values from this block, then the + // predicate is also true for the Phis. + for (const BasicBlock *IncBB : predecessors(LBB)) { + const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); + if (!ProvedEasily(L, R)) + return false; + } + } else if (RAR && RAR->getLoop()->getHeader() == LBB) { + // Case two: RHS is also a Phi from the same basic block, and it is an + // AddRec. It means that there is a loop which has both AddRec and Unknown + // PHIs, for it we can compare incoming values of AddRec from above the loop + // and latch with their respective incoming values of LPhi. + // TODO: Generalize to handle loops with many inputs in a header. + if (LPhi->getNumIncomingValues() != 2) return false; + + auto *RLoop = RAR->getLoop(); + auto *Predecessor = RLoop->getLoopPredecessor(); + assert(Predecessor && "Loop with AddRec with no predecessor?"); + const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); + if (!ProvedEasily(L1, RAR->getStart())) + return false; + auto *Latch = RLoop->getLoopLatch(); + assert(Latch && "Loop with AddRec with no latch?"); + const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); + if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) + return false; + } else { + // In all other cases go over inputs of LHS and compare each of them to RHS, + // the predicate is true for (LHS, RHS) if it is true for all such pairs. + // At this point RHS is either a non-Phi, or it is a Phi from some block + // different from LBB. + for (const BasicBlock *IncBB : predecessors(LBB)) { + // Check that RHS is available in this block. + if (!dominates(RHS, IncBB)) + return false; + const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + if (!ProvedEasily(L, RHS)) + return false; + } + } + return true; +} + +bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + return isImpliedCondOperandsHelper(Pred, LHS, RHS, + FoundLHS, FoundRHS) || + // ~x < ~y --> x > y + isImpliedCondOperandsHelper(Pred, LHS, RHS, + getNotSCEV(FoundRHS), + getNotSCEV(FoundLHS)); +} + +/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values? +template <typename MinMaxExprType> +static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, + const SCEV *Candidate) { + const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr); + if (!MinMaxExpr) + return false; + + return find(MinMaxExpr->operands(), Candidate) != MinMaxExpr->op_end(); +} + +static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // If both sides are affine addrecs for the same loop, with equal + // steps, and we know the recurrences don't wrap, then we only + // need to check the predicate on the starting values. + + if (!ICmpInst::isRelational(Pred)) + return false; + + const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); + if (!LAR) + return false; + const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); + if (!RAR) + return false; + if (LAR->getLoop() != RAR->getLoop()) + return false; + if (!LAR->isAffine() || !RAR->isAffine()) + return false; + + if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE)) + return false; + + SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? + SCEV::FlagNSW : SCEV::FlagNUW; + if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) + return false; + + return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart()); +} + +/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max +/// expression? +static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + switch (Pred) { + default: + return false; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLE: + return + // min(A, ...) <= A + IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) || + // A <= max(A, ...) + IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); + + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULE: + return + // min(A, ...) <= A + IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) || + // A <= max(A, ...) + IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); + } + + llvm_unreachable("covered switch fell through?!"); +} + +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + unsigned Depth) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(RHS->getType()) && + "LHS and RHS have different sizes?"); + assert(getTypeSizeInBits(FoundLHS->getType()) == + getTypeSizeInBits(FoundRHS->getType()) && + "FoundLHS and FoundRHS have different sizes?"); + // We want to avoid hurting the compile time with analysis of too big trees. + if (Depth > MaxSCEVOperationsImplicationDepth) + return false; + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromSExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) + return Ext->getOperand(); + // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off + // the constant in some cases. + return S; + }; + + // Acquire values from extensions. + auto *OrigLHS = LHS; + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromSExt(LHS); + FoundLHS = GetOpFromSExt(FoundLHS); + + // Is the SGT predicate can be proved trivially or using the found context. + auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || + isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, + FoundRHS, Depth + 1); + }; + + if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { + // We want to avoid creation of any new non-constant SCEV. Since we are + // going to compare the operands to RHS, we should be certain that we don't + // need any size extensions for this. So let's decline all cases when the + // sizes of types of LHS and RHS do not match. + // TODO: Maybe try to get RHS from sext to catch more cases? + if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) + return false; + + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + + using namespace llvm::PatternMatch; + + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denominator and its + // derivative expressions. In general case, creating a SCEV for it may + // lead to a complex analysis of the entire graph, and in particular it + // can request trip count recalculation for the same loop. This would + // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid + // this, we only want to create SCEVs that are constants in this section. + // So we bail if Denominator is not a constant. + if (!isa<ConstantInt>(LR)) + return false; + + auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + + // We want to make sure that LHS = FoundLHS / Denominator. If it is so, + // then a SCEV for the numerator already exists and matches with FoundLHS. + auto *Numerator = getExistingSCEV(LL); + if (!Numerator || Numerator->getType() != FoundLHS->getType()) + return false; + + // Make sure that the numerator matches with FoundLHS and the denominator + // is positive. + if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + return false; + + auto *DTy = Denominator->getType(); + auto *FRHSTy = FoundRHS->getType(); + if (DTy->isPointerTy() != FRHSTy->isPointerTy()) + // One of types is a pointer and another one is not. We cannot extend + // them properly to a wider type, so let us just reject this case. + // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help + // to avoid this check. + return false; + + // Given that: + // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + auto *WTy = getWiderType(DTy, FRHSTy); + auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + + // Try to prove the following rule: + // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // For example, given that FoundLHS > 2. It means that FoundLHS is at + // least 3. If we divide it by Denominator < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + if (isKnownNonPositive(RHS) && + IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) + return true; + + // Try to prove the following rule: + // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. + // If we divide it by Denominator > 2, then: + // 1. If FoundLHS is negative, then the result is 0. + // 2. If FoundLHS is non-negative, then the result is non-negative. + // Anyways, the result is non-negative. + auto *MinusOne = getNegativeSCEV(getOne(WTy)); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + if (isKnownNegative(RHS) && + IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) + return true; + } + } + + // If our expression contained SCEVUnknown Phis, and we split it down and now + // need to prove something for them, try to prove the predicate for every + // possible incoming values of those Phis. + if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1)) + return true; + + return false; +} + +static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // zext x u<= sext x, sext x s<= zext x + switch (Pred) { + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLE: { + // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt. + const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS); + const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS); + if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand()) + return true; + break; + } + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULE: { + // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt. + const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS); + const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS); + if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand()) + return true; + break; + } + default: + break; + }; + return false; +} + +bool +ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateExtendIdiom(Pred, LHS, RHS) || + isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + +bool +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + switch (Pred) { + default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + return true; + break; + } + + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + return false; +} + +bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS)) + // The restriction on `FoundRHS` be lifted easily -- it exists only to + // reduce the compile time impact of this optimization. + return false; + + Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); + if (!Addend) + return false; + + const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); + + // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the + // antecedent "`FoundLHS` `Pred` `FoundRHS`". + ConstantRange FoundLHSRange = + ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); + + // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: + ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend)); + + // We can also compute the range of values for `LHS` that satisfy the + // consequent, "`LHS` `Pred` `RHS`": + const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); + ConstantRange SatisfyingLHSRange = + ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); + + // The antecedent implies the consequent if every value of `LHS` that + // satisfies the antecedent also satisfies the consequent. + return SatisfyingLHSRange.contains(LHSRange); +} + +bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, + bool IsSigned, bool NoWrap) { + assert(isKnownPositive(Stride) && "Positive stride expected!"); + + if (NoWrap) return false; + + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *One = getOne(Stride->getType()); + + if (IsSigned) { + APInt MaxRHS = getSignedRangeMax(RHS); + APInt MaxValue = APInt::getSignedMaxValue(BitWidth); + APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); + + // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! + return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS); + } + + APInt MaxRHS = getUnsignedRangeMax(RHS); + APInt MaxValue = APInt::getMaxValue(BitWidth); + APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); + + // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! + return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); +} + +bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, + bool IsSigned, bool NoWrap) { + if (NoWrap) return false; + + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *One = getOne(Stride->getType()); + + if (IsSigned) { + APInt MinRHS = getSignedRangeMin(RHS); + APInt MinValue = APInt::getSignedMinValue(BitWidth); + APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); + + // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! + return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); + } + + APInt MinRHS = getUnsignedRangeMin(RHS); + APInt MinValue = APInt::getMinValue(BitWidth); + APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); + + // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! + return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); +} + +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, + bool Equality) { + const SCEV *One = getOne(Step->getType()); + Delta = Equality ? getAddExpr(Delta, Step) + : getAddExpr(Delta, getMinusSCEV(Step, One)); + return getUDivExpr(Delta, Step); +} + +const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, + const SCEV *Stride, + const SCEV *End, + unsigned BitWidth, + bool IsSigned) { + + assert(!isKnownNonPositive(Stride) && + "Stride is expected strictly positive!"); + // Calculate the maximum backedge count based on the range of values + // permitted by Start, End, and Stride. + const SCEV *MaxBECount; + APInt MinStart = + IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start); + + APInt StrideForMaxBECount = + IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride); + + // We already know that the stride is positive, so we paper over conservatism + // in our range computation by forcing StrideForMaxBECount to be at least one. + // In theory this is unnecessary, but we expect MaxBECount to be a + // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there + // is nothing to constant fold it to). + APInt One(BitWidth, 1, IsSigned); + StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount); + + APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth) + : APInt::getMaxValue(BitWidth); + APInt Limit = MaxValue - (StrideForMaxBECount - 1); + + // Although End can be a MAX expression we estimate MaxEnd considering only + // the case End = RHS of the loop termination condition. This is safe because + // in the other case (End - Start) is zero, leading to a zero maximum backedge + // taken count. + APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) + : APIntOps::umin(getUnsignedRangeMax(End), Limit); + + MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */, + getConstant(StrideForMaxBECount) /* Step */, + false /* Equality */); + + return MaxBECount; +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, + const Loop *L, bool IsSigned, + bool ControlsExit, bool AllowPredicates) { + SmallPtrSet<const SCEVPredicate *, 4> Predicates; + + const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + bool PredicatedIV = false; + + if (!IV && AllowPredicates) { + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); + PredicatedIV = true; + } + + // Avoid weird loops + if (!IV || IV->getLoop() != L || !IV->isAffine()) + return getCouldNotCompute(); + + bool NoWrap = ControlsExit && + IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); + + const SCEV *Stride = IV->getStepRecurrence(*this); + + bool PositiveStride = isKnownPositive(Stride); + + // Avoid negative or zero stride values. + if (!PositiveStride) { + // We can compute the correct backedge taken count for loops with unknown + // strides if we can prove that the loop is not an infinite loop with side + // effects. Here's the loop structure we are trying to handle - + // + // i = start + // do { + // A[i] = i; + // i += s; + // } while (i < end); + // + // The backedge taken count for such loops is evaluated as - + // (max(end, start + stride) - start - 1) /u stride + // + // The additional preconditions that we need to check to prove correctness + // of the above formula is as follows - + // + // a) IV is either nuw or nsw depending upon signedness (indicated by the + // NoWrap flag). + // b) loop is single exit with no side effects. + // + // + // Precondition a) implies that if the stride is negative, this is a single + // trip loop. The backedge taken count formula reduces to zero in this case. + // + // Precondition b) implies that the unknown stride cannot be zero otherwise + // we have UB. + // + // The positive stride case is the same as isKnownPositive(Stride) returning + // true (original behavior of the function). + // + // We want to make sure that the stride is truly unknown as there are edge + // cases where ScalarEvolution propagates no wrap flags to the + // post-increment/decrement IV even though the increment/decrement operation + // itself is wrapping. The computed backedge taken count may be wrong in + // such cases. This is prevented by checking that the stride is not known to + // be either positive or non-positive. For example, no wrap flags are + // propagated to the post-increment IV of this loop with a trip count of 2 - + // + // unsigned char i; + // for(i=127; i<128; i+=129) + // A[i] = i; + // + if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || + !loopHasNoSideEffects(L)) + return getCouldNotCompute(); + } else if (!Stride->isOne() && + doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) + // Avoid proven overflow cases: this will ensure that the backedge taken + // count will not generate any unsigned overflow. Relaxed no-overflow + // conditions exploit NoWrapFlags, allowing to optimize in presence of + // undefined behaviors like the case of C language. + return getCouldNotCompute(); + + ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT + : ICmpInst::ICMP_ULT; + const SCEV *Start = IV->getStart(); + const SCEV *End = RHS; + // When the RHS is not invariant, we do not know the end bound of the loop and + // cannot calculate the ExactBECount needed by ExitLimit. However, we can + // calculate the MaxBECount, given the start, stride and max value for the end + // bound of the loop (RHS), and the fact that IV does not overflow (which is + // checked above). + if (!isLoopInvariant(RHS, L)) { + const SCEV *MaxBECount = computeMaxBECountForLT( + Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); + return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, + false /*MaxOrZero*/, Predicates); + } + // If the backedge is taken at least once, then it will be taken + // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start + // is the LHS value of the less-than comparison the first time it is evaluated + // and End is the RHS. + const SCEV *BECountIfBackedgeTaken = + computeBECount(getMinusSCEV(End, Start), Stride, false); + // If the loop entry is guarded by the result of the backedge test of the + // first loop iteration, then we know the backedge will be taken at least + // once and so the backedge taken count is as above. If not then we use the + // expression (max(End,Start)-Start)/Stride to describe the backedge count, + // as if the backedge is taken at least once max(End,Start) is End and so the + // result is as above, and if not max(End,Start) is Start so we get a backedge + // count of zero. + const SCEV *BECount; + if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + BECount = BECountIfBackedgeTaken; + else { + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); + BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); + } + + const SCEV *MaxBECount; + bool MaxOrZero = false; + if (isa<SCEVConstant>(BECount)) + MaxBECount = BECount; + else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) { + // If we know exactly how many times the backedge will be taken if it's + // taken at least once, then the backedge count will either be that or + // zero. + MaxBECount = BECountIfBackedgeTaken; + MaxOrZero = true; + } else { + MaxBECount = computeMaxBECountForLT( + Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); + } + + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + + return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, + const Loop *L, bool IsSigned, + bool ControlsExit, bool AllowPredicates) { + SmallPtrSet<const SCEVPredicate *, 4> Predicates; + // We handle only IV > Invariant + if (!isLoopInvariant(RHS, L)) + return getCouldNotCompute(); + + const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); + + // Avoid weird loops + if (!IV || IV->getLoop() != L || !IV->isAffine()) + return getCouldNotCompute(); + + bool NoWrap = ControlsExit && + IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); + + const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); + + // Avoid negative or zero stride values + if (!isKnownPositive(Stride)) + return getCouldNotCompute(); + + // Avoid proven overflow cases: this will ensure that the backedge taken count + // will not generate any unsigned overflow. Relaxed no-overflow conditions + // exploit NoWrapFlags, allowing to optimize in presence of undefined + // behaviors like the case of C language. + if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) + return getCouldNotCompute(); + + ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT + : ICmpInst::ICMP_UGT; + + const SCEV *Start = IV->getStart(); + const SCEV *End = RHS; + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) + End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); + + const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); + + APInt MaxStart = IsSigned ? getSignedRangeMax(Start) + : getUnsignedRangeMax(Start); + + APInt MinStride = IsSigned ? getSignedRangeMin(Stride) + : getUnsignedRangeMin(Stride); + + unsigned BitWidth = getTypeSizeInBits(LHS->getType()); + APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) + : APInt::getMinValue(BitWidth) + (MinStride - 1); + + // Although End can be a MIN expression we estimate MinEnd considering only + // the case End = RHS. This is safe because in the other case (Start - End) + // is zero, leading to a zero maximum backedge taken count. + APInt MinEnd = + IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) + : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); + + const SCEV *MaxBECount = isa<SCEVConstant>(BECount) + ? BECount + : computeBECount(getConstant(MaxStart - MinEnd), + getConstant(MinStride), false); + + if (isa<SCEVCouldNotCompute>(MaxBECount)) + MaxBECount = BECount; + + return ExitLimit(BECount, MaxBECount, false, Predicates); +} + +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, + ScalarEvolution &SE) const { + if (Range.isFullSet()) // Infinite loop. + return SE.getCouldNotCompute(); + + // If the start is a non-zero constant, shift the range to simplify things. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) + if (!SC->getValue()->isZero()) { + SmallVector<const SCEV *, 4> Operands(op_begin(), op_end()); + Operands[0] = SE.getZero(SC->getType()); + const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), + getNoWrapFlags(FlagNW)); + if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted)) + return ShiftedAddRec->getNumIterationsInRange( + Range.subtract(SC->getAPInt()), SE); + // This is strange and shouldn't happen. + return SE.getCouldNotCompute(); + } + + // The only time we can solve this is when we have all constant indices. + // Otherwise, we cannot determine the overflow conditions. + if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); })) + return SE.getCouldNotCompute(); + + // Okay at this point we know that all elements of the chrec are constants and + // that the start element is zero. + + // First check to see if the range contains zero. If not, the first + // iteration exits. + unsigned BitWidth = SE.getTypeSizeInBits(getType()); + if (!Range.contains(APInt(BitWidth, 0))) + return SE.getZero(getType()); + + if (isAffine()) { + // If this is an affine expression then we have this situation: + // Solve {0,+,A} in Range === Ax in Range + + // We know that zero is in the range. If A is positive then we know that + // the upper value of the range must be the first possible exit value. + // If A is negative then the lower of the range is the last possible loop + // value. Also note that we already checked for a full range. + APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); + APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower(); + + // The exit value should be (End+A)/A. + APInt ExitVal = (End + A).udiv(A); + ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); + + // Evaluate at the exit value. If we really did fall out of the valid + // range, then we computed our trip count, otherwise wrap around or other + // things must have happened. + ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); + if (Range.contains(Val->getValue())) + return SE.getCouldNotCompute(); // Something strange happened + + // Ensure that the previous value is in the range. This is a sanity check. + assert(Range.contains( + EvaluateConstantChrecAtConstant(this, + ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && + "Linear scev computation is off in a bad way!"); + return SE.getConstant(ExitValue); + } + + if (isQuadratic()) { + if (auto S = SolveQuadraticAddRecRange(this, Range, SE)) + return SE.getConstant(S.getValue()); + } + + return SE.getCouldNotCompute(); +} + +const SCEVAddRecExpr * +SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { + assert(getNumOperands() > 1 && "AddRec with zero step?"); + // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)), + // but in this case we cannot guarantee that the value returned will be an + // AddRec because SCEV does not have a fixed point where it stops + // simplification: it is legal to return ({rec1} + {rec2}). For example, it + // may happen if we reach arithmetic depth limit while simplifying. So we + // construct the returned value explicitly. + SmallVector<const SCEV *, 3> Ops; + // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and + // (this + Step) is {A+B,+,B+C,+...,+,N}. + for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) + Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1))); + // We know that the last operand is not a constant zero (otherwise it would + // have been popped out earlier). This guarantees us that if the result has + // the same last operand, then it will also not be popped out, meaning that + // the returned value will be an AddRec. + const SCEV *Last = getOperand(getNumOperands() - 1); + assert(!Last->isZero() && "Recurrency with zero step?"); + Ops.push_back(Last); + return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(), + SCEV::FlagAnyWrap)); +} + +// Return true when S contains at least an undef value. +static inline bool containsUndefs(const SCEV *S) { + return SCEVExprContains(S, [](const SCEV *S) { + if (const auto *SU = dyn_cast<SCEVUnknown>(S)) + return isa<UndefValue>(SU->getValue()); + return false; + }); +} + +namespace { + +// Collect all steps of SCEV expressions. +struct SCEVCollectStrides { + ScalarEvolution &SE; + SmallVectorImpl<const SCEV *> &Strides; + + SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S) + : SE(SE), Strides(S) {} + + bool follow(const SCEV *S) { + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) + Strides.push_back(AR->getStepRecurrence(SE)); + return true; + } + + bool isDone() const { return false; } +}; + +// Collect all SCEVUnknown and SCEVMulExpr expressions. +struct SCEVCollectTerms { + SmallVectorImpl<const SCEV *> &Terms; + + SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {} + + bool follow(const SCEV *S) { + if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || + isa<SCEVSignExtendExpr>(S)) { + if (!containsUndefs(S)) + Terms.push_back(S); + + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + + bool isDone() const { return false; } +}; + +// Check if a SCEV contains an AddRecExpr. +struct SCEVHasAddRec { + bool &ContainsAddRec; + + SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { + ContainsAddRec = false; + } + + bool follow(const SCEV *S) { + if (isa<SCEVAddRecExpr>(S)) { + ContainsAddRec = true; + + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + + bool isDone() const { return false; } +}; + +// Find factors that are multiplied with an expression that (possibly as a +// subexpression) contains an AddRecExpr. In the expression: +// +// 8 * (100 + %p * %q * (%a + {0, +, 1}_loop)) +// +// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" +// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size +// parameters as they form a product with an induction variable. +// +// This collector expects all array size parameters to be in the same MulExpr. +// It might be necessary to later add support for collecting parameters that are +// spread over different nested MulExpr. +struct SCEVCollectAddRecMultiplies { + SmallVectorImpl<const SCEV *> &Terms; + ScalarEvolution &SE; + + SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE) + : Terms(T), SE(SE) {} + + bool follow(const SCEV *S) { + if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) { + bool HasAddRec = false; + SmallVector<const SCEV *, 0> Operands; + for (auto Op : Mul->operands()) { + const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op); + if (Unknown && !isa<CallInst>(Unknown->getValue())) { + Operands.push_back(Op); + } else if (Unknown) { + HasAddRec = true; + } else { + bool ContainsAddRec; + SCEVHasAddRec ContiansAddRec(ContainsAddRec); + visitAll(Op, ContiansAddRec); + HasAddRec |= ContainsAddRec; + } + } + if (Operands.size() == 0) + return true; + + if (!HasAddRec) + return false; + + Terms.push_back(SE.getMulExpr(Operands)); + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + + bool isDone() const { return false; } +}; + +} // end anonymous namespace + +/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in +/// two places: +/// 1) The strides of AddRec expressions. +/// 2) Unknowns that are multiplied with AddRec expressions. +void ScalarEvolution::collectParametricTerms(const SCEV *Expr, + SmallVectorImpl<const SCEV *> &Terms) { + SmallVector<const SCEV *, 4> Strides; + SCEVCollectStrides StrideCollector(*this, Strides); + visitAll(Expr, StrideCollector); + + LLVM_DEBUG({ + dbgs() << "Strides:\n"; + for (const SCEV *S : Strides) + dbgs() << *S << "\n"; + }); + + for (const SCEV *S : Strides) { + SCEVCollectTerms TermCollector(Terms); + visitAll(S, TermCollector); + } + + LLVM_DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); + + SCEVCollectAddRecMultiplies MulCollector(Terms, *this); + visitAll(Expr, MulCollector); +} + +static bool findArrayDimensionsRec(ScalarEvolution &SE, + SmallVectorImpl<const SCEV *> &Terms, + SmallVectorImpl<const SCEV *> &Sizes) { + int Last = Terms.size() - 1; + const SCEV *Step = Terms[Last]; + + // End of recursion. + if (Last == 0) { + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) { + SmallVector<const SCEV *, 2> Qs; + for (const SCEV *Op : M->operands()) + if (!isa<SCEVConstant>(Op)) + Qs.push_back(Op); + + Step = SE.getMulExpr(Qs); + } + + Sizes.push_back(Step); + return true; + } + + for (const SCEV *&Term : Terms) { + // Normalize the terms before the next call to findArrayDimensionsRec. + const SCEV *Q, *R; + SCEVDivision::divide(SE, Term, Step, &Q, &R); + + // Bail out when GCD does not evenly divide one of the terms. + if (!R->isZero()) + return false; + + Term = Q; + } + + // Remove all SCEVConstants. + Terms.erase( + remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }), + Terms.end()); + + if (Terms.size() > 0) + if (!findArrayDimensionsRec(SE, Terms, Sizes)) + return false; + + Sizes.push_back(Step); + return true; +} + +// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. +static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { + for (const SCEV *T : Terms) + if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>)) + return true; + return false; +} + +// Return the number of product terms in S. +static inline int numberOfTerms(const SCEV *S) { + if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S)) + return Expr->getNumOperands(); + return 1; +} + +static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { + if (isa<SCEVConstant>(T)) + return nullptr; + + if (isa<SCEVUnknown>(T)) + return T; + + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) { + SmallVector<const SCEV *, 2> Factors; + for (const SCEV *Op : M->operands()) + if (!isa<SCEVConstant>(Op)) + Factors.push_back(Op); + + return SE.getMulExpr(Factors); + } + + return T; +} + +/// Return the size of an element read or written by Inst. +const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { + Type *Ty; + if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) + Ty = Store->getValueOperand()->getType(); + else if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) + Ty = Load->getType(); + else + return nullptr; + + Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); + return getSizeOfExpr(ETy, Ty); +} + +void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, + SmallVectorImpl<const SCEV *> &Sizes, + const SCEV *ElementSize) { + if (Terms.size() < 1 || !ElementSize) + return; + + // Early return when Terms do not contain parameters: we do not delinearize + // non parametric SCEVs. + if (!containsParameters(Terms)) + return; + + LLVM_DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); + + // Remove duplicates. + array_pod_sort(Terms.begin(), Terms.end()); + Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); + + // Put larger terms first. + llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) { + return numberOfTerms(LHS) > numberOfTerms(RHS); + }); + + // Try to divide all terms by the element size. If term is not divisible by + // element size, proceed with the original term. + for (const SCEV *&Term : Terms) { + const SCEV *Q, *R; + SCEVDivision::divide(*this, Term, ElementSize, &Q, &R); + if (!Q->isZero()) + Term = Q; + } + + SmallVector<const SCEV *, 4> NewTerms; + + // Remove constant factors. + for (const SCEV *T : Terms) + if (const SCEV *NewT = removeConstantFactors(*this, T)) + NewTerms.push_back(NewT); + + LLVM_DEBUG({ + dbgs() << "Terms after sorting:\n"; + for (const SCEV *T : NewTerms) + dbgs() << *T << "\n"; + }); + + if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) { + Sizes.clear(); + return; + } + + // The last element to be pushed into Sizes is the size of an element. + Sizes.push_back(ElementSize); + + LLVM_DEBUG({ + dbgs() << "Sizes:\n"; + for (const SCEV *S : Sizes) + dbgs() << *S << "\n"; + }); +} + +void ScalarEvolution::computeAccessFunctions( + const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, + SmallVectorImpl<const SCEV *> &Sizes) { + // Early exit in case this SCEV is not an affine multivariate function. + if (Sizes.empty()) + return; + + if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr)) + if (!AR->isAffine()) + return; + + const SCEV *Res = Expr; + int Last = Sizes.size() - 1; + for (int i = Last; i >= 0; i--) { + const SCEV *Q, *R; + SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); + + LLVM_DEBUG({ + dbgs() << "Res: " << *Res << "\n"; + dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; + dbgs() << "Res divided by Sizes[i]:\n"; + dbgs() << "Quotient: " << *Q << "\n"; + dbgs() << "Remainder: " << *R << "\n"; + }); + + Res = Q; + + // Do not record the last subscript corresponding to the size of elements in + // the array. + if (i == Last) { + + // Bail out if the remainder is too complex. + if (isa<SCEVAddRecExpr>(R)) { + Subscripts.clear(); + Sizes.clear(); + return; + } + + continue; + } + + // Record the access function for the current subscript. + Subscripts.push_back(R); + } + + // Also push in last position the remainder of the last division: it will be + // the access function of the innermost dimension. + Subscripts.push_back(Res); + + std::reverse(Subscripts.begin(), Subscripts.end()); + + LLVM_DEBUG({ + dbgs() << "Subscripts:\n"; + for (const SCEV *S : Subscripts) + dbgs() << *S << "\n"; + }); +} + +/// Splits the SCEV into two vectors of SCEVs representing the subscripts and +/// sizes of an array access. Returns the remainder of the delinearization that +/// is the offset start of the array. The SCEV->delinearize algorithm computes +/// the multiples of SCEV coefficients: that is a pattern matching of sub +/// expressions in the stride and base of a SCEV corresponding to the +/// computation of a GCD (greatest common divisor) of base and stride. When +/// SCEV->delinearize fails, it returns the SCEV unchanged. +/// +/// For example: when analyzing the memory access A[i][j][k] in this loop nest +/// +/// void foo(long n, long m, long o, double A[n][m][o]) { +/// +/// for (long i = 0; i < n; i++) +/// for (long j = 0; j < m; j++) +/// for (long k = 0; k < o; k++) +/// A[i][j][k] = 1.0; +/// } +/// +/// the delinearization input is the following AddRec SCEV: +/// +/// AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k> +/// +/// From this SCEV, we are able to say that the base offset of the access is %A +/// because it appears as an offset that does not divide any of the strides in +/// the loops: +/// +/// CHECK: Base offset: %A +/// +/// and then SCEV->delinearize determines the size of some of the dimensions of +/// the array as these are the multiples by which the strides are happening: +/// +/// CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes. +/// +/// Note that the outermost dimension remains of UnknownSize because there are +/// no strides that would help identifying the size of the last dimension: when +/// the array has been statically allocated, one could compute the size of that +/// dimension by dividing the overall size of the array by the size of the known +/// dimensions: %m * %o * 8. +/// +/// Finally delinearize provides the access functions for the array reference +/// that does correspond to A[i][j][k] of the above C testcase: +/// +/// CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>] +/// +/// The testcases are checking the output of a function pass: +/// DelinearizationPass that walks through all loads and stores of a function +/// asking for the SCEV of the memory access with respect to all enclosing +/// loops, calling SCEV->delinearize on that and printing the results. +void ScalarEvolution::delinearize(const SCEV *Expr, + SmallVectorImpl<const SCEV *> &Subscripts, + SmallVectorImpl<const SCEV *> &Sizes, + const SCEV *ElementSize) { + // First step: collect parametric terms. + SmallVector<const SCEV *, 4> Terms; + collectParametricTerms(Expr, Terms); + + if (Terms.empty()) + return; + + // Second step: find subscript sizes. + findArrayDimensions(Terms, Sizes, ElementSize); + + if (Sizes.empty()) + return; + + // Third step: compute the access functions for each subscript. + computeAccessFunctions(Expr, Subscripts, Sizes); + + if (Subscripts.empty()) + return; + + LLVM_DEBUG({ + dbgs() << "succeeded to delinearize " << *Expr << "\n"; + dbgs() << "ArrayDecl[UnknownSize]"; + for (const SCEV *S : Sizes) + dbgs() << "[" << *S << "]"; + + dbgs() << "\nArrayRef"; + for (const SCEV *S : Subscripts) + dbgs() << "[" << *S << "]"; + dbgs() << "\n"; + }); +} + +//===----------------------------------------------------------------------===// +// SCEVCallbackVH Class Implementation +//===----------------------------------------------------------------------===// + +void ScalarEvolution::SCEVCallbackVH::deleted() { + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); + if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) + SE->ConstantEvolutionLoopExitValue.erase(PN); + SE->eraseValueFromMap(getValPtr()); + // this now dangles! +} + +void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); + + // Forget all the expressions associated with users of the old value, + // so that future queries will recompute the expressions using the new + // value. + Value *Old = getValPtr(); + SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end()); + SmallPtrSet<User *, 8> Visited; + while (!Worklist.empty()) { + User *U = Worklist.pop_back_val(); + // Deleting the Old value will cause this to dangle. Postpone + // that until everything else is done. + if (U == Old) + continue; + if (!Visited.insert(U).second) + continue; + if (PHINode *PN = dyn_cast<PHINode>(U)) + SE->ConstantEvolutionLoopExitValue.erase(PN); + SE->eraseValueFromMap(U); + Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); + } + // Delete the Old value. + if (PHINode *PN = dyn_cast<PHINode>(Old)) + SE->ConstantEvolutionLoopExitValue.erase(PN); + SE->eraseValueFromMap(Old); + // this now dangles! +} + +ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) + : CallbackVH(V), SE(se) {} + +//===----------------------------------------------------------------------===// +// ScalarEvolution Class Implementation +//===----------------------------------------------------------------------===// + +ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, + AssumptionCache &AC, DominatorTree &DT, + LoopInfo &LI) + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), + LoopDispositions(64), BlockDispositions(64) { + // To use guards for proving predicates, we need to scan every instruction in + // relevant basic blocks, and not just terminators. Doing this is a waste of + // time if the IR does not actually contain any calls to + // @llvm.experimental.guard, so do a quick check and remember this beforehand. + // + // This pessimizes the case where a pass that preserves ScalarEvolution wants + // to _add_ guards to the module when there weren't any before, and wants + // ScalarEvolution to optimize based on those guards. For now we prefer to be + // efficient in lieu of being smart in that rather obscure case. + + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); +} + +ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) + : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), + LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), + ValueExprMap(std::move(Arg.ValueExprMap)), + PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), + PendingPhiRanges(std::move(Arg.PendingPhiRanges)), + PendingMerges(std::move(Arg.PendingMerges)), + MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), + BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), + PredicatedBackedgeTakenCounts( + std::move(Arg.PredicatedBackedgeTakenCounts)), + ConstantEvolutionLoopExitValue( + std::move(Arg.ConstantEvolutionLoopExitValue)), + ValuesAtScopes(std::move(Arg.ValuesAtScopes)), + LoopDispositions(std::move(Arg.LoopDispositions)), + LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)), + BlockDispositions(std::move(Arg.BlockDispositions)), + UnsignedRanges(std::move(Arg.UnsignedRanges)), + SignedRanges(std::move(Arg.SignedRanges)), + UniqueSCEVs(std::move(Arg.UniqueSCEVs)), + UniquePreds(std::move(Arg.UniquePreds)), + SCEVAllocator(std::move(Arg.SCEVAllocator)), + LoopUsers(std::move(Arg.LoopUsers)), + PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), + FirstUnknown(Arg.FirstUnknown) { + Arg.FirstUnknown = nullptr; +} + +ScalarEvolution::~ScalarEvolution() { + // Iterate through all the SCEVUnknown instances and call their + // destructors, so that they release their references to their values. + for (SCEVUnknown *U = FirstUnknown; U;) { + SCEVUnknown *Tmp = U; + U = U->Next; + Tmp->~SCEVUnknown(); + } + FirstUnknown = nullptr; + + ExprValueMap.clear(); + ValueExprMap.clear(); + HasRecMap.clear(); + + // Free any extra memory created for ExitNotTakenInfo in the unlikely event + // that a loop had multiple computable exits. + for (auto &BTCI : BackedgeTakenCounts) + BTCI.second.clear(); + for (auto &BTCI : PredicatedBackedgeTakenCounts) + BTCI.second.clear(); + + assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); + assert(PendingPhiRanges.empty() && "getRangeRef garbage"); + assert(PendingMerges.empty() && "isImpliedViaMerge garbage"); + assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); + assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); +} + +bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { + return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); +} + +static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, + const Loop *L) { + // Print all inner loops first + for (Loop *I : *L) + PrintLoopInfo(OS, SE, I); + + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + SmallVector<BasicBlock *, 8> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + if (ExitingBlocks.size() != 1) + OS << "<multiple exits> "; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) + OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n"; + else + OS << "Unpredictable backedge-taken count.\n"; + + if (ExitingBlocks.size() > 1) + for (BasicBlock *ExitingBlock : ExitingBlocks) { + OS << " exit count for " << ExitingBlock->getName() << ": " + << *SE->getExitCount(L, ExitingBlock) << "\n"; + } + + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) { + OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L); + if (SE->isBackedgeTakenCountMaxOrZero(L)) + OS << ", actual taken count either this or zero."; + } else { + OS << "Unpredictable max backedge-taken count. "; + } + + OS << "\n" + "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + SCEVUnionPredicate Pred; + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); + if (!isa<SCEVCouldNotCompute>(PBT)) { + OS << "Predicated backedge-taken count is " << *PBT << "\n"; + OS << " Predicates:\n"; + Pred.print(OS, 4); + } else { + OS << "Unpredictable predicated backedge-taken count. "; + } + OS << "\n"; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; + } +} + +static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { + switch (LD) { + case ScalarEvolution::LoopVariant: + return "Variant"; + case ScalarEvolution::LoopInvariant: + return "Invariant"; + case ScalarEvolution::LoopComputable: + return "Computable"; + } + llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); +} + +void ScalarEvolution::print(raw_ostream &OS) const { + // ScalarEvolution's implementation of the print method is to print + // out SCEV values of all instructions that are interesting. Doing + // this potentially causes it to create new SCEV objects though, + // which technically conflicts with the const qualifier. This isn't + // observable from outside the class though, so casting away the + // const isn't dangerous. + ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + + OS << "Classifying expressions for: "; + F.printAsOperand(OS, /*PrintType=*/false); + OS << "\n"; + for (Instruction &I : instructions(F)) + if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) { + OS << I << '\n'; + OS << " --> "; + const SCEV *SV = SE.getSCEV(&I); + SV->print(OS); + if (!isa<SCEVCouldNotCompute>(SV)) { + OS << " U: "; + SE.getUnsignedRange(SV).print(OS); + OS << " S: "; + SE.getSignedRange(SV).print(OS); + } + + const Loop *L = LI.getLoopFor(I.getParent()); + + const SCEV *AtUse = SE.getSCEVAtScope(SV, L); + if (AtUse != SV) { + OS << " --> "; + AtUse->print(OS); + if (!isa<SCEVCouldNotCompute>(AtUse)) { + OS << " U: "; + SE.getUnsignedRange(AtUse).print(OS); + OS << " S: "; + SE.getSignedRange(AtUse).print(OS); + } + } + + if (L) { + OS << "\t\t" "Exits: "; + const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); + if (!SE.isLoopInvariant(ExitValue, L)) { + OS << "<<Unknown>>"; + } else { + OS << *ExitValue; + } + + bool First = true; + for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); + } + + for (auto *InnerL : depth_first(L)) { + if (InnerL == L) + continue; + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); + } + + OS << " }"; + } + + OS << "\n"; + } + + OS << "Determining loop execution counts for: "; + F.printAsOperand(OS, /*PrintType=*/false); + OS << "\n"; + for (Loop *I : LI) + PrintLoopInfo(OS, &SE, I); +} + +ScalarEvolution::LoopDisposition +ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { + auto &Values = LoopDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == L) + return V.getInt(); + } + Values.emplace_back(L, LoopVariant); + LoopDisposition D = computeLoopDisposition(S, L); + auto &Values2 = LoopDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == L) { + V.setInt(D); + break; + } + } + return D; +} + +ScalarEvolution::LoopDisposition +ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scConstant: + return LoopInvariant; + case scTruncate: + case scZeroExtend: + case scSignExtend: + return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L); + case scAddRecExpr: { + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); + + // If L is the addrec's loop, it's computable. + if (AR->getLoop() == L) + return LoopComputable; + + // Add recurrences are never invariant in the function-body (null loop). + if (!L) + return LoopVariant; + + // Everything that is not defined at loop entry is variant. + if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) + return LoopVariant; + assert(!L->contains(AR->getLoop()) && "Containing loop's header does not" + " dominate the contained loop's header?"); + + // This recurrence is invariant w.r.t. L if AR's loop contains L. + if (AR->getLoop()->contains(L)) + return LoopInvariant; + + // This recurrence is variant w.r.t. L if any of its operands + // are variant. + for (auto *Op : AR->operands()) + if (!isLoopInvariant(Op, L)) + return LoopVariant; + + // Otherwise it's loop-invariant. + return LoopInvariant; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: { + bool HasVarying = false; + for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) { + LoopDisposition D = getLoopDisposition(Op, L); + if (D == LoopVariant) + return LoopVariant; + if (D == LoopComputable) + HasVarying = true; + } + return HasVarying ? LoopComputable : LoopInvariant; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); + LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); + if (LD == LoopVariant) + return LoopVariant; + LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); + if (RD == LoopVariant) + return LoopVariant; + return (LD == LoopInvariant && RD == LoopInvariant) ? + LoopInvariant : LoopComputable; + } + case scUnknown: + // All non-instruction values are loop invariant. All instructions are loop + // invariant if they are not contained in the specified loop. + // Instructions are never considered invariant in the function body + // (null loop) because they are defined within the "loop". + if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) + return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; + return LoopInvariant; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} + +bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { + return getLoopDisposition(S, L) == LoopInvariant; +} + +bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { + return getLoopDisposition(S, L) == LoopComputable; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { + auto &Values = BlockDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == BB) + return V.getInt(); + } + Values.emplace_back(BB, DoesNotDominateBlock); + BlockDisposition D = computeBlockDisposition(S, BB); + auto &Values2 = BlockDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == BB) { + V.setInt(D); + break; + } + } + return D; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scConstant: + return ProperlyDominatesBlock; + case scTruncate: + case scZeroExtend: + case scSignExtend: + return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB); + case scAddRecExpr: { + // This uses a "dominates" query instead of "properly dominates" query + // to test for proper dominance too, because the instruction which + // produces the addrec's value is a PHI, and a PHI effectively properly + // dominates its entire containing block. + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); + if (!DT.dominates(AR->getLoop()->getHeader(), BB)) + return DoesNotDominateBlock; + + // Fall through into SCEVNAryExpr handling. + LLVM_FALLTHROUGH; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: { + const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); + bool Proper = true; + for (const SCEV *NAryOp : NAry->operands()) { + BlockDisposition D = getBlockDisposition(NAryOp, BB); + if (D == DoesNotDominateBlock) + return DoesNotDominateBlock; + if (D == DominatesBlock) + Proper = false; + } + return Proper ? ProperlyDominatesBlock : DominatesBlock; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); + const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); + BlockDisposition LD = getBlockDisposition(LHS, BB); + if (LD == DoesNotDominateBlock) + return DoesNotDominateBlock; + BlockDisposition RD = getBlockDisposition(RHS, BB); + if (RD == DoesNotDominateBlock) + return DoesNotDominateBlock; + return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? + ProperlyDominatesBlock : DominatesBlock; + } + case scUnknown: + if (Instruction *I = + dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) { + if (I->getParent() == BB) + return DominatesBlock; + if (DT.properlyDominates(I->getParent(), BB)) + return ProperlyDominatesBlock; + return DoesNotDominateBlock; + } + return ProperlyDominatesBlock; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} + +bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { + return getBlockDisposition(S, BB) >= DominatesBlock; +} + +bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { + return getBlockDisposition(S, BB) == ProperlyDominatesBlock; +} + +bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { + return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); +} + +bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const { + auto IsS = [&](const SCEV *X) { return S == X; }; + auto ContainsS = [&](const SCEV *X) { + return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS); + }; + return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken); +} + +void +ScalarEvolution::forgetMemoizedResults(const SCEV *S) { + ValuesAtScopes.erase(S); + LoopDispositions.erase(S); + BlockDispositions.erase(S); + UnsignedRanges.erase(S); + SignedRanges.erase(S); + ExprValueMap.erase(S); + HasRecMap.erase(S); + MinTrailingZerosCache.erase(S); + + for (auto I = PredicatedSCEVRewrites.begin(); + I != PredicatedSCEVRewrites.end();) { + std::pair<const SCEV *, const Loop *> Entry = I->first; + if (Entry.first == S) + PredicatedSCEVRewrites.erase(I++); + else + ++I; + } + + auto RemoveSCEVFromBackedgeMap = + [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + for (auto I = Map.begin(), E = Map.end(); I != E;) { + BackedgeTakenInfo &BEInfo = I->second; + if (BEInfo.hasOperand(S, this)) { + BEInfo.clear(); + Map.erase(I++); + } else + ++I; + } + }; + + RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); + RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); +} + +void +ScalarEvolution::getUsedLoops(const SCEV *S, + SmallPtrSetImpl<const Loop *> &LoopsUsed) { + struct FindUsedLoops { + FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed) + : LoopsUsed(LoopsUsed) {} + SmallPtrSetImpl<const Loop *> &LoopsUsed; + bool follow(const SCEV *S) { + if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) + LoopsUsed.insert(AR->getLoop()); + return true; + } + + bool isDone() const { return false; } + }; + + FindUsedLoops F(LoopsUsed); + SCEVTraversal<FindUsedLoops>(F).visitAll(S); +} + +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { + SmallPtrSet<const Loop *, 8> LoopsUsed; + getUsedLoops(S, LoopsUsed); + for (auto *L : LoopsUsed) + LoopUsers[L].push_back(S); +} + +void ScalarEvolution::verify() const { + ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + ScalarEvolution SE2(F, TLI, AC, DT, LI); + + SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end()); + + // Map's SCEV expressions from one ScalarEvolution "universe" to another. + struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { + SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} + + const SCEV *visitConstant(const SCEVConstant *Constant) { + return SE.getConstant(Constant->getAPInt()); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return SE.getUnknown(Expr->getValue()); + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return SE.getCouldNotCompute(); + } + }; + + SCEVMapper SCM(SE2); + + while (!LoopStack.empty()) { + auto *L = LoopStack.pop_back_val(); + LoopStack.insert(LoopStack.end(), L->begin(), L->end()); + + auto *CurBECount = SCM.visit( + const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); + auto *NewBECount = SE2.getBackedgeTakenCount(L); + + if (CurBECount == SE2.getCouldNotCompute() || + NewBECount == SE2.getCouldNotCompute()) { + // NB! This situation is legal, but is very suspicious -- whatever pass + // change the loop to make a trip count go from could not compute to + // computable or vice-versa *should have* invalidated SCEV. However, we + // choose not to assert here (for now) since we don't want false + // positives. + continue; + } + + if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) { + // SCEV treats "undef" as an unknown but consistent value (i.e. it does + // not propagate undef aggressively). This means we can (and do) fail + // verification in cases where a transform makes the trip count of a loop + // go from "undef" to "undef+1" (say). The transform is fine, since in + // both cases the loop iterates "undef" times, but SCEV thinks we + // increased the trip count of the loop by 1 incorrectly. + continue; + } + + if (SE.getTypeSizeInBits(CurBECount->getType()) > + SE.getTypeSizeInBits(NewBECount->getType())) + NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); + else if (SE.getTypeSizeInBits(CurBECount->getType()) < + SE.getTypeSizeInBits(NewBECount->getType())) + CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); + + const SCEV *Delta = SE2.getMinusSCEV(CurBECount, NewBECount); + + // Unless VerifySCEVStrict is set, we only compare constant deltas. + if ((VerifySCEVStrict || isa<SCEVConstant>(Delta)) && !Delta->isZero()) { + dbgs() << "Trip Count for " << *L << " Changed!\n"; + dbgs() << "Old: " << *CurBECount << "\n"; + dbgs() << "New: " << *NewBECount << "\n"; + dbgs() << "Delta: " << *Delta << "\n"; + std::abort(); + } + } +} + +bool ScalarEvolution::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // Invalidate the ScalarEvolution object whenever it isn't preserved or one + // of its dependencies is invalidated. + auto PAC = PA.getChecker<ScalarEvolutionAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + Inv.invalidate<AssumptionAnalysis>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA) || + Inv.invalidate<LoopAnalysis>(F, PA); +} + +AnalysisKey ScalarEvolutionAnalysis::Key; + +ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); +} + +PreservedAnalyses +ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} + +INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", + "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", + "Scalar Evolution Analysis", false, true) + +char ScalarEvolutionWrapperPass::ID = 0; + +ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { + initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { + SE.reset(new ScalarEvolution( + F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo())); + return false; +} + +void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } + +void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { + SE->print(OS); +} + +void ScalarEvolutionWrapperPass::verifyAnalysis() const { + if (!VerifySCEV) + return; + + SE->verify(); +} + +void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<AssumptionCacheTracker>(); + AU.addRequiredTransitive<LoopInfoWrapperPass>(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); +} + +const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, + const SCEV *RHS) { + FoldingSetNodeID ID; + assert(LHS->getType() == RHS->getType() && + "Type mismatch between LHS and RHS"); + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + +namespace { + +class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { +public: + + /// Rewrites \p S in the context of a loop L and the SCEV predication + /// infrastructure. + /// + /// If \p Pred is non-null, the SCEV expression is rewritten to respect the + /// equivalences present in \p Pred. + /// + /// If \p NewPreds is non-null, rewrite is free to add further predicates to + /// \p NewPreds such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, + SCEVUnionPredicate *Pred) { + SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); + return Rewriter.visit(S); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (Pred) { + auto ExprPreds = Pred->getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } + return convertToAddRecWithPreds(Expr); + } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + +private: + explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, + SCEVUnionPredicate *Pred) + : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} + + bool addOverflowAssumption(const SCEVPredicate *P) { + if (!NewPreds) { + // Check if we've already made this assumption. + return Pred && Pred->implies(P); + } + NewPreds->insert(P); + return true; + } + + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + return addOverflowAssumption(A); + } + + // If \p Expr represents a PHINode, we try to see if it can be represented + // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible + // to add this predicate as a runtime overflow check, we return the AddRec. + // If \p Expr does not meet these conditions (is not a PHI node, or we + // couldn't create an AddRec for it, or couldn't add the predicate), we just + // return \p Expr. + const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { + if (!isa<PHINode>(Expr->getValue())) + return Expr; + Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> + PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); + if (!PredicatedRewrite) + return Expr; + for (auto *P : PredicatedRewrite->second){ + // Wrap predicates from outer loops are not supported. + if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) { + auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr()); + if (L != AR->getLoop()) + return Expr; + } + if (!addOverflowAssumption(P)) + return Expr; + } + return PredicatedRewrite->first; + } + + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; + SCEVUnionPredicate *Pred; + const Loop *L; +}; + +} // end anonymous namespace + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); +} + +const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *S, const Loop *L, + SmallPtrSetImpl<const SCEVPredicate *> &Preds) { + SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); + auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); + + if (!AddRec) + return nullptr; + + // Since the transformation was successful, we can now transfer the SCEV + // predicates. + for (auto *P : TransformPreds) + Preds.insert(P); + + return AddRec; +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, + SCEVPredicateKind Kind) + : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEV *LHS, const SCEV *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) { + assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); + assert(LHS != RHS && "LHS and RHS are the same SCEV"); +} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVEqualPredicate>(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << "<nusw>"; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << "<nssw>"; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return all_of(Preds, + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef<const SCEVPredicate *> +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return ArrayRef<const SCEVPredicate *>(); + return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) + return all_of(Set->Preds, + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return any_of(SCEVPreds, + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + SCEVToPreds[Key].push_back(N); + Preds.push_back(N); +} + +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L) {} + +const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { + const SCEV *Expr = SE.getSCEV(V); + RewriteEntry &Entry = RewriteMap[Expr]; + + // If we already have an entry and the version matches, return it. + if (Entry.second && Generation == Entry.first) + return Entry.second; + + // We found an entry but it's stale. Rewrite the stale entry + // according to the current predicate. + if (Entry.second) + Expr = Entry.second; + + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); + Entry = {Generation, NewSCEV}; + + return NewSCEV; +} + +const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { + if (!BackedgeCount) { + SCEVUnionPredicate BackedgePred; + BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); + addPredicate(BackedgePred); + } + return BackedgeCount; +} + +void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { + if (Preds.implies(&Pred)) + return; + Preds.add(&Pred); + updateGeneration(); +} + +const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { + return Preds; +} + +void PredicatedScalarEvolution::updateGeneration() { + // If the generation number wrapped recompute everything. + if (++Generation == 0) { + for (auto &II : RewriteMap) { + const SCEV *Rewritten = II.second.second; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; + } + } +} + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + SmallPtrSet<const SCEVPredicate *, 4> NewPreds; + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); + + if (!New) + return nullptr; + + for (auto *P : NewPreds) + Preds.add(P); + + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution::PredicatedScalarEvolution( + const PredicatedScalarEvolution &Init) + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + for (const auto &I : Init.FlagsMap) + FlagsMap.insert(I); +} + +void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { + // For each block. + for (auto *BB : L.getBlocks()) + for (auto &I : *BB) { + if (!SE.isSCEVable(I.getType())) + continue; + + auto *Expr = SE.getSCEV(&I); + auto II = RewriteMap.find(Expr); + + if (II == RewriteMap.end()) + continue; + + // Don't print things that are not interesting. + if (II->second.second == Expr) + continue; + + OS.indent(Depth) << "[PSE]" << I << ":\n"; + OS.indent(Depth + 2) << *Expr << "\n"; + OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; + } +} + +// Match the mathematical pattern A - (A / B) * B, where A and B can be +// arbitrary expressions. +// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is +// 4, A / B becomes X / 8). +bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, + const SCEV *&RHS) { + const auto *Add = dyn_cast<SCEVAddExpr>(Expr); + if (Add == nullptr || Add->getNumOperands() != 2) + return false; + + const SCEV *A = Add->getOperand(1); + const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); + + if (Mul == nullptr) + return false; + + const auto MatchURemWithDivisor = [&](const SCEV *B) { + // (SomeExpr + (-(SomeExpr / B) * B)). + if (Expr == getURemExpr(A, B)) { + LHS = A; + RHS = B; + return true; + } + return false; + }; + + // (SomeExpr + (-1 * (SomeExpr / B) * B)). + if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0))) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(2)); + + // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). + if (Mul->getNumOperands() == 2) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(0)) || + MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) || + MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); + return false; +} diff --git a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp new file mode 100644 index 000000000000..96da0a24cddd --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -0,0 +1,147 @@ +//===- ScalarEvolutionAliasAnalysis.cpp - SCEV-based Alias Analysis -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScalarEvolutionAliasAnalysis pass, which implements a +// simple alias analysis implemented in terms of ScalarEvolution queries. +// +// This differs from traditional loop dependence analysis in that it tests +// for dependencies within a single iteration of a loop, rather than +// dependencies between different iterations. +// +// ScalarEvolution has a more complete understanding of pointer arithmetic +// than BasicAliasAnalysis' collection of ad-hoc analyses. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +using namespace llvm; + +AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, AAQueryInfo &AAQI) { + // If either of the memory references is empty, it doesn't matter what the + // pointer values are. This allows the code below to ignore this special + // case. + if (LocA.Size.isZero() || LocB.Size.isZero()) + return NoAlias; + + // This is SCEVAAResult. Get the SCEVs! + const SCEV *AS = SE.getSCEV(const_cast<Value *>(LocA.Ptr)); + const SCEV *BS = SE.getSCEV(const_cast<Value *>(LocB.Ptr)); + + // If they evaluate to the same expression, it's a MustAlias. + if (AS == BS) + return MustAlias; + + // If something is known about the difference between the two addresses, + // see if it's enough to prove a NoAlias. + if (SE.getEffectiveSCEVType(AS->getType()) == + SE.getEffectiveSCEVType(BS->getType())) { + unsigned BitWidth = SE.getTypeSizeInBits(AS->getType()); + APInt ASizeInt(BitWidth, LocA.Size.hasValue() + ? LocA.Size.getValue() + : MemoryLocation::UnknownSize); + APInt BSizeInt(BitWidth, LocB.Size.hasValue() + ? LocB.Size.getValue() + : MemoryLocation::UnknownSize); + + // Compute the difference between the two pointers. + const SCEV *BA = SE.getMinusSCEV(BS, AS); + + // Test whether the difference is known to be great enough that memory of + // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt + // are non-zero, which is special-cased above. + if (ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) && + (-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax())) + return NoAlias; + + // Folding the subtraction while preserving range information can be tricky + // (because of INT_MIN, etc.); if the prior test failed, swap AS and BS + // and try again to see if things fold better that way. + + // Compute the difference between the two pointers. + const SCEV *AB = SE.getMinusSCEV(AS, BS); + + // Test whether the difference is known to be great enough that memory of + // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt + // are non-zero, which is special-cased above. + if (BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) && + (-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax())) + return NoAlias; + } + + // If ScalarEvolution can find an underlying object, form a new query. + // The correctness of this depends on ScalarEvolution not recognizing + // inttoptr and ptrtoint operators. + Value *AO = GetBaseValue(AS); + Value *BO = GetBaseValue(BS); + if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr)) + if (alias(MemoryLocation(AO ? AO : LocA.Ptr, + AO ? LocationSize::unknown() : LocA.Size, + AO ? AAMDNodes() : LocA.AATags), + MemoryLocation(BO ? BO : LocB.Ptr, + BO ? LocationSize::unknown() : LocB.Size, + BO ? AAMDNodes() : LocB.AATags), + AAQI) == NoAlias) + return NoAlias; + + // Forward the query to the next analysis. + return AAResultBase::alias(LocA, LocB, AAQI); +} + +/// Given an expression, try to find a base value. +/// +/// Returns null if none was found. +Value *SCEVAAResult::GetBaseValue(const SCEV *S) { + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + // In an addrec, assume that the base will be in the start, rather + // than the step. + return GetBaseValue(AR->getStart()); + } else if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { + // If there's a pointer operand, it'll be sorted at the end of the list. + const SCEV *Last = A->getOperand(A->getNumOperands() - 1); + if (Last->getType()->isPointerTy()) + return GetBaseValue(Last); + } else if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // This is a leaf node. + return U->getValue(); + } + // No Identified object found. + return nullptr; +} + +AnalysisKey SCEVAA::Key; + +SCEVAAResult SCEVAA::run(Function &F, FunctionAnalysisManager &AM) { + return SCEVAAResult(AM.getResult<ScalarEvolutionAnalysis>(F)); +} + +char SCEVAAWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(SCEVAAWrapperPass, "scev-aa", + "ScalarEvolution-based Alias Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(SCEVAAWrapperPass, "scev-aa", + "ScalarEvolution-based Alias Analysis", false, true) + +FunctionPass *llvm::createSCEVAAWrapperPass() { + return new SCEVAAWrapperPass(); +} + +SCEVAAWrapperPass::SCEVAAWrapperPass() : FunctionPass(ID) { + initializeSCEVAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool SCEVAAWrapperPass::runOnFunction(Function &F) { + Result.reset( + new SCEVAAResult(getAnalysis<ScalarEvolutionWrapperPass>().getSE())); + return false; +} + +void SCEVAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<ScalarEvolutionWrapperPass>(); +} diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp new file mode 100644 index 000000000000..bceec921188e --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -0,0 +1,2452 @@ +//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution expander, +// which is used to generate the code corresponding to a given scalar evolution +// expression. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace PatternMatch; + +/// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP, +/// reusing an existing cast if a suitable one exists, moving an existing +/// cast if a suitable one exists but isn't in the right place, or +/// creating a new one. +Value *SCEVExpander::ReuseOrCreateCast(Value *V, Type *Ty, + Instruction::CastOps Op, + BasicBlock::iterator IP) { + // This function must be called with the builder having a valid insertion + // point. It doesn't need to be the actual IP where the uses of the returned + // cast will be added, but it must dominate such IP. + // We use this precondition to produce a cast that will dominate all its + // uses. In particular, this is crucial for the case where the builder's + // insertion point *is* the point where we were asked to put the cast. + // Since we don't know the builder's insertion point is actually + // where the uses will be added (only that it dominates it), we are + // not allowed to move it. + BasicBlock::iterator BIP = Builder.GetInsertPoint(); + + Instruction *Ret = nullptr; + + // Check to see if there is already a cast! + for (User *U : V->users()) + if (U->getType() == Ty) + if (CastInst *CI = dyn_cast<CastInst>(U)) + if (CI->getOpcode() == Op) { + // If the cast isn't where we want it, create a new cast at IP. + // Likewise, do not reuse a cast at BIP because it must dominate + // instructions that might be inserted before BIP. + if (BasicBlock::iterator(CI) != IP || BIP == IP) { + // Create a new cast, and leave the old cast in place in case + // it is being used as an insert point. + Ret = CastInst::Create(Op, V, Ty, "", &*IP); + Ret->takeName(CI); + CI->replaceAllUsesWith(Ret); + break; + } + Ret = CI; + break; + } + + // Create a new cast. + if (!Ret) + Ret = CastInst::Create(Op, V, Ty, V->getName(), &*IP); + + // We assert at the end of the function since IP might point to an + // instruction with different dominance properties than a cast + // (an invoke for example) and not dominate BIP (but the cast does). + assert(SE.DT.dominates(Ret, &*BIP)); + + rememberInstruction(Ret); + return Ret; +} + +static BasicBlock::iterator findInsertPointAfter(Instruction *I, + BasicBlock *MustDominate) { + BasicBlock::iterator IP = ++I->getIterator(); + if (auto *II = dyn_cast<InvokeInst>(I)) + IP = II->getNormalDest()->begin(); + + while (isa<PHINode>(IP)) + ++IP; + + if (isa<FuncletPadInst>(IP) || isa<LandingPadInst>(IP)) { + ++IP; + } else if (isa<CatchSwitchInst>(IP)) { + IP = MustDominate->getFirstInsertionPt(); + } else { + assert(!IP->isEHPad() && "unexpected eh pad!"); + } + + return IP; +} + +/// InsertNoopCastOfTo - Insert a cast of V to the specified type, +/// which must be possible with a noop cast, doing what we can to share +/// the casts. +Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { + Instruction::CastOps Op = CastInst::getCastOpcode(V, false, Ty, false); + assert((Op == Instruction::BitCast || + Op == Instruction::PtrToInt || + Op == Instruction::IntToPtr) && + "InsertNoopCastOfTo cannot perform non-noop casts!"); + assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) && + "InsertNoopCastOfTo cannot change sizes!"); + + // Short-circuit unnecessary bitcasts. + if (Op == Instruction::BitCast) { + if (V->getType() == Ty) + return V; + if (CastInst *CI = dyn_cast<CastInst>(V)) { + if (CI->getOperand(0)->getType() == Ty) + return CI->getOperand(0); + } + } + // Short-circuit unnecessary inttoptr<->ptrtoint casts. + if ((Op == Instruction::PtrToInt || Op == Instruction::IntToPtr) && + SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(V->getType())) { + if (CastInst *CI = dyn_cast<CastInst>(V)) + if ((CI->getOpcode() == Instruction::PtrToInt || + CI->getOpcode() == Instruction::IntToPtr) && + SE.getTypeSizeInBits(CI->getType()) == + SE.getTypeSizeInBits(CI->getOperand(0)->getType())) + return CI->getOperand(0); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + if ((CE->getOpcode() == Instruction::PtrToInt || + CE->getOpcode() == Instruction::IntToPtr) && + SE.getTypeSizeInBits(CE->getType()) == + SE.getTypeSizeInBits(CE->getOperand(0)->getType())) + return CE->getOperand(0); + } + + // Fold a cast of a constant. + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(Op, C, Ty); + + // Cast the argument at the beginning of the entry block, after + // any bitcasts of other arguments. + if (Argument *A = dyn_cast<Argument>(V)) { + BasicBlock::iterator IP = A->getParent()->getEntryBlock().begin(); + while ((isa<BitCastInst>(IP) && + isa<Argument>(cast<BitCastInst>(IP)->getOperand(0)) && + cast<BitCastInst>(IP)->getOperand(0) != A) || + isa<DbgInfoIntrinsic>(IP)) + ++IP; + return ReuseOrCreateCast(A, Ty, Op, IP); + } + + // Cast the instruction immediately after the instruction. + Instruction *I = cast<Instruction>(V); + BasicBlock::iterator IP = findInsertPointAfter(I, Builder.GetInsertBlock()); + return ReuseOrCreateCast(I, Ty, Op, IP); +} + +/// InsertBinop - Insert the specified binary operator, doing a small amount +/// of work to avoid inserting an obviously redundant operation, and hoisting +/// to an outer loop when the opportunity is there and it is safe. +Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + SCEV::NoWrapFlags Flags, bool IsSafeToHoist) { + // Fold a binop with constant operands. + if (Constant *CLHS = dyn_cast<Constant>(LHS)) + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantExpr::get(Opcode, CLHS, CRHS); + + // Do a quick scan to see if we have this binop nearby. If so, reuse it. + unsigned ScanLimit = 6; + BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); + // Scanning starts from the last instruction before the insertion point. + BasicBlock::iterator IP = Builder.GetInsertPoint(); + if (IP != BlockBegin) { + --IP; + for (; ScanLimit; --IP, --ScanLimit) { + // Don't count dbg.value against the ScanLimit, to avoid perturbing the + // generated code. + if (isa<DbgInfoIntrinsic>(IP)) + ScanLimit++; + + auto canGenerateIncompatiblePoison = [&Flags](Instruction *I) { + // Ensure that no-wrap flags match. + if (isa<OverflowingBinaryOperator>(I)) { + if (I->hasNoSignedWrap() != (Flags & SCEV::FlagNSW)) + return true; + if (I->hasNoUnsignedWrap() != (Flags & SCEV::FlagNUW)) + return true; + } + // Conservatively, do not use any instruction which has any of exact + // flags installed. + if (isa<PossiblyExactOperator>(I) && I->isExact()) + return true; + return false; + }; + if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS && + IP->getOperand(1) == RHS && !canGenerateIncompatiblePoison(&*IP)) + return &*IP; + if (IP == BlockBegin) break; + } + } + + // Save the original insertion point so we can restore it when we're done. + DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc(); + SCEVInsertPointGuard Guard(Builder, this); + + if (IsSafeToHoist) { + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break; + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; + + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); + } + } + + // If we haven't found this binop, insert it. + Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS)); + BO->setDebugLoc(Loc); + if (Flags & SCEV::FlagNUW) + BO->setHasNoUnsignedWrap(); + if (Flags & SCEV::FlagNSW) + BO->setHasNoSignedWrap(); + rememberInstruction(BO); + + return BO; +} + +/// FactorOutConstant - Test if S is divisible by Factor, using signed +/// division. If so, update S with Factor divided out and return true. +/// S need not be evenly divisible if a reasonable remainder can be +/// computed. +static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder, + const SCEV *Factor, ScalarEvolution &SE, + const DataLayout &DL) { + // Everything is divisible by one. + if (Factor->isOne()) + return true; + + // x/x == 1. + if (S == Factor) { + S = SE.getConstant(S->getType(), 1); + return true; + } + + // For a Constant, check for a multiple of the given factor. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { + // 0/x == 0. + if (C->isZero()) + return true; + // Check for divisibility. + if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) { + ConstantInt *CI = + ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt())); + // If the quotient is zero and the remainder is non-zero, reject + // the value at this scale. It will be considered for subsequent + // smaller scales. + if (!CI->isZero()) { + const SCEV *Div = SE.getConstant(CI); + S = Div; + Remainder = SE.getAddExpr( + Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt()))); + return true; + } + } + } + + // In a Mul, check if there is a constant operand which is a multiple + // of the given factor. + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { + // Size is known, check if there is a constant operand which is a multiple + // of the given factor. If so, we can factor it. + const SCEVConstant *FC = cast<SCEVConstant>(Factor); + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0))) + if (!C->getAPInt().srem(FC->getAPInt())) { + SmallVector<const SCEV *, 4> NewMulOps(M->op_begin(), M->op_end()); + NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); + S = SE.getMulExpr(NewMulOps); + return true; + } + } + + // In an AddRec, check if both start and step are divisible. + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEV *Step = A->getStepRecurrence(SE); + const SCEV *StepRem = SE.getConstant(Step->getType(), 0); + if (!FactorOutConstant(Step, StepRem, Factor, SE, DL)) + return false; + if (!StepRem->isZero()) + return false; + const SCEV *Start = A->getStart(); + if (!FactorOutConstant(Start, Remainder, Factor, SE, DL)) + return false; + S = SE.getAddRecExpr(Start, Step, A->getLoop(), + A->getNoWrapFlags(SCEV::FlagNW)); + return true; + } + + return false; +} + +/// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs +/// is the number of SCEVAddRecExprs present, which are kept at the end of +/// the list. +/// +static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops, + Type *Ty, + ScalarEvolution &SE) { + unsigned NumAddRecs = 0; + for (unsigned i = Ops.size(); i > 0 && isa<SCEVAddRecExpr>(Ops[i-1]); --i) + ++NumAddRecs; + // Group Ops into non-addrecs and addrecs. + SmallVector<const SCEV *, 8> NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs); + SmallVector<const SCEV *, 8> AddRecs(Ops.end() - NumAddRecs, Ops.end()); + // Let ScalarEvolution sort and simplify the non-addrecs list. + const SCEV *Sum = NoAddRecs.empty() ? + SE.getConstant(Ty, 0) : + SE.getAddExpr(NoAddRecs); + // If it returned an add, use the operands. Otherwise it simplified + // the sum into a single value, so just use that. + Ops.clear(); + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum)) + Ops.append(Add->op_begin(), Add->op_end()); + else if (!Sum->isZero()) + Ops.push_back(Sum); + // Then append the addrecs. + Ops.append(AddRecs.begin(), AddRecs.end()); +} + +/// SplitAddRecs - Flatten a list of add operands, moving addrec start values +/// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}. +/// This helps expose more opportunities for folding parts of the expressions +/// into GEP indices. +/// +static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops, + Type *Ty, + ScalarEvolution &SE) { + // Find the addrecs. + SmallVector<const SCEV *, 8> AddRecs; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) { + const SCEV *Start = A->getStart(); + if (Start->isZero()) break; + const SCEV *Zero = SE.getConstant(Ty, 0); + AddRecs.push_back(SE.getAddRecExpr(Zero, + A->getStepRecurrence(SE), + A->getLoop(), + A->getNoWrapFlags(SCEV::FlagNW))); + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) { + Ops[i] = Zero; + Ops.append(Add->op_begin(), Add->op_end()); + e += Add->getNumOperands(); + } else { + Ops[i] = Start; + } + } + if (!AddRecs.empty()) { + // Add the addrecs onto the end of the list. + Ops.append(AddRecs.begin(), AddRecs.end()); + // Resort the operand list, moving any constants to the front. + SimplifyAddOperands(Ops, Ty, SE); + } +} + +/// expandAddToGEP - Expand an addition expression with a pointer type into +/// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps +/// BasicAliasAnalysis and other passes analyze the result. See the rules +/// for getelementptr vs. inttoptr in +/// http://llvm.org/docs/LangRef.html#pointeraliasing +/// for details. +/// +/// Design note: The correctness of using getelementptr here depends on +/// ScalarEvolution not recognizing inttoptr and ptrtoint operators, as +/// they may introduce pointer arithmetic which may not be safely converted +/// into getelementptr. +/// +/// Design note: It might seem desirable for this function to be more +/// loop-aware. If some of the indices are loop-invariant while others +/// aren't, it might seem desirable to emit multiple GEPs, keeping the +/// loop-invariant portions of the overall computation outside the loop. +/// However, there are a few reasons this is not done here. Hoisting simple +/// arithmetic is a low-level optimization that often isn't very +/// important until late in the optimization process. In fact, passes +/// like InstructionCombining will combine GEPs, even if it means +/// pushing loop-invariant computation down into loops, so even if the +/// GEPs were split here, the work would quickly be undone. The +/// LoopStrengthReduction pass, which is usually run quite late (and +/// after the last InstructionCombining pass), takes care of hoisting +/// loop-invariant portions of expressions, after considering what +/// can be folded using target addressing modes. +/// +Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, + const SCEV *const *op_end, + PointerType *PTy, + Type *Ty, + Value *V) { + Type *OriginalElTy = PTy->getElementType(); + Type *ElTy = OriginalElTy; + SmallVector<Value *, 4> GepIndices; + SmallVector<const SCEV *, 8> Ops(op_begin, op_end); + bool AnyNonZeroIndices = false; + + // Split AddRecs up into parts as either of the parts may be usable + // without the other. + SplitAddRecs(Ops, Ty, SE); + + Type *IntPtrTy = DL.getIntPtrType(PTy); + + // Descend down the pointer's type and attempt to convert the other + // operands into GEP indices, at each level. The first index in a GEP + // indexes into the array implied by the pointer operand; the rest of + // the indices index into the element or field type selected by the + // preceding index. + for (;;) { + // If the scale size is not 0, attempt to factor out a scale for + // array indexing. + SmallVector<const SCEV *, 8> ScaledOps; + if (ElTy->isSized()) { + const SCEV *ElSize = SE.getSizeOfExpr(IntPtrTy, ElTy); + if (!ElSize->isZero()) { + SmallVector<const SCEV *, 8> NewOps; + for (const SCEV *Op : Ops) { + const SCEV *Remainder = SE.getConstant(Ty, 0); + if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) { + // Op now has ElSize factored out. + ScaledOps.push_back(Op); + if (!Remainder->isZero()) + NewOps.push_back(Remainder); + AnyNonZeroIndices = true; + } else { + // The operand was not divisible, so add it to the list of operands + // we'll scan next iteration. + NewOps.push_back(Op); + } + } + // If we made any changes, update Ops. + if (!ScaledOps.empty()) { + Ops = NewOps; + SimplifyAddOperands(Ops, Ty, SE); + } + } + } + + // Record the scaled array index for this level of the type. If + // we didn't find any operands that could be factored, tentatively + // assume that element zero was selected (since the zero offset + // would obviously be folded away). + Value *Scaled = ScaledOps.empty() ? + Constant::getNullValue(Ty) : + expandCodeFor(SE.getAddExpr(ScaledOps), Ty); + GepIndices.push_back(Scaled); + + // Collect struct field index operands. + while (StructType *STy = dyn_cast<StructType>(ElTy)) { + bool FoundFieldNo = false; + // An empty struct has no fields. + if (STy->getNumElements() == 0) break; + // Field offsets are known. See if a constant offset falls within any of + // the struct fields. + if (Ops.empty()) + break; + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0])) + if (SE.getTypeSizeInBits(C->getType()) <= 64) { + const StructLayout &SL = *DL.getStructLayout(STy); + uint64_t FullOffset = C->getValue()->getZExtValue(); + if (FullOffset < SL.getSizeInBytes()) { + unsigned ElIdx = SL.getElementContainingOffset(FullOffset); + GepIndices.push_back( + ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx)); + ElTy = STy->getTypeAtIndex(ElIdx); + Ops[0] = + SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx)); + AnyNonZeroIndices = true; + FoundFieldNo = true; + } + } + // If no struct field offsets were found, tentatively assume that + // field zero was selected (since the zero offset would obviously + // be folded away). + if (!FoundFieldNo) { + ElTy = STy->getTypeAtIndex(0u); + GepIndices.push_back( + Constant::getNullValue(Type::getInt32Ty(Ty->getContext()))); + } + } + + if (ArrayType *ATy = dyn_cast<ArrayType>(ElTy)) + ElTy = ATy->getElementType(); + else + break; + } + + // If none of the operands were convertible to proper GEP indices, cast + // the base to i8* and do an ugly getelementptr with that. It's still + // better than ptrtoint+arithmetic+inttoptr at least. + if (!AnyNonZeroIndices) { + // Cast the base to i8*. + V = InsertNoopCastOfTo(V, + Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace())); + + assert(!isa<Instruction>(V) || + SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); + + // Expand the operands for a plain byte offset. + Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty); + + // Fold a GEP with constant operands. + if (Constant *CLHS = dyn_cast<Constant>(V)) + if (Constant *CRHS = dyn_cast<Constant>(Idx)) + return ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ty->getContext()), + CLHS, CRHS); + + // Do a quick scan to see if we have this GEP nearby. If so, reuse it. + unsigned ScanLimit = 6; + BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); + // Scanning starts from the last instruction before the insertion point. + BasicBlock::iterator IP = Builder.GetInsertPoint(); + if (IP != BlockBegin) { + --IP; + for (; ScanLimit; --IP, --ScanLimit) { + // Don't count dbg.value against the ScanLimit, to avoid perturbing the + // generated code. + if (isa<DbgInfoIntrinsic>(IP)) + ScanLimit++; + if (IP->getOpcode() == Instruction::GetElementPtr && + IP->getOperand(0) == V && IP->getOperand(1) == Idx) + return &*IP; + if (IP == BlockBegin) break; + } + } + + // Save the original insertion point so we can restore it when we're done. + SCEVInsertPointGuard Guard(Builder, this); + + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break; + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; + + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); + } + + // Emit a GEP. + Value *GEP = Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep"); + rememberInstruction(GEP); + + return GEP; + } + + { + SCEVInsertPointGuard Guard(Builder, this); + + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(V)) break; + + bool AnyIndexNotLoopInvariant = any_of( + GepIndices, [L](Value *Op) { return !L->isLoopInvariant(Op); }); + + if (AnyIndexNotLoopInvariant) + break; + + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; + + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); + } + + // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, + // because ScalarEvolution may have changed the address arithmetic to + // compute a value which is beyond the end of the allocated object. + Value *Casted = V; + if (V->getType() != PTy) + Casted = InsertNoopCastOfTo(Casted, PTy); + Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep"); + Ops.push_back(SE.getUnknown(GEP)); + rememberInstruction(GEP); + } + + return expand(SE.getAddExpr(Ops)); +} + +Value *SCEVExpander::expandAddToGEP(const SCEV *Op, PointerType *PTy, Type *Ty, + Value *V) { + const SCEV *const Ops[1] = {Op}; + return expandAddToGEP(Ops, Ops + 1, PTy, Ty, V); +} + +/// PickMostRelevantLoop - Given two loops pick the one that's most relevant for +/// SCEV expansion. If they are nested, this is the most nested. If they are +/// neighboring, pick the later. +static const Loop *PickMostRelevantLoop(const Loop *A, const Loop *B, + DominatorTree &DT) { + if (!A) return B; + if (!B) return A; + if (A->contains(B)) return B; + if (B->contains(A)) return A; + if (DT.dominates(A->getHeader(), B->getHeader())) return B; + if (DT.dominates(B->getHeader(), A->getHeader())) return A; + return A; // Arbitrarily break the tie. +} + +/// getRelevantLoop - Get the most relevant loop associated with the given +/// expression, according to PickMostRelevantLoop. +const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { + // Test whether we've already computed the most relevant loop for this SCEV. + auto Pair = RelevantLoops.insert(std::make_pair(S, nullptr)); + if (!Pair.second) + return Pair.first->second; + + if (isa<SCEVConstant>(S)) + // A constant has no relevant loops. + return nullptr; + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + if (const Instruction *I = dyn_cast<Instruction>(U->getValue())) + return Pair.first->second = SE.LI.getLoopFor(I->getParent()); + // A non-instruction has no relevant loops. + return nullptr; + } + if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) { + const Loop *L = nullptr; + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) + L = AR->getLoop(); + for (const SCEV *Op : N->operands()) + L = PickMostRelevantLoop(L, getRelevantLoop(Op), SE.DT); + return RelevantLoops[N] = L; + } + if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) { + const Loop *Result = getRelevantLoop(C->getOperand()); + return RelevantLoops[C] = Result; + } + if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { + const Loop *Result = PickMostRelevantLoop( + getRelevantLoop(D->getLHS()), getRelevantLoop(D->getRHS()), SE.DT); + return RelevantLoops[D] = Result; + } + llvm_unreachable("Unexpected SCEV type!"); +} + +namespace { + +/// LoopCompare - Compare loops by PickMostRelevantLoop. +class LoopCompare { + DominatorTree &DT; +public: + explicit LoopCompare(DominatorTree &dt) : DT(dt) {} + + bool operator()(std::pair<const Loop *, const SCEV *> LHS, + std::pair<const Loop *, const SCEV *> RHS) const { + // Keep pointer operands sorted at the end. + if (LHS.second->getType()->isPointerTy() != + RHS.second->getType()->isPointerTy()) + return LHS.second->getType()->isPointerTy(); + + // Compare loops with PickMostRelevantLoop. + if (LHS.first != RHS.first) + return PickMostRelevantLoop(LHS.first, RHS.first, DT) != LHS.first; + + // If one operand is a non-constant negative and the other is not, + // put the non-constant negative on the right so that a sub can + // be used instead of a negate and add. + if (LHS.second->isNonConstantNegative()) { + if (!RHS.second->isNonConstantNegative()) + return false; + } else if (RHS.second->isNonConstantNegative()) + return true; + + // Otherwise they are equivalent according to this comparison. + return false; + } +}; + +} + +Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + + // Collect all the add operands in a loop, along with their associated loops. + // Iterate in reverse so that constants are emitted last, all else equal, and + // so that pointer operands are inserted first, which the code below relies on + // to form more involved GEPs. + SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; + for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(S->op_end()), + E(S->op_begin()); I != E; ++I) + OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + + // Sort by loop. Use a stable sort so that constants follow non-constants and + // pointer operands precede non-pointer operands. + llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT)); + + // Emit instructions to add all the operands. Hoist as much as possible + // out of loops, and form meaningful getelementptrs where possible. + Value *Sum = nullptr; + for (auto I = OpsAndLoops.begin(), E = OpsAndLoops.end(); I != E;) { + const Loop *CurLoop = I->first; + const SCEV *Op = I->second; + if (!Sum) { + // This is the first operand. Just expand it. + Sum = expand(Op); + ++I; + } else if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) { + // The running sum expression is a pointer. Try to form a getelementptr + // at this level with that as the base. + SmallVector<const SCEV *, 4> NewOps; + for (; I != E && I->first == CurLoop; ++I) { + // If the operand is SCEVUnknown and not instructions, peek through + // it, to enable more of it to be folded into the GEP. + const SCEV *X = I->second; + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(X)) + if (!isa<Instruction>(U->getValue())) + X = SE.getSCEV(U->getValue()); + NewOps.push_back(X); + } + Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); + } else if (PointerType *PTy = dyn_cast<PointerType>(Op->getType())) { + // The running sum is an integer, and there's a pointer at this level. + // Try to form a getelementptr. If the running sum is instructions, + // use a SCEVUnknown to avoid re-analyzing them. + SmallVector<const SCEV *, 4> NewOps; + NewOps.push_back(isa<Instruction>(Sum) ? SE.getUnknown(Sum) : + SE.getSCEV(Sum)); + for (++I; I != E && I->first == CurLoop; ++I) + NewOps.push_back(I->second); + Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op)); + } else if (Op->isNonConstantNegative()) { + // Instead of doing a negate and add, just do a subtract. + Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty); + Sum = InsertNoopCastOfTo(Sum, Ty); + Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap, + /*IsSafeToHoist*/ true); + ++I; + } else { + // A simple add. + Value *W = expandCodeFor(Op, Ty); + Sum = InsertNoopCastOfTo(Sum, Ty); + // Canonicalize a constant to the RHS. + if (isa<Constant>(Sum)) std::swap(Sum, W); + Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(), + /*IsSafeToHoist*/ true); + ++I; + } + } + + return Sum; +} + +Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + + // Collect all the mul operands in a loop, along with their associated loops. + // Iterate in reverse so that constants are emitted last, all else equal. + SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; + for (std::reverse_iterator<SCEVMulExpr::op_iterator> I(S->op_end()), + E(S->op_begin()); I != E; ++I) + OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + + // Sort by loop. Use a stable sort so that constants follow non-constants. + llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT)); + + // Emit instructions to mul all the operands. Hoist as much as possible + // out of loops. + Value *Prod = nullptr; + auto I = OpsAndLoops.begin(); + + // Expand the calculation of X pow N in the following manner: + // Let N = P1 + P2 + ... + PK, where all P are powers of 2. Then: + // X pow N = (X pow P1) * (X pow P2) * ... * (X pow PK). + const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops, &Ty]() { + auto E = I; + // Calculate how many times the same operand from the same loop is included + // into this power. + uint64_t Exponent = 0; + const uint64_t MaxExponent = UINT64_MAX >> 1; + // No one sane will ever try to calculate such huge exponents, but if we + // need this, we stop on UINT64_MAX / 2 because we need to exit the loop + // below when the power of 2 exceeds our Exponent, and we want it to be + // 1u << 31 at most to not deal with unsigned overflow. + while (E != OpsAndLoops.end() && *I == *E && Exponent != MaxExponent) { + ++Exponent; + ++E; + } + assert(Exponent > 0 && "Trying to calculate a zeroth exponent of operand?"); + + // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them + // that are needed into the result. + Value *P = expandCodeFor(I->second, Ty); + Value *Result = nullptr; + if (Exponent & 1) + Result = P; + for (uint64_t BinExp = 2; BinExp <= Exponent; BinExp <<= 1) { + P = InsertBinop(Instruction::Mul, P, P, SCEV::FlagAnyWrap, + /*IsSafeToHoist*/ true); + if (Exponent & BinExp) + Result = Result ? InsertBinop(Instruction::Mul, Result, P, + SCEV::FlagAnyWrap, + /*IsSafeToHoist*/ true) + : P; + } + + I = E; + assert(Result && "Nothing was expanded?"); + return Result; + }; + + while (I != OpsAndLoops.end()) { + if (!Prod) { + // This is the first operand. Just expand it. + Prod = ExpandOpBinPowN(); + } else if (I->second->isAllOnesValue()) { + // Instead of doing a multiply by negative one, just do a negate. + Prod = InsertNoopCastOfTo(Prod, Ty); + Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod, + SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); + ++I; + } else { + // A simple mul. + Value *W = ExpandOpBinPowN(); + Prod = InsertNoopCastOfTo(Prod, Ty); + // Canonicalize a constant to the RHS. + if (isa<Constant>(Prod)) std::swap(Prod, W); + const APInt *RHS; + if (match(W, m_Power2(RHS))) { + // Canonicalize Prod*(1<<C) to Prod<<C. + assert(!Ty->isVectorTy() && "vector types are not SCEVable"); + auto NWFlags = S->getNoWrapFlags(); + // clear nsw flag if shl will produce poison value. + if (RHS->logBase2() == RHS->getBitWidth() - 1) + NWFlags = ScalarEvolution::clearFlags(NWFlags, SCEV::FlagNSW); + Prod = InsertBinop(Instruction::Shl, Prod, + ConstantInt::get(Ty, RHS->logBase2()), NWFlags, + /*IsSafeToHoist*/ true); + } else { + Prod = InsertBinop(Instruction::Mul, Prod, W, S->getNoWrapFlags(), + /*IsSafeToHoist*/ true); + } + } + } + + return Prod; +} + +Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + + Value *LHS = expandCodeFor(S->getLHS(), Ty); + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) { + const APInt &RHS = SC->getAPInt(); + if (RHS.isPowerOf2()) + return InsertBinop(Instruction::LShr, LHS, + ConstantInt::get(Ty, RHS.logBase2()), + SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); + } + + Value *RHS = expandCodeFor(S->getRHS(), Ty); + return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap, + /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS())); +} + +/// Move parts of Base into Rest to leave Base with the minimal +/// expression that provides a pointer operand suitable for a +/// GEP expansion. +static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest, + ScalarEvolution &SE) { + while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) { + Base = A->getStart(); + Rest = SE.getAddExpr(Rest, + SE.getAddRecExpr(SE.getConstant(A->getType(), 0), + A->getStepRecurrence(SE), + A->getLoop(), + A->getNoWrapFlags(SCEV::FlagNW))); + } + if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) { + Base = A->getOperand(A->getNumOperands()-1); + SmallVector<const SCEV *, 8> NewAddOps(A->op_begin(), A->op_end()); + NewAddOps.back() = Rest; + Rest = SE.getAddExpr(NewAddOps); + ExposePointerBase(Base, Rest, SE); + } +} + +/// Determine if this is a well-behaved chain of instructions leading back to +/// the PHI. If so, it may be reused by expanded expressions. +bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV, + const Loop *L) { + if (IncV->getNumOperands() == 0 || isa<PHINode>(IncV) || + (isa<CastInst>(IncV) && !isa<BitCastInst>(IncV))) + return false; + // If any of the operands don't dominate the insert position, bail. + // Addrec operands are always loop-invariant, so this can only happen + // if there are instructions which haven't been hoisted. + if (L == IVIncInsertLoop) { + for (User::op_iterator OI = IncV->op_begin()+1, + OE = IncV->op_end(); OI != OE; ++OI) + if (Instruction *OInst = dyn_cast<Instruction>(OI)) + if (!SE.DT.dominates(OInst, IVIncInsertPos)) + return false; + } + // Advance to the next instruction. + IncV = dyn_cast<Instruction>(IncV->getOperand(0)); + if (!IncV) + return false; + + if (IncV->mayHaveSideEffects()) + return false; + + if (IncV == PN) + return true; + + return isNormalAddRecExprPHI(PN, IncV, L); +} + +/// getIVIncOperand returns an induction variable increment's induction +/// variable operand. +/// +/// If allowScale is set, any type of GEP is allowed as long as the nonIV +/// operands dominate InsertPos. +/// +/// If allowScale is not set, ensure that a GEP increment conforms to one of the +/// simple patterns generated by getAddRecExprPHILiterally and +/// expandAddtoGEP. If the pattern isn't recognized, return NULL. +Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV, + Instruction *InsertPos, + bool allowScale) { + if (IncV == InsertPos) + return nullptr; + + switch (IncV->getOpcode()) { + default: + return nullptr; + // Check for a simple Add/Sub or GEP of a loop invariant step. + case Instruction::Add: + case Instruction::Sub: { + Instruction *OInst = dyn_cast<Instruction>(IncV->getOperand(1)); + if (!OInst || SE.DT.dominates(OInst, InsertPos)) + return dyn_cast<Instruction>(IncV->getOperand(0)); + return nullptr; + } + case Instruction::BitCast: + return dyn_cast<Instruction>(IncV->getOperand(0)); + case Instruction::GetElementPtr: + for (auto I = IncV->op_begin() + 1, E = IncV->op_end(); I != E; ++I) { + if (isa<Constant>(*I)) + continue; + if (Instruction *OInst = dyn_cast<Instruction>(*I)) { + if (!SE.DT.dominates(OInst, InsertPos)) + return nullptr; + } + if (allowScale) { + // allow any kind of GEP as long as it can be hoisted. + continue; + } + // This must be a pointer addition of constants (pretty), which is already + // handled, or some number of address-size elements (ugly). Ugly geps + // have 2 operands. i1* is used by the expander to represent an + // address-size element. + if (IncV->getNumOperands() != 2) + return nullptr; + unsigned AS = cast<PointerType>(IncV->getType())->getAddressSpace(); + if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS) + && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS)) + return nullptr; + break; + } + return dyn_cast<Instruction>(IncV->getOperand(0)); + } +} + +/// If the insert point of the current builder or any of the builders on the +/// stack of saved builders has 'I' as its insert point, update it to point to +/// the instruction after 'I'. This is intended to be used when the instruction +/// 'I' is being moved. If this fixup is not done and 'I' is moved to a +/// different block, the inconsistent insert point (with a mismatched +/// Instruction and Block) can lead to an instruction being inserted in a block +/// other than its parent. +void SCEVExpander::fixupInsertPoints(Instruction *I) { + BasicBlock::iterator It(*I); + BasicBlock::iterator NewInsertPt = std::next(It); + if (Builder.GetInsertPoint() == It) + Builder.SetInsertPoint(&*NewInsertPt); + for (auto *InsertPtGuard : InsertPointGuards) + if (InsertPtGuard->GetInsertPoint() == It) + InsertPtGuard->SetInsertPoint(NewInsertPt); +} + +/// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make +/// it available to other uses in this loop. Recursively hoist any operands, +/// until we reach a value that dominates InsertPos. +bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { + if (SE.DT.dominates(IncV, InsertPos)) + return true; + + // InsertPos must itself dominate IncV so that IncV's new position satisfies + // its existing users. + if (isa<PHINode>(InsertPos) || + !SE.DT.dominates(InsertPos->getParent(), IncV->getParent())) + return false; + + if (!SE.LI.movementPreservesLCSSAForm(IncV, InsertPos)) + return false; + + // Check that the chain of IV operands leading back to Phi can be hoisted. + SmallVector<Instruction*, 4> IVIncs; + for(;;) { + Instruction *Oper = getIVIncOperand(IncV, InsertPos, /*allowScale*/true); + if (!Oper) + return false; + // IncV is safe to hoist. + IVIncs.push_back(IncV); + IncV = Oper; + if (SE.DT.dominates(IncV, InsertPos)) + break; + } + for (auto I = IVIncs.rbegin(), E = IVIncs.rend(); I != E; ++I) { + fixupInsertPoints(*I); + (*I)->moveBefore(InsertPos); + } + return true; +} + +/// Determine if this cyclic phi is in a form that would have been generated by +/// LSR. We don't care if the phi was actually expanded in this pass, as long +/// as it is in a low-cost form, for example, no implied multiplication. This +/// should match any patterns generated by getAddRecExprPHILiterally and +/// expandAddtoGEP. +bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, + const Loop *L) { + for(Instruction *IVOper = IncV; + (IVOper = getIVIncOperand(IVOper, L->getLoopPreheader()->getTerminator(), + /*allowScale=*/false));) { + if (IVOper == PN) + return true; + } + return false; +} + +/// expandIVInc - Expand an IV increment at Builder's current InsertPos. +/// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may +/// need to materialize IV increments elsewhere to handle difficult situations. +Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, + Type *ExpandTy, Type *IntTy, + bool useSubtract) { + Value *IncV; + // If the PHI is a pointer, use a GEP, otherwise use an add or sub. + if (ExpandTy->isPointerTy()) { + PointerType *GEPPtrTy = cast<PointerType>(ExpandTy); + // If the step isn't constant, don't use an implicitly scaled GEP, because + // that would require a multiply inside the loop. + if (!isa<ConstantInt>(StepV)) + GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()), + GEPPtrTy->getAddressSpace()); + IncV = expandAddToGEP(SE.getSCEV(StepV), GEPPtrTy, IntTy, PN); + if (IncV->getType() != PN->getType()) { + IncV = Builder.CreateBitCast(IncV, PN->getType()); + rememberInstruction(IncV); + } + } else { + IncV = useSubtract ? + Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : + Builder.CreateAdd(PN, StepV, Twine(IVName) + ".iv.next"); + rememberInstruction(IncV); + } + return IncV; +} + +/// Hoist the addrec instruction chain rooted in the loop phi above the +/// position. This routine assumes that this is possible (has been checked). +void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, + Instruction *Pos, PHINode *LoopPhi) { + do { + if (DT->dominates(InstToHoist, Pos)) + break; + // Make sure the increment is where we want it. But don't move it + // down past a potential existing post-inc user. + fixupInsertPoints(InstToHoist); + InstToHoist->moveBefore(Pos); + Pos = InstToHoist; + InstToHoist = cast<Instruction>(InstToHoist->getOperand(0)); + } while (InstToHoist != LoopPhi); +} + +/// Check whether we can cheaply express the requested SCEV in terms of +/// the available PHI SCEV by truncation and/or inversion of the step. +static bool canBeCheaplyTransformed(ScalarEvolution &SE, + const SCEVAddRecExpr *Phi, + const SCEVAddRecExpr *Requested, + bool &InvertStep) { + Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType()); + Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType()); + + if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth()) + return false; + + // Try truncate it if necessary. + Phi = dyn_cast<SCEVAddRecExpr>(SE.getTruncateOrNoop(Phi, RequestedTy)); + if (!Phi) + return false; + + // Check whether truncation will help. + if (Phi == Requested) { + InvertStep = false; + return true; + } + + // Check whether inverting will help: {R,+,-1} == R - {0,+,1}. + if (SE.getAddExpr(Requested->getStart(), + SE.getNegativeSCEV(Requested)) == Phi) { + InvertStep = true; + return true; + } + + return false; +} + +static bool IsIncrementNSW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { + if (!isa<IntegerType>(AR->getType())) + return false; + + unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth(); + Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpAfterExtend = SE.getAddExpr(SE.getSignExtendExpr(Step, WideTy), + SE.getSignExtendExpr(AR, WideTy)); + const SCEV *ExtendAfterOp = + SE.getSignExtendExpr(SE.getAddExpr(AR, Step), WideTy); + return ExtendAfterOp == OpAfterExtend; +} + +static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { + if (!isa<IntegerType>(AR->getType())) + return false; + + unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth(); + Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpAfterExtend = SE.getAddExpr(SE.getZeroExtendExpr(Step, WideTy), + SE.getZeroExtendExpr(AR, WideTy)); + const SCEV *ExtendAfterOp = + SE.getZeroExtendExpr(SE.getAddExpr(AR, Step), WideTy); + return ExtendAfterOp == OpAfterExtend; +} + +/// getAddRecExprPHILiterally - Helper for expandAddRecExprLiterally. Expand +/// the base addrec, which is the addrec without any non-loop-dominating +/// values, and return the PHI. +PHINode * +SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, + const Loop *L, + Type *ExpandTy, + Type *IntTy, + Type *&TruncTy, + bool &InvertStep) { + assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position"); + + // Reuse a previously-inserted PHI, if present. + BasicBlock *LatchBlock = L->getLoopLatch(); + if (LatchBlock) { + PHINode *AddRecPhiMatch = nullptr; + Instruction *IncV = nullptr; + TruncTy = nullptr; + InvertStep = false; + + // Only try partially matching scevs that need truncation and/or + // step-inversion if we know this loop is outside the current loop. + bool TryNonMatchingSCEV = + IVIncInsertLoop && + SE.DT.properlyDominates(LatchBlock, IVIncInsertLoop->getHeader()); + + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) + continue; + + const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); + if (!PhiSCEV) + continue; + + bool IsMatchingSCEV = PhiSCEV == Normalized; + // We only handle truncation and inversion of phi recurrences for the + // expanded expression if the expanded expression's loop dominates the + // loop we insert to. Check now, so we can bail out early. + if (!IsMatchingSCEV && !TryNonMatchingSCEV) + continue; + + // TODO: this possibly can be reworked to avoid this cast at all. + Instruction *TempIncV = + dyn_cast<Instruction>(PN.getIncomingValueForBlock(LatchBlock)); + if (!TempIncV) + continue; + + // Check whether we can reuse this PHI node. + if (LSRMode) { + if (!isExpandedAddRecExprPHI(&PN, TempIncV, L)) + continue; + if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos)) + continue; + } else { + if (!isNormalAddRecExprPHI(&PN, TempIncV, L)) + continue; + } + + // Stop if we have found an exact match SCEV. + if (IsMatchingSCEV) { + IncV = TempIncV; + TruncTy = nullptr; + InvertStep = false; + AddRecPhiMatch = &PN; + break; + } + + // Try whether the phi can be translated into the requested form + // (truncated and/or offset by a constant). + if ((!TruncTy || InvertStep) && + canBeCheaplyTransformed(SE, PhiSCEV, Normalized, InvertStep)) { + // Record the phi node. But don't stop we might find an exact match + // later. + AddRecPhiMatch = &PN; + IncV = TempIncV; + TruncTy = SE.getEffectiveSCEVType(Normalized->getType()); + } + } + + if (AddRecPhiMatch) { + // Potentially, move the increment. We have made sure in + // isExpandedAddRecExprPHI or hoistIVInc that this is possible. + if (L == IVIncInsertLoop) + hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch); + + // Ok, the add recurrence looks usable. + // Remember this PHI, even in post-inc mode. + InsertedValues.insert(AddRecPhiMatch); + // Remember the increment. + rememberInstruction(IncV); + return AddRecPhiMatch; + } + } + + // Save the original insertion point so we can restore it when we're done. + SCEVInsertPointGuard Guard(Builder, this); + + // Another AddRec may need to be recursively expanded below. For example, if + // this AddRec is quadratic, the StepV may itself be an AddRec in this + // loop. Remove this loop from the PostIncLoops set before expanding such + // AddRecs. Otherwise, we cannot find a valid position for the step + // (i.e. StepV can never dominate its loop header). Ideally, we could do + // SavedIncLoops.swap(PostIncLoops), but we generally have a single element, + // so it's not worth implementing SmallPtrSet::swap. + PostIncLoopSet SavedPostIncLoops = PostIncLoops; + PostIncLoops.clear(); + + // Expand code for the start value into the loop preheader. + assert(L->getLoopPreheader() && + "Can't expand add recurrences without a loop preheader!"); + Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy, + L->getLoopPreheader()->getTerminator()); + + // StartV must have been be inserted into L's preheader to dominate the new + // phi. + assert(!isa<Instruction>(StartV) || + SE.DT.properlyDominates(cast<Instruction>(StartV)->getParent(), + L->getHeader())); + + // Expand code for the step value. Do this before creating the PHI so that PHI + // reuse code doesn't see an incomplete PHI. + const SCEV *Step = Normalized->getStepRecurrence(SE); + // If the stride is negative, insert a sub instead of an add for the increment + // (unless it's a constant, because subtracts of constants are canonicalized + // to adds). + bool useSubtract = !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); + if (useSubtract) + Step = SE.getNegativeSCEV(Step); + // Expand the step somewhere that dominates the loop header. + Value *StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); + + // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if + // we actually do emit an addition. It does not apply if we emit a + // subtraction. + bool IncrementIsNUW = !useSubtract && IsIncrementNUW(SE, Normalized); + bool IncrementIsNSW = !useSubtract && IsIncrementNSW(SE, Normalized); + + // Create the PHI. + BasicBlock *Header = L->getHeader(); + Builder.SetInsertPoint(Header, Header->begin()); + pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); + PHINode *PN = Builder.CreatePHI(ExpandTy, std::distance(HPB, HPE), + Twine(IVName) + ".iv"); + rememberInstruction(PN); + + // Create the step instructions and populate the PHI. + for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { + BasicBlock *Pred = *HPI; + + // Add a start value. + if (!L->contains(Pred)) { + PN->addIncoming(StartV, Pred); + continue; + } + + // Create a step value and add it to the PHI. + // If IVIncInsertLoop is non-null and equal to the addrec's loop, insert the + // instructions at IVIncInsertPos. + Instruction *InsertPos = L == IVIncInsertLoop ? + IVIncInsertPos : Pred->getTerminator(); + Builder.SetInsertPoint(InsertPos); + Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); + + if (isa<OverflowingBinaryOperator>(IncV)) { + if (IncrementIsNUW) + cast<BinaryOperator>(IncV)->setHasNoUnsignedWrap(); + if (IncrementIsNSW) + cast<BinaryOperator>(IncV)->setHasNoSignedWrap(); + } + PN->addIncoming(IncV, Pred); + } + + // After expanding subexpressions, restore the PostIncLoops set so the caller + // can ensure that IVIncrement dominates the current uses. + PostIncLoops = SavedPostIncLoops; + + // Remember this PHI, even in post-inc mode. + InsertedValues.insert(PN); + + return PN; +} + +Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { + Type *STy = S->getType(); + Type *IntTy = SE.getEffectiveSCEVType(STy); + const Loop *L = S->getLoop(); + + // Determine a normalized form of this expression, which is the expression + // before any post-inc adjustment is made. + const SCEVAddRecExpr *Normalized = S; + if (PostIncLoops.count(L)) { + PostIncLoopSet Loops; + Loops.insert(L); + Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE)); + } + + // Strip off any non-loop-dominating component from the addrec start. + const SCEV *Start = Normalized->getStart(); + const SCEV *PostLoopOffset = nullptr; + if (!SE.properlyDominates(Start, L->getHeader())) { + PostLoopOffset = Start; + Start = SE.getConstant(Normalized->getType(), 0); + Normalized = cast<SCEVAddRecExpr>( + SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE), + Normalized->getLoop(), + Normalized->getNoWrapFlags(SCEV::FlagNW))); + } + + // Strip off any non-loop-dominating component from the addrec step. + const SCEV *Step = Normalized->getStepRecurrence(SE); + const SCEV *PostLoopScale = nullptr; + if (!SE.dominates(Step, L->getHeader())) { + PostLoopScale = Step; + Step = SE.getConstant(Normalized->getType(), 1); + if (!Start->isZero()) { + // The normalization below assumes that Start is constant zero, so if + // it isn't re-associate Start to PostLoopOffset. + assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?"); + PostLoopOffset = Start; + Start = SE.getConstant(Normalized->getType(), 0); + } + Normalized = + cast<SCEVAddRecExpr>(SE.getAddRecExpr( + Start, Step, Normalized->getLoop(), + Normalized->getNoWrapFlags(SCEV::FlagNW))); + } + + // Expand the core addrec. If we need post-loop scaling, force it to + // expand to an integer type to avoid the need for additional casting. + Type *ExpandTy = PostLoopScale ? IntTy : STy; + // We can't use a pointer type for the addrec if the pointer type is + // non-integral. + Type *AddRecPHIExpandTy = + DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy; + + // In some cases, we decide to reuse an existing phi node but need to truncate + // it and/or invert the step. + Type *TruncTy = nullptr; + bool InvertStep = false; + PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy, + IntTy, TruncTy, InvertStep); + + // Accommodate post-inc mode, if necessary. + Value *Result; + if (!PostIncLoops.count(L)) + Result = PN; + else { + // In PostInc mode, use the post-incremented value. + BasicBlock *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "PostInc mode requires a unique loop latch!"); + Result = PN->getIncomingValueForBlock(LatchBlock); + + // For an expansion to use the postinc form, the client must call + // expandCodeFor with an InsertPoint that is either outside the PostIncLoop + // or dominated by IVIncInsertPos. + if (isa<Instruction>(Result) && + !SE.DT.dominates(cast<Instruction>(Result), + &*Builder.GetInsertPoint())) { + // The induction variable's postinc expansion does not dominate this use. + // IVUsers tries to prevent this case, so it is rare. However, it can + // happen when an IVUser outside the loop is not dominated by the latch + // block. Adjusting IVIncInsertPos before expansion begins cannot handle + // all cases. Consider a phi outside whose operand is replaced during + // expansion with the value of the postinc user. Without fundamentally + // changing the way postinc users are tracked, the only remedy is + // inserting an extra IV increment. StepV might fold into PostLoopOffset, + // but hopefully expandCodeFor handles that. + bool useSubtract = + !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); + if (useSubtract) + Step = SE.getNegativeSCEV(Step); + Value *StepV; + { + // Expand the step somewhere that dominates the loop header. + SCEVInsertPointGuard Guard(Builder, this); + StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); + } + Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); + } + } + + // We have decided to reuse an induction variable of a dominating loop. Apply + // truncation and/or inversion of the step. + if (TruncTy) { + Type *ResTy = Result->getType(); + // Normalize the result type. + if (ResTy != SE.getEffectiveSCEVType(ResTy)) + Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy)); + // Truncate the result. + if (TruncTy != Result->getType()) { + Result = Builder.CreateTrunc(Result, TruncTy); + rememberInstruction(Result); + } + // Invert the result. + if (InvertStep) { + Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy), + Result); + rememberInstruction(Result); + } + } + + // Re-apply any non-loop-dominating scale. + if (PostLoopScale) { + assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); + Result = InsertNoopCastOfTo(Result, IntTy); + Result = Builder.CreateMul(Result, + expandCodeFor(PostLoopScale, IntTy)); + rememberInstruction(Result); + } + + // Re-apply any non-loop-dominating offset. + if (PostLoopOffset) { + if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) { + if (Result->getType()->isIntegerTy()) { + Value *Base = expandCodeFor(PostLoopOffset, ExpandTy); + Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base); + } else { + Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result); + } + } else { + Result = InsertNoopCastOfTo(Result, IntTy); + Result = Builder.CreateAdd(Result, + expandCodeFor(PostLoopOffset, IntTy)); + rememberInstruction(Result); + } + } + + return Result; +} + +Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { + // In canonical mode we compute the addrec as an expression of a canonical IV + // using evaluateAtIteration and expand the resulting SCEV expression. This + // way we avoid introducing new IVs to carry on the comutation of the addrec + // throughout the loop. + // + // For nested addrecs evaluateAtIteration might need a canonical IV of a + // type wider than the addrec itself. Emitting a canonical IV of the + // proper type might produce non-legal types, for example expanding an i64 + // {0,+,2,+,1} addrec would need an i65 canonical IV. To avoid this just fall + // back to non-canonical mode for nested addrecs. + if (!CanonicalMode || (S->getNumOperands() > 2)) + return expandAddRecExprLiterally(S); + + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + const Loop *L = S->getLoop(); + + // First check for an existing canonical IV in a suitable type. + PHINode *CanonicalIV = nullptr; + if (PHINode *PN = L->getCanonicalInductionVariable()) + if (SE.getTypeSizeInBits(PN->getType()) >= SE.getTypeSizeInBits(Ty)) + CanonicalIV = PN; + + // Rewrite an AddRec in terms of the canonical induction variable, if + // its type is more narrow. + if (CanonicalIV && + SE.getTypeSizeInBits(CanonicalIV->getType()) > + SE.getTypeSizeInBits(Ty)) { + SmallVector<const SCEV *, 4> NewOps(S->getNumOperands()); + for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i) + NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType()); + Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(), + S->getNoWrapFlags(SCEV::FlagNW))); + BasicBlock::iterator NewInsertPt = + findInsertPointAfter(cast<Instruction>(V), Builder.GetInsertBlock()); + V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, + &*NewInsertPt); + return V; + } + + // {X,+,F} --> X + {0,+,F} + if (!S->getStart()->isZero()) { + SmallVector<const SCEV *, 4> NewOps(S->op_begin(), S->op_end()); + NewOps[0] = SE.getConstant(Ty, 0); + const SCEV *Rest = SE.getAddRecExpr(NewOps, L, + S->getNoWrapFlags(SCEV::FlagNW)); + + // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the + // comments on expandAddToGEP for details. + const SCEV *Base = S->getStart(); + // Dig into the expression to find the pointer base for a GEP. + const SCEV *ExposedRest = Rest; + ExposePointerBase(Base, ExposedRest, SE); + // If we found a pointer, expand the AddRec with a GEP. + if (PointerType *PTy = dyn_cast<PointerType>(Base->getType())) { + // Make sure the Base isn't something exotic, such as a multiplied + // or divided pointer value. In those cases, the result type isn't + // actually a pointer type. + if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) { + Value *StartV = expand(Base); + assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); + return expandAddToGEP(ExposedRest, PTy, Ty, StartV); + } + } + + // Just do a normal add. Pre-expand the operands to suppress folding. + // + // The LHS and RHS values are factored out of the expand call to make the + // output independent of the argument evaluation order. + const SCEV *AddExprLHS = SE.getUnknown(expand(S->getStart())); + const SCEV *AddExprRHS = SE.getUnknown(expand(Rest)); + return expand(SE.getAddExpr(AddExprLHS, AddExprRHS)); + } + + // If we don't yet have a canonical IV, create one. + if (!CanonicalIV) { + // Create and insert the PHI node for the induction variable in the + // specified loop. + BasicBlock *Header = L->getHeader(); + pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); + CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar", + &Header->front()); + rememberInstruction(CanonicalIV); + + SmallSet<BasicBlock *, 4> PredSeen; + Constant *One = ConstantInt::get(Ty, 1); + for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { + BasicBlock *HP = *HPI; + if (!PredSeen.insert(HP).second) { + // There must be an incoming value for each predecessor, even the + // duplicates! + CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP); + continue; + } + + if (L->contains(HP)) { + // Insert a unit add instruction right before the terminator + // corresponding to the back-edge. + Instruction *Add = BinaryOperator::CreateAdd(CanonicalIV, One, + "indvar.next", + HP->getTerminator()); + Add->setDebugLoc(HP->getTerminator()->getDebugLoc()); + rememberInstruction(Add); + CanonicalIV->addIncoming(Add, HP); + } else { + CanonicalIV->addIncoming(Constant::getNullValue(Ty), HP); + } + } + } + + // {0,+,1} --> Insert a canonical induction variable into the loop! + if (S->isAffine() && S->getOperand(1)->isOne()) { + assert(Ty == SE.getEffectiveSCEVType(CanonicalIV->getType()) && + "IVs with types different from the canonical IV should " + "already have been handled!"); + return CanonicalIV; + } + + // {0,+,F} --> {0,+,1} * F + + // If this is a simple linear addrec, emit it now as a special case. + if (S->isAffine()) // {0,+,F} --> i*F + return + expand(SE.getTruncateOrNoop( + SE.getMulExpr(SE.getUnknown(CanonicalIV), + SE.getNoopOrAnyExtend(S->getOperand(1), + CanonicalIV->getType())), + Ty)); + + // If this is a chain of recurrences, turn it into a closed form, using the + // folders, then expandCodeFor the closed form. This allows the folders to + // simplify the expression without having to build a bunch of special code + // into this folder. + const SCEV *IH = SE.getUnknown(CanonicalIV); // Get I as a "symbolic" SCEV. + + // Promote S up to the canonical IV type, if the cast is foldable. + const SCEV *NewS = S; + const SCEV *Ext = SE.getNoopOrAnyExtend(S, CanonicalIV->getType()); + if (isa<SCEVAddRecExpr>(Ext)) + NewS = Ext; + + const SCEV *V = cast<SCEVAddRecExpr>(NewS)->evaluateAtIteration(IH, SE); + //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n"; + + // Truncate the result down to the original type, if needed. + const SCEV *T = SE.getTruncateOrNoop(V, Ty); + return expand(T); +} + +Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expandCodeFor(S->getOperand(), + SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *I = Builder.CreateTrunc(V, Ty); + rememberInstruction(I); + return I; +} + +Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expandCodeFor(S->getOperand(), + SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *I = Builder.CreateZExt(V, Ty); + rememberInstruction(I); + return I; +} + +Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { + Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expandCodeFor(S->getOperand(), + SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *I = Builder.CreateSExt(V, Ty); + rememberInstruction(I); + return I; +} + +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { + Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); + Type *Ty = LHS->getType(); + for (int i = S->getNumOperands()-2; i >= 0; --i) { + // In the case of mixed integer and pointer types, do the + // rest of the comparisons as integer. + Type *OpTy = S->getOperand(i)->getType(); + if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { + Ty = SE.getEffectiveSCEVType(Ty); + LHS = InsertNoopCastOfTo(LHS, Ty); + } + Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *ICmp = Builder.CreateICmpSGT(LHS, RHS); + rememberInstruction(ICmp); + Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax"); + rememberInstruction(Sel); + LHS = Sel; + } + // In the case of mixed integer and pointer types, cast the + // final result back to the pointer type. + if (LHS->getType() != S->getType()) + LHS = InsertNoopCastOfTo(LHS, S->getType()); + return LHS; +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { + Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); + Type *Ty = LHS->getType(); + for (int i = S->getNumOperands()-2; i >= 0; --i) { + // In the case of mixed integer and pointer types, do the + // rest of the comparisons as integer. + Type *OpTy = S->getOperand(i)->getType(); + if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { + Ty = SE.getEffectiveSCEVType(Ty); + LHS = InsertNoopCastOfTo(LHS, Ty); + } + Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *ICmp = Builder.CreateICmpUGT(LHS, RHS); + rememberInstruction(ICmp); + Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax"); + rememberInstruction(Sel); + LHS = Sel; + } + // In the case of mixed integer and pointer types, cast the + // final result back to the pointer type. + if (LHS->getType() != S->getType()) + LHS = InsertNoopCastOfTo(LHS, S->getType()); + return LHS; +} + +Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { + Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); + Type *Ty = LHS->getType(); + for (int i = S->getNumOperands() - 2; i >= 0; --i) { + // In the case of mixed integer and pointer types, do the + // rest of the comparisons as integer. + Type *OpTy = S->getOperand(i)->getType(); + if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { + Ty = SE.getEffectiveSCEVType(Ty); + LHS = InsertNoopCastOfTo(LHS, Ty); + } + Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *ICmp = Builder.CreateICmpSLT(LHS, RHS); + rememberInstruction(ICmp); + Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin"); + rememberInstruction(Sel); + LHS = Sel; + } + // In the case of mixed integer and pointer types, cast the + // final result back to the pointer type. + if (LHS->getType() != S->getType()) + LHS = InsertNoopCastOfTo(LHS, S->getType()); + return LHS; +} + +Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { + Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); + Type *Ty = LHS->getType(); + for (int i = S->getNumOperands() - 2; i >= 0; --i) { + // In the case of mixed integer and pointer types, do the + // rest of the comparisons as integer. + Type *OpTy = S->getOperand(i)->getType(); + if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { + Ty = SE.getEffectiveSCEVType(Ty); + LHS = InsertNoopCastOfTo(LHS, Ty); + } + Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *ICmp = Builder.CreateICmpULT(LHS, RHS); + rememberInstruction(ICmp); + Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin"); + rememberInstruction(Sel); + LHS = Sel; + } + // In the case of mixed integer and pointer types, cast the + // final result back to the pointer type. + if (LHS->getType() != S->getType()) + LHS = InsertNoopCastOfTo(LHS, S->getType()); + return LHS; +} + +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty, + Instruction *IP) { + setInsertPoint(IP); + return expandCodeFor(SH, Ty); +} + +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { + // Expand the code for this SCEV. + Value *V = expand(SH); + if (Ty) { + assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && + "non-trivial casts should be done with the SCEVs directly!"); + V = InsertNoopCastOfTo(V, Ty); + } + return V; +} + +ScalarEvolution::ValueOffsetPair +SCEVExpander::FindValueInExprValueMap(const SCEV *S, + const Instruction *InsertPt) { + SetVector<ScalarEvolution::ValueOffsetPair> *Set = SE.getSCEVValues(S); + // If the expansion is not in CanonicalMode, and the SCEV contains any + // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. + if (CanonicalMode || !SE.containsAddRecurrence(S)) { + // If S is scConstant, it may be worse to reuse an existing Value. + if (S->getSCEVType() != scConstant && Set) { + // Choose a Value from the set which dominates the insertPt. + // insertPt should be inside the Value's parent loop so as not to break + // the LCSSA form. + for (auto const &VOPair : *Set) { + Value *V = VOPair.first; + ConstantInt *Offset = VOPair.second; + Instruction *EntInst = nullptr; + if (V && isa<Instruction>(V) && (EntInst = cast<Instruction>(V)) && + S->getType() == V->getType() && + EntInst->getFunction() == InsertPt->getFunction() && + SE.DT.dominates(EntInst, InsertPt) && + (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + return {V, Offset}; + } + } + } + return {nullptr, nullptr}; +} + +// The expansion of SCEV will either reuse a previous Value in ExprValueMap, +// or expand the SCEV literally. Specifically, if the expansion is in LSRMode, +// and the SCEV contains any sub scAddRecExpr type SCEV, it will be expanded +// literally, to prevent LSR's transformed SCEV from being reverted. Otherwise, +// the expansion will try to reuse Value from ExprValueMap, and only when it +// fails, expand the SCEV literally. +Value *SCEVExpander::expand(const SCEV *S) { + // Compute an insertion point for this SCEV object. Hoist the instructions + // as far out in the loop nest as possible. + Instruction *InsertPt = &*Builder.GetInsertPoint(); + + // We can move insertion point only if there is no div or rem operations + // otherwise we are risky to move it over the check for zero denominator. + auto SafeToHoist = [](const SCEV *S) { + return !SCEVExprContains(S, [](const SCEV *S) { + if (const auto *D = dyn_cast<SCEVUDivExpr>(S)) { + if (const auto *SC = dyn_cast<SCEVConstant>(D->getRHS())) + // Division by non-zero constants can be hoisted. + return SC->getValue()->isZero(); + // All other divisions should not be moved as they may be + // divisions by zero and should be kept within the + // conditions of the surrounding loops that guard their + // execution (see PR35406). + return true; + } + return false; + }); + }; + if (SafeToHoist(S)) { + for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());; + L = L->getParentLoop()) { + if (SE.isLoopInvariant(S, L)) { + if (!L) break; + if (BasicBlock *Preheader = L->getLoopPreheader()) + InsertPt = Preheader->getTerminator(); + else + // LSR sets the insertion point for AddRec start/step values to the + // block start to simplify value reuse, even though it's an invalid + // position. SCEVExpander must correct for this in all cases. + InsertPt = &*L->getHeader()->getFirstInsertionPt(); + } else { + // If the SCEV is computable at this level, insert it into the header + // after the PHIs (and after any other instructions that we've inserted + // there) so that it is guaranteed to dominate any user inside the loop. + if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) + InsertPt = &*L->getHeader()->getFirstInsertionPt(); + while (InsertPt->getIterator() != Builder.GetInsertPoint() && + (isInsertedInstruction(InsertPt) || + isa<DbgInfoIntrinsic>(InsertPt))) + InsertPt = &*std::next(InsertPt->getIterator()); + break; + } + } + } + + // IndVarSimplify sometimes sets the insertion point at the block start, even + // when there are PHIs at that point. We must correct for this. + if (isa<PHINode>(*InsertPt)) + InsertPt = &*InsertPt->getParent()->getFirstInsertionPt(); + + // Check to see if we already expanded this here. + auto I = InsertedExpressions.find(std::make_pair(S, InsertPt)); + if (I != InsertedExpressions.end()) + return I->second; + + SCEVInsertPointGuard Guard(Builder, this); + Builder.SetInsertPoint(InsertPt); + + // Expand the expression into instructions. + ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, InsertPt); + Value *V = VO.first; + + if (!V) + V = visit(S); + else if (VO.second) { + if (PointerType *Vty = dyn_cast<PointerType>(V->getType())) { + Type *Ety = Vty->getPointerElementType(); + int64_t Offset = VO.second->getSExtValue(); + int64_t ESize = SE.getTypeSizeInBits(Ety); + if ((Offset * 8) % ESize == 0) { + ConstantInt *Idx = + ConstantInt::getSigned(VO.second->getType(), -(Offset * 8) / ESize); + V = Builder.CreateGEP(Ety, V, Idx, "scevgep"); + } else { + ConstantInt *Idx = + ConstantInt::getSigned(VO.second->getType(), -Offset); + unsigned AS = Vty->getAddressSpace(); + V = Builder.CreateBitCast(V, Type::getInt8PtrTy(SE.getContext(), AS)); + V = Builder.CreateGEP(Type::getInt8Ty(SE.getContext()), V, Idx, + "uglygep"); + V = Builder.CreateBitCast(V, Vty); + } + } else { + V = Builder.CreateSub(V, VO.second); + } + } + // Remember the expanded value for this SCEV at this location. + // + // This is independent of PostIncLoops. The mapped value simply materializes + // the expression at this insertion point. If the mapped value happened to be + // a postinc expansion, it could be reused by a non-postinc user, but only if + // its insertion point was already at the head of the loop. + InsertedExpressions[std::make_pair(S, InsertPt)] = V; + return V; +} + +void SCEVExpander::rememberInstruction(Value *I) { + if (!PostIncLoops.empty()) + InsertedPostIncValues.insert(I); + else + InsertedValues.insert(I); +} + +/// getOrInsertCanonicalInductionVariable - This method returns the +/// canonical induction variable of the specified type for the specified +/// loop (inserting one if there is none). A canonical induction variable +/// starts at zero and steps by one on each iteration. +PHINode * +SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L, + Type *Ty) { + assert(Ty->isIntegerTy() && "Can only insert integer induction variables!"); + + // Build a SCEV for {0,+,1}<L>. + // Conservatively use FlagAnyWrap for now. + const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0), + SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap); + + // Emit code for it. + SCEVInsertPointGuard Guard(Builder, this); + PHINode *V = + cast<PHINode>(expandCodeFor(H, nullptr, &L->getHeader()->front())); + + return V; +} + +/// replaceCongruentIVs - Check for congruent phis in this loop header and +/// replace them with their most canonical representative. Return the number of +/// phis eliminated. +/// +/// This does not depend on any SCEVExpander state but should be used in +/// the same context that SCEVExpander is used. +unsigned +SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, + SmallVectorImpl<WeakTrackingVH> &DeadInsts, + const TargetTransformInfo *TTI) { + // Find integer phis in order of increasing width. + SmallVector<PHINode*, 8> Phis; + for (PHINode &PN : L->getHeader()->phis()) + Phis.push_back(&PN); + + if (TTI) + llvm::sort(Phis, [](Value *LHS, Value *RHS) { + // Put pointers at the back and make sure pointer < pointer = false. + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); + return RHS->getType()->getPrimitiveSizeInBits() < + LHS->getType()->getPrimitiveSizeInBits(); + }); + + unsigned NumElim = 0; + DenseMap<const SCEV *, PHINode *> ExprToIVMap; + // Process phis from wide to narrow. Map wide phis to their truncation + // so narrow phis can reuse them. + for (PHINode *Phi : Phis) { + auto SimplifyPHINode = [&](PHINode *PN) -> Value * { + if (Value *V = SimplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC})) + return V; + if (!SE.isSCEVable(PN->getType())) + return nullptr; + auto *Const = dyn_cast<SCEVConstant>(SE.getSCEV(PN)); + if (!Const) + return nullptr; + return Const->getValue(); + }; + + // Fold constant phis. They may be congruent to other constant phis and + // would confuse the logic below that expects proper IVs. + if (Value *V = SimplifyPHINode(Phi)) { + if (V->getType() != Phi->getType()) + continue; + Phi->replaceAllUsesWith(V); + DeadInsts.emplace_back(Phi); + ++NumElim; + DEBUG_WITH_TYPE(DebugType, dbgs() + << "INDVARS: Eliminated constant iv: " << *Phi << '\n'); + continue; + } + + if (!SE.isSCEVable(Phi->getType())) + continue; + + PHINode *&OrigPhiRef = ExprToIVMap[SE.getSCEV(Phi)]; + if (!OrigPhiRef) { + OrigPhiRef = Phi; + if (Phi->getType()->isIntegerTy() && TTI && + TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { + // This phi can be freely truncated to the narrowest phi type. Map the + // truncated expression to it so it will be reused for narrow types. + const SCEV *TruncExpr = + SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType()); + ExprToIVMap[TruncExpr] = Phi; + } + continue; + } + + // Replacing a pointer phi with an integer phi or vice-versa doesn't make + // sense. + if (OrigPhiRef->getType()->isPointerTy() != Phi->getType()->isPointerTy()) + continue; + + if (BasicBlock *LatchBlock = L->getLoopLatch()) { + Instruction *OrigInc = dyn_cast<Instruction>( + OrigPhiRef->getIncomingValueForBlock(LatchBlock)); + Instruction *IsomorphicInc = + dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); + + if (OrigInc && IsomorphicInc) { + // If this phi has the same width but is more canonical, replace the + // original with it. As part of the "more canonical" determination, + // respect a prior decision to use an IV chain. + if (OrigPhiRef->getType() == Phi->getType() && + !(ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) && + (ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { + std::swap(OrigPhiRef, Phi); + std::swap(OrigInc, IsomorphicInc); + } + // Replacing the congruent phi is sufficient because acyclic + // redundancy elimination, CSE/GVN, should handle the + // rest. However, once SCEV proves that a phi is congruent, + // it's often the head of an IV user cycle that is isomorphic + // with the original phi. It's worth eagerly cleaning up the + // common case of a single IV increment so that DeleteDeadPHIs + // can remove cycles that had postinc uses. + const SCEV *TruncExpr = + SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); + if (OrigInc != IsomorphicInc && + TruncExpr == SE.getSCEV(IsomorphicInc) && + SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && + hoistIVInc(OrigInc, IsomorphicInc)) { + DEBUG_WITH_TYPE(DebugType, + dbgs() << "INDVARS: Eliminated congruent iv.inc: " + << *IsomorphicInc << '\n'); + Value *NewInc = OrigInc; + if (OrigInc->getType() != IsomorphicInc->getType()) { + Instruction *IP = nullptr; + if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) + IP = &*PN->getParent()->getFirstInsertionPt(); + else + IP = OrigInc->getNextNode(); + + IRBuilder<> Builder(IP); + Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); + NewInc = Builder.CreateTruncOrBitCast( + OrigInc, IsomorphicInc->getType(), IVName); + } + IsomorphicInc->replaceAllUsesWith(NewInc); + DeadInsts.emplace_back(IsomorphicInc); + } + } + } + DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: " + << *Phi << '\n'); + ++NumElim; + Value *NewIV = OrigPhiRef; + if (OrigPhiRef->getType() != Phi->getType()) { + IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt()); + Builder.SetCurrentDebugLocation(Phi->getDebugLoc()); + NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName); + } + Phi->replaceAllUsesWith(NewIV); + DeadInsts.emplace_back(Phi); + } + return NumElim; +} + +Value *SCEVExpander::getExactExistingExpansion(const SCEV *S, + const Instruction *At, Loop *L) { + Optional<ScalarEvolution::ValueOffsetPair> VO = + getRelatedExistingExpansion(S, At, L); + if (VO && VO.getValue().second == nullptr) + return VO.getValue().first; + return nullptr; +} + +Optional<ScalarEvolution::ValueOffsetPair> +SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, + Loop *L) { + using namespace llvm::PatternMatch; + + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Look for suitable value in simple conditions at the loop exits. + for (BasicBlock *BB : ExitingBlocks) { + ICmpInst::Predicate Pred; + Instruction *LHS, *RHS; + + if (!match(BB->getTerminator(), + m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), + m_BasicBlock(), m_BasicBlock()))) + continue; + + if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) + return ScalarEvolution::ValueOffsetPair(LHS, nullptr); + + if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) + return ScalarEvolution::ValueOffsetPair(RHS, nullptr); + } + + // Use expand's logic which is used for reusing a previous Value in + // ExprValueMap. + ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, At); + if (VO.first) + return VO; + + // There is potential to make this significantly smarter, but this simple + // heuristic already gets some interesting cases. + + // Can not find suitable value. + return None; +} + +bool SCEVExpander::isHighCostExpansionHelper( + const SCEV *S, Loop *L, const Instruction *At, + SmallPtrSetImpl<const SCEV *> &Processed) { + + // If we can find an existing value for this scev available at the point "At" + // then consider the expression cheap. + if (At && getRelatedExistingExpansion(S, At, L)) + return false; + + // Zero/One operand expressions + switch (S->getSCEVType()) { + case scUnknown: + case scConstant: + return false; + case scTruncate: + return isHighCostExpansionHelper(cast<SCEVTruncateExpr>(S)->getOperand(), + L, At, Processed); + case scZeroExtend: + return isHighCostExpansionHelper(cast<SCEVZeroExtendExpr>(S)->getOperand(), + L, At, Processed); + case scSignExtend: + return isHighCostExpansionHelper(cast<SCEVSignExtendExpr>(S)->getOperand(), + L, At, Processed); + } + + if (!Processed.insert(S).second) + return false; + + if (auto *UDivExpr = dyn_cast<SCEVUDivExpr>(S)) { + // If the divisor is a power of two and the SCEV type fits in a native + // integer (and the LHS not expensive), consider the division cheap + // irrespective of whether it occurs in the user code since it can be + // lowered into a right shift. + if (auto *SC = dyn_cast<SCEVConstant>(UDivExpr->getRHS())) + if (SC->getAPInt().isPowerOf2()) { + if (isHighCostExpansionHelper(UDivExpr->getLHS(), L, At, Processed)) + return true; + const DataLayout &DL = + L->getHeader()->getParent()->getParent()->getDataLayout(); + unsigned Width = cast<IntegerType>(UDivExpr->getType())->getBitWidth(); + return DL.isIllegalInteger(Width); + } + + // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or + // HowManyLessThans produced to compute a precise expression, rather than a + // UDiv from the user's code. If we can't find a UDiv in the code with some + // simple searching, assume the former consider UDivExpr expensive to + // compute. + BasicBlock *ExitingBB = L->getExitingBlock(); + if (!ExitingBB) + return true; + + // At the beginning of this function we already tried to find existing value + // for plain 'S'. Now try to lookup 'S + 1' since it is common pattern + // involving division. This is just a simple search heuristic. + if (!At) + At = &ExitingBB->back(); + if (!getRelatedExistingExpansion( + SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L)) + return true; + } + + // HowManyLessThans uses a Max expression whenever the loop is not guarded by + // the exit condition. + if (isa<SCEVMinMaxExpr>(S)) + return true; + + // Recurse past nary expressions, which commonly occur in the + // BackedgeTakenCount. They may already exist in program code, and if not, + // they are not too expensive rematerialize. + if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) { + for (auto *Op : NAry->operands()) + if (isHighCostExpansionHelper(Op, L, At, Processed)) + return true; + } + + // If we haven't recognized an expensive SCEV pattern, assume it's an + // expression produced by program code. + return false; +} + +Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, + Instruction *IP) { + assert(IP); + switch (Pred->getKind()) { + case SCEVPredicate::P_Union: + return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); + case SCEVPredicate::P_Equal: + return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); + case SCEVPredicate::P_Wrap: { + auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); + return expandWrapPredicate(AddRecPred, IP); + } + } + llvm_unreachable("Unknown SCEV predicate type"); +} + +Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, + Instruction *IP) { + Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP); + Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP); + + Builder.SetInsertPoint(IP); + auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + return I; +} + +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + SCEVUnionPredicate Pred; + const SCEV *ExitCount = + SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + Type *ARTy = AR->getType(); + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned DstBits = SE.getTypeSizeInBits(ARTy); + + // The expression {Start,+,Step} has nusw/nssw if + // Step < 0, Start - |Step| * Backedge <= Start + // Step >= 0, Start + |Step| * Backedge > Start + // and |Step| * Backedge doesn't unsigned overflow. + + IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); + Builder.SetInsertPoint(Loc); + Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); + + IntegerType *Ty = + IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); + Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty; + + Value *StepValue = expandCodeFor(Step, Ty, Loc); + Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); + Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc); + + ConstantInt *Zero = + ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); + + Builder.SetInsertPoint(Loc); + // Compute |Step| + Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); + Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); + + // Get the backedge taken count and truncate or extended to the AR type. + Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); + auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty); + + // Compute |Step| * Backedge + CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); + Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + + // Compute: + // Start + |Step| * Backedge < Start + // Start - |Step| * Backedge > Start + Value *Add = nullptr, *Sub = nullptr; + if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARExpandTy)) { + const SCEV *MulS = SE.getSCEV(MulV); + const SCEV *NegMulS = SE.getNegativeSCEV(MulS); + Add = Builder.CreateBitCast(expandAddToGEP(MulS, ARPtrTy, Ty, StartValue), + ARPtrTy); + Sub = Builder.CreateBitCast( + expandAddToGEP(NegMulS, ARPtrTy, Ty, StartValue), ARPtrTy); + } else { + Add = Builder.CreateAdd(StartValue, MulV); + Sub = Builder.CreateSub(StartValue, MulV); + } + + Value *EndCompareGT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + + Value *EndCompareLT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); + + // Select the answer based on the sign of Step. + Value *EndCheck = + Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + + // If the backedge taken count type is larger than the AR type, + // check that we don't drop any bits by truncating it. If we are + // dropping bits, then we have overflow (unless the step is zero). + if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { + auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); + auto *BackedgeCheck = + Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, + ConstantInt::get(Loc->getContext(), MaxVal)); + BackedgeCheck = Builder.CreateAnd( + BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); + + EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); + } + + EndCheck = Builder.CreateOr(EndCheck, OfMul); + return EndCheck; +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr()); + Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + + // Add a check for NUSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) + NUSWCheck = generateOverflowCheck(A, IP, false); + + // Add a check for NSSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) + NSSWCheck = generateOverflowCheck(A, IP, true); + + if (NUSWCheck && NSSWCheck) + return Builder.CreateOr(NUSWCheck, NSSWCheck); + + if (NUSWCheck) + return NUSWCheck; + + if (NSSWCheck) + return NSSWCheck; + + return ConstantInt::getFalse(IP->getContext()); +} + +Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, + Instruction *IP) { + auto *BoolType = IntegerType::get(IP->getContext(), 1); + Value *Check = ConstantInt::getNullValue(BoolType); + + // Loop over all checks in this set. + for (auto Pred : Union->getPredicates()) { + auto *NextCheck = expandCodeForPredicate(Pred, IP); + Builder.SetInsertPoint(IP); + Check = Builder.CreateOr(Check, NextCheck); + } + + return Check; +} + +namespace { +// Search for a SCEV subexpression that is not safe to expand. Any expression +// that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely +// UDiv expressions. We don't know if the UDiv is derived from an IR divide +// instruction, but the important thing is that we prove the denominator is +// nonzero before expansion. +// +// IVUsers already checks that IV-derived expressions are safe. So this check is +// only needed when the expression includes some subexpression that is not IV +// derived. +// +// Currently, we only allow division by a nonzero constant here. If this is +// inadequate, we could easily allow division by SCEVUnknown by using +// ValueTracking to check isKnownNonZero(). +// +// We cannot generally expand recurrences unless the step dominates the loop +// header. The expander handles the special case of affine recurrences by +// scaling the recurrence outside the loop, but this technique isn't generally +// applicable. Expanding a nested recurrence outside a loop requires computing +// binomial coefficients. This could be done, but the recurrence has to be in a +// perfectly reduced form, which can't be guaranteed. +struct SCEVFindUnsafe { + ScalarEvolution &SE; + bool IsUnsafe; + + SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {} + + bool follow(const SCEV *S) { + if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { + const SCEVConstant *SC = dyn_cast<SCEVConstant>(D->getRHS()); + if (!SC || SC->getValue()->isZero()) { + IsUnsafe = true; + return false; + } + } + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEV *Step = AR->getStepRecurrence(SE); + if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) { + IsUnsafe = true; + return false; + } + } + return true; + } + bool isDone() const { return IsUnsafe; } +}; +} + +namespace llvm { +bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) { + SCEVFindUnsafe Search(SE); + visitAll(S, Search); + return !Search.IsUnsafe; +} + +bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, + ScalarEvolution &SE) { + if (!isSafeToExpand(S, SE)) + return false; + // We have to prove that the expanded site of S dominates InsertionPoint. + // This is easy when not in the same block, but hard when S is an instruction + // to be expanded somewhere inside the same block as our insertion point. + // What we really need here is something analogous to an OrderedBasicBlock, + // but for the moment, we paper over the problem by handling two common and + // cheap to check cases. + if (SE.properlyDominates(S, InsertionPoint->getParent())) + return true; + if (SE.dominates(S, InsertionPoint->getParent())) { + if (InsertionPoint->getParent()->getTerminator() == InsertionPoint) + return true; + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) + for (const Value *V : InsertionPoint->operand_values()) + if (V == U->getValue()) + return true; + } + return false; +} +} diff --git a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp new file mode 100644 index 000000000000..209ae66ca53e --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -0,0 +1,117 @@ +//===- ScalarEvolutionNormalization.cpp - See below -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for working with "normalized" expressions. +// See the comments at the top of ScalarEvolutionNormalization.h for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionNormalization.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +using namespace llvm; + +/// TransformKind - Different types of transformations that +/// TransformForPostIncUse can do. +enum TransformKind { + /// Normalize - Normalize according to the given loops. + Normalize, + /// Denormalize - Perform the inverse transform on the expression with the + /// given loop set. + Denormalize +}; + +namespace { +struct NormalizeDenormalizeRewriter + : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { + const TransformKind Kind; + + // NB! Pred is a function_ref. Storing it here is okay only because + // we're careful about the lifetime of NormalizeDenormalizeRewriter. + const NormalizePredTy Pred; + + NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, + ScalarEvolution &SE) + : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), + Pred(Pred) {} + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); +}; +} // namespace + +const SCEV * +NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { + SmallVector<const SCEV *, 8> Operands; + + transform(AR->operands(), std::back_inserter(Operands), + [&](const SCEV *Op) { return visit(Op); }); + + if (!Pred(AR)) + return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); + + // Normalization and denormalization are fancy names for decrementing and + // incrementing a SCEV expression with respect to a set of loops. Since + // Pred(AR) has returned true, we know we need to normalize or denormalize AR + // with respect to its loop. + + if (Kind == Denormalize) { + // Denormalization / "partial increment" is essentially the same as \c + // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the + // symmetry with Normalization clear. + for (int i = 0, e = Operands.size() - 1; i < e; i++) + Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); + } else { + assert(Kind == Normalize && "Only two possibilities!"); + + // Normalization / "partial decrement" is a bit more subtle. Since + // incrementing a SCEV expression (in general) changes the step of the SCEV + // expression as well, we cannot use the step of the current expression. + // Instead, we have to use the step of the very expression we're trying to + // compute! + // + // We solve the issue by recursively building up the result, starting from + // the "least significant" operand in the add recurrence: + // + // Base case: + // Single operand add recurrence. It's its own normalization. + // + // N-operand case: + // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S + // + // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its + // normalization by induction. We subtract the normalized step + // recurrence from S_{N-1} to get the normalization of S. + + for (int i = Operands.size() - 2; i >= 0; i--) + Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); + } + + return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); +} + +const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, + const PostIncLoopSet &Loops, + ScalarEvolution &SE) { + auto Pred = [&](const SCEVAddRecExpr *AR) { + return Loops.count(AR->getLoop()); + }; + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, + ScalarEvolution &SE) { + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, + const PostIncLoopSet &Loops, + ScalarEvolution &SE) { + auto Pred = [&](const SCEVAddRecExpr *AR) { + return Loops.count(AR->getLoop()); + }; + return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); +} diff --git a/llvm/lib/Analysis/ScopedNoAliasAA.cpp b/llvm/lib/Analysis/ScopedNoAliasAA.cpp new file mode 100644 index 000000000000..094e4a3d5dc8 --- /dev/null +++ b/llvm/lib/Analysis/ScopedNoAliasAA.cpp @@ -0,0 +1,210 @@ +//===- ScopedNoAliasAA.cpp - Scoped No-Alias Alias Analysis ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScopedNoAlias alias-analysis pass, which implements +// metadata-based scoped no-alias support. +// +// Alias-analysis scopes are defined by an id (which can be a string or some +// other metadata node), a domain node, and an optional descriptive string. +// A domain is defined by an id (which can be a string or some other metadata +// node), and an optional descriptive string. +// +// !dom0 = metadata !{ metadata !"domain of foo()" } +// !scope1 = metadata !{ metadata !scope1, metadata !dom0, metadata !"scope 1" } +// !scope2 = metadata !{ metadata !scope2, metadata !dom0, metadata !"scope 2" } +// +// Loads and stores can be tagged with an alias-analysis scope, and also, with +// a noalias tag for a specific scope: +// +// ... = load %ptr1, !alias.scope !{ !scope1 } +// ... = load %ptr2, !alias.scope !{ !scope1, !scope2 }, !noalias !{ !scope1 } +// +// When evaluating an aliasing query, if one of the instructions is associated +// has a set of noalias scopes in some domain that is a superset of the alias +// scopes in that domain of some other instruction, then the two memory +// accesses are assumed not to alias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +// A handy option for disabling scoped no-alias functionality. The same effect +// can also be achieved by stripping the associated metadata tags from IR, but +// this option is sometimes more convenient. +static cl::opt<bool> EnableScopedNoAlias("enable-scoped-noalias", + cl::init(true), cl::Hidden); + +namespace { + +/// This is a simple wrapper around an MDNode which provides a higher-level +/// interface by hiding the details of how alias analysis information is encoded +/// in its operands. +class AliasScopeNode { + const MDNode *Node = nullptr; + +public: + AliasScopeNode() = default; + explicit AliasScopeNode(const MDNode *N) : Node(N) {} + + /// Get the MDNode for this AliasScopeNode. + const MDNode *getNode() const { return Node; } + + /// Get the MDNode for this AliasScopeNode's domain. + const MDNode *getDomain() const { + if (Node->getNumOperands() < 2) + return nullptr; + return dyn_cast_or_null<MDNode>(Node->getOperand(1)); + } +}; + +} // end anonymous namespace + +AliasResult ScopedNoAliasAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + if (!EnableScopedNoAlias) + return AAResultBase::alias(LocA, LocB, AAQI); + + // Get the attached MDNodes. + const MDNode *AScopes = LocA.AATags.Scope, *BScopes = LocB.AATags.Scope; + + const MDNode *ANoAlias = LocA.AATags.NoAlias, *BNoAlias = LocB.AATags.NoAlias; + + if (!mayAliasInScopes(AScopes, BNoAlias)) + return NoAlias; + + if (!mayAliasInScopes(BScopes, ANoAlias)) + return NoAlias; + + // If they may alias, chain to the next AliasAnalysis. + return AAResultBase::alias(LocA, LocB, AAQI); +} + +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (!EnableScopedNoAlias) + return AAResultBase::getModRefInfo(Call, Loc, AAQI); + + if (!mayAliasInScopes(Loc.AATags.Scope, + Call->getMetadata(LLVMContext::MD_noalias))) + return ModRefInfo::NoModRef; + + if (!mayAliasInScopes(Call->getMetadata(LLVMContext::MD_alias_scope), + Loc.AATags.NoAlias)) + return ModRefInfo::NoModRef; + + return AAResultBase::getModRefInfo(Call, Loc, AAQI); +} + +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2, + AAQueryInfo &AAQI) { + if (!EnableScopedNoAlias) + return AAResultBase::getModRefInfo(Call1, Call2, AAQI); + + if (!mayAliasInScopes(Call1->getMetadata(LLVMContext::MD_alias_scope), + Call2->getMetadata(LLVMContext::MD_noalias))) + return ModRefInfo::NoModRef; + + if (!mayAliasInScopes(Call2->getMetadata(LLVMContext::MD_alias_scope), + Call1->getMetadata(LLVMContext::MD_noalias))) + return ModRefInfo::NoModRef; + + return AAResultBase::getModRefInfo(Call1, Call2, AAQI); +} + +static void collectMDInDomain(const MDNode *List, const MDNode *Domain, + SmallPtrSetImpl<const MDNode *> &Nodes) { + for (const MDOperand &MDOp : List->operands()) + if (const MDNode *MD = dyn_cast<MDNode>(MDOp)) + if (AliasScopeNode(MD).getDomain() == Domain) + Nodes.insert(MD); +} + +bool ScopedNoAliasAAResult::mayAliasInScopes(const MDNode *Scopes, + const MDNode *NoAlias) const { + if (!Scopes || !NoAlias) + return true; + + // Collect the set of scope domains relevant to the noalias scopes. + SmallPtrSet<const MDNode *, 16> Domains; + for (const MDOperand &MDOp : NoAlias->operands()) + if (const MDNode *NAMD = dyn_cast<MDNode>(MDOp)) + if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain()) + Domains.insert(Domain); + + // We alias unless, for some domain, the set of noalias scopes in that domain + // is a superset of the set of alias scopes in that domain. + for (const MDNode *Domain : Domains) { + SmallPtrSet<const MDNode *, 16> ScopeNodes; + collectMDInDomain(Scopes, Domain, ScopeNodes); + if (ScopeNodes.empty()) + continue; + + SmallPtrSet<const MDNode *, 16> NANodes; + collectMDInDomain(NoAlias, Domain, NANodes); + + // To not alias, all of the nodes in ScopeNodes must be in NANodes. + bool FoundAll = true; + for (const MDNode *SMD : ScopeNodes) + if (!NANodes.count(SMD)) { + FoundAll = false; + break; + } + + if (FoundAll) + return false; + } + + return true; +} + +AnalysisKey ScopedNoAliasAA::Key; + +ScopedNoAliasAAResult ScopedNoAliasAA::run(Function &F, + FunctionAnalysisManager &AM) { + return ScopedNoAliasAAResult(); +} + +char ScopedNoAliasAAWrapperPass::ID = 0; + +INITIALIZE_PASS(ScopedNoAliasAAWrapperPass, "scoped-noalias", + "Scoped NoAlias Alias Analysis", false, true) + +ImmutablePass *llvm::createScopedNoAliasAAWrapperPass() { + return new ScopedNoAliasAAWrapperPass(); +} + +ScopedNoAliasAAWrapperPass::ScopedNoAliasAAWrapperPass() : ImmutablePass(ID) { + initializeScopedNoAliasAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ScopedNoAliasAAWrapperPass::doInitialization(Module &M) { + Result.reset(new ScopedNoAliasAAResult()); + return false; +} + +bool ScopedNoAliasAAWrapperPass::doFinalization(Module &M) { + Result.reset(); + return false; +} + +void ScopedNoAliasAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp new file mode 100644 index 000000000000..1b3638698950 --- /dev/null +++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp @@ -0,0 +1,676 @@ +//===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "stack-safety" + +static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations", + cl::init(20), cl::Hidden); + +namespace { + +/// Rewrite an SCEV expression for a memory access address to an expression that +/// represents offset from the given alloca. +class AllocaOffsetRewriter : public SCEVRewriteVisitor<AllocaOffsetRewriter> { + const Value *AllocaPtr; + +public: + AllocaOffsetRewriter(ScalarEvolution &SE, const Value *AllocaPtr) + : SCEVRewriteVisitor(SE), AllocaPtr(AllocaPtr) {} + + const SCEV *visit(const SCEV *Expr) { + // Only re-write the expression if the alloca is used in an addition + // expression (it can be used in other types of expressions if it's cast to + // an int and passed as an argument.) + if (!isa<SCEVAddRecExpr>(Expr) && !isa<SCEVAddExpr>(Expr) && + !isa<SCEVUnknown>(Expr)) + return Expr; + return SCEVRewriteVisitor<AllocaOffsetRewriter>::visit(Expr); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // FIXME: look through one or several levels of definitions? + // This can be inttoptr(AllocaPtr) and SCEV would not unwrap + // it for us. + if (Expr->getValue() == AllocaPtr) + return SE.getZero(Expr->getType()); + return Expr; + } +}; + +/// Describes use of address in as a function call argument. +struct PassAsArgInfo { + /// Function being called. + const GlobalValue *Callee = nullptr; + /// Index of argument which pass address. + size_t ParamNo = 0; + // Offset range of address from base address (alloca or calling function + // argument). + // Range should never set to empty-set, that is an invalid access range + // that can cause empty-set to be propagated with ConstantRange::add + ConstantRange Offset; + PassAsArgInfo(const GlobalValue *Callee, size_t ParamNo, ConstantRange Offset) + : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {} + + StringRef getName() const { return Callee->getName(); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const PassAsArgInfo &P) { + return OS << "@" << P.getName() << "(arg" << P.ParamNo << ", " << P.Offset + << ")"; +} + +/// Describe uses of address (alloca or parameter) inside of the function. +struct UseInfo { + // Access range if the address (alloca or parameters). + // It is allowed to be empty-set when there are no known accesses. + ConstantRange Range; + + // List of calls which pass address as an argument. + SmallVector<PassAsArgInfo, 4> Calls; + + explicit UseInfo(unsigned PointerSize) : Range{PointerSize, false} {} + + void updateRange(ConstantRange R) { Range = Range.unionWith(R); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const UseInfo &U) { + OS << U.Range; + for (auto &Call : U.Calls) + OS << ", " << Call; + return OS; +} + +struct AllocaInfo { + const AllocaInst *AI = nullptr; + uint64_t Size = 0; + UseInfo Use; + + AllocaInfo(unsigned PointerSize, const AllocaInst *AI, uint64_t Size) + : AI(AI), Size(Size), Use(PointerSize) {} + + StringRef getName() const { return AI->getName(); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const AllocaInfo &A) { + return OS << A.getName() << "[" << A.Size << "]: " << A.Use; +} + +struct ParamInfo { + const Argument *Arg = nullptr; + UseInfo Use; + + explicit ParamInfo(unsigned PointerSize, const Argument *Arg) + : Arg(Arg), Use(PointerSize) {} + + StringRef getName() const { return Arg ? Arg->getName() : "<N/A>"; } +}; + +raw_ostream &operator<<(raw_ostream &OS, const ParamInfo &P) { + return OS << P.getName() << "[]: " << P.Use; +} + +/// Calculate the allocation size of a given alloca. Returns 0 if the +/// size can not be statically determined. +uint64_t getStaticAllocaAllocationSize(const AllocaInst *AI) { + const DataLayout &DL = AI->getModule()->getDataLayout(); + uint64_t Size = DL.getTypeAllocSize(AI->getAllocatedType()); + if (AI->isArrayAllocation()) { + auto C = dyn_cast<ConstantInt>(AI->getArraySize()); + if (!C) + return 0; + Size *= C->getZExtValue(); + } + return Size; +} + +} // end anonymous namespace + +/// Describes uses of allocas and parameters inside of a single function. +struct StackSafetyInfo::FunctionInfo { + // May be a Function or a GlobalAlias + const GlobalValue *GV = nullptr; + // Informations about allocas uses. + SmallVector<AllocaInfo, 4> Allocas; + // Informations about parameters uses. + SmallVector<ParamInfo, 4> Params; + // TODO: describe return value as depending on one or more of its arguments. + + // StackSafetyDataFlowAnalysis counter stored here for faster access. + int UpdateCount = 0; + + FunctionInfo(const StackSafetyInfo &SSI) : FunctionInfo(*SSI.Info) {} + + explicit FunctionInfo(const Function *F) : GV(F){}; + // Creates FunctionInfo that forwards all the parameters to the aliasee. + explicit FunctionInfo(const GlobalAlias *A); + + FunctionInfo(FunctionInfo &&) = default; + + bool IsDSOLocal() const { return GV->isDSOLocal(); }; + + bool IsInterposable() const { return GV->isInterposable(); }; + + StringRef getName() const { return GV->getName(); } + + void print(raw_ostream &O) const { + // TODO: Consider different printout format after + // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then. + O << " @" << getName() << (IsDSOLocal() ? "" : " dso_preemptable") + << (IsInterposable() ? " interposable" : "") << "\n"; + O << " args uses:\n"; + for (auto &P : Params) + O << " " << P << "\n"; + O << " allocas uses:\n"; + for (auto &AS : Allocas) + O << " " << AS << "\n"; + } + +private: + FunctionInfo(const FunctionInfo &) = default; +}; + +StackSafetyInfo::FunctionInfo::FunctionInfo(const GlobalAlias *A) : GV(A) { + unsigned PointerSize = A->getParent()->getDataLayout().getPointerSizeInBits(); + const GlobalObject *Aliasee = A->getBaseObject(); + const FunctionType *Type = cast<FunctionType>(Aliasee->getValueType()); + // 'Forward' all parameters to this alias to the aliasee + for (unsigned ArgNo = 0; ArgNo < Type->getNumParams(); ArgNo++) { + Params.emplace_back(PointerSize, nullptr); + UseInfo &US = Params.back().Use; + US.Calls.emplace_back(Aliasee, ArgNo, ConstantRange(APInt(PointerSize, 0))); + } +} + +namespace { + +class StackSafetyLocalAnalysis { + const Function &F; + const DataLayout &DL; + ScalarEvolution &SE; + unsigned PointerSize = 0; + + const ConstantRange UnknownRange; + + ConstantRange offsetFromAlloca(Value *Addr, const Value *AllocaPtr); + ConstantRange getAccessRange(Value *Addr, const Value *AllocaPtr, + uint64_t AccessSize); + ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U, + const Value *AllocaPtr); + + bool analyzeAllUses(const Value *Ptr, UseInfo &AS); + + ConstantRange getRange(uint64_t Lower, uint64_t Upper) const { + return ConstantRange(APInt(PointerSize, Lower), APInt(PointerSize, Upper)); + } + +public: + StackSafetyLocalAnalysis(const Function &F, ScalarEvolution &SE) + : F(F), DL(F.getParent()->getDataLayout()), SE(SE), + PointerSize(DL.getPointerSizeInBits()), + UnknownRange(PointerSize, true) {} + + // Run the transformation on the associated function. + StackSafetyInfo run(); +}; + +ConstantRange +StackSafetyLocalAnalysis::offsetFromAlloca(Value *Addr, + const Value *AllocaPtr) { + if (!SE.isSCEVable(Addr->getType())) + return UnknownRange; + + AllocaOffsetRewriter Rewriter(SE, AllocaPtr); + const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr)); + ConstantRange Offset = SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize); + assert(!Offset.isEmptySet()); + return Offset; +} + +ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, + const Value *AllocaPtr, + uint64_t AccessSize) { + if (!SE.isSCEVable(Addr->getType())) + return UnknownRange; + + AllocaOffsetRewriter Rewriter(SE, AllocaPtr); + const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr)); + + ConstantRange AccessStartRange = + SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize); + ConstantRange SizeRange = getRange(0, AccessSize); + ConstantRange AccessRange = AccessStartRange.add(SizeRange); + assert(!AccessRange.isEmptySet()); + return AccessRange; +} + +ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange( + const MemIntrinsic *MI, const Use &U, const Value *AllocaPtr) { + if (auto MTI = dyn_cast<MemTransferInst>(MI)) { + if (MTI->getRawSource() != U && MTI->getRawDest() != U) + return getRange(0, 1); + } else { + if (MI->getRawDest() != U) + return getRange(0, 1); + } + const auto *Len = dyn_cast<ConstantInt>(MI->getLength()); + // Non-constant size => unsafe. FIXME: try SCEV getRange. + if (!Len) + return UnknownRange; + ConstantRange AccessRange = getAccessRange(U, AllocaPtr, Len->getZExtValue()); + return AccessRange; +} + +/// The function analyzes all local uses of Ptr (alloca or argument) and +/// calculates local access range and all function calls where it was used. +bool StackSafetyLocalAnalysis::analyzeAllUses(const Value *Ptr, UseInfo &US) { + SmallPtrSet<const Value *, 16> Visited; + SmallVector<const Value *, 8> WorkList; + WorkList.push_back(Ptr); + + // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc. + while (!WorkList.empty()) { + const Value *V = WorkList.pop_back_val(); + for (const Use &UI : V->uses()) { + auto I = cast<const Instruction>(UI.getUser()); + assert(V == UI.get()); + + switch (I->getOpcode()) { + case Instruction::Load: { + US.updateRange( + getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType()))); + break; + } + + case Instruction::VAArg: + // "va-arg" from a pointer is safe. + break; + case Instruction::Store: { + if (V == I->getOperand(0)) { + // Stored the pointer - conservatively assume it may be unsafe. + US.updateRange(UnknownRange); + return false; + } + US.updateRange(getAccessRange( + UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType()))); + break; + } + + case Instruction::Ret: + // Information leak. + // FIXME: Process parameters correctly. This is a leak only if we return + // alloca. + US.updateRange(UnknownRange); + return false; + + case Instruction::Call: + case Instruction::Invoke: { + ImmutableCallSite CS(I); + + if (I->isLifetimeStartOrEnd()) + break; + + if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { + US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr)); + break; + } + + // FIXME: consult devirt? + // Do not follow aliases, otherwise we could inadvertently follow + // dso_preemptable aliases or aliases with interposable linkage. + const GlobalValue *Callee = + dyn_cast<GlobalValue>(CS.getCalledValue()->stripPointerCasts()); + if (!Callee) { + US.updateRange(UnknownRange); + return false; + } + + assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee)); + + ImmutableCallSite::arg_iterator B = CS.arg_begin(), E = CS.arg_end(); + for (ImmutableCallSite::arg_iterator A = B; A != E; ++A) { + if (A->get() == V) { + ConstantRange Offset = offsetFromAlloca(UI, Ptr); + US.Calls.emplace_back(Callee, A - B, Offset); + } + } + + break; + } + + default: + if (Visited.insert(I).second) + WorkList.push_back(cast<const Instruction>(I)); + } + } + } + + return true; +} + +StackSafetyInfo StackSafetyLocalAnalysis::run() { + StackSafetyInfo::FunctionInfo Info(&F); + assert(!F.isDeclaration() && + "Can't run StackSafety on a function declaration"); + + LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n"); + + for (auto &I : instructions(F)) { + if (auto AI = dyn_cast<AllocaInst>(&I)) { + Info.Allocas.emplace_back(PointerSize, AI, + getStaticAllocaAllocationSize(AI)); + AllocaInfo &AS = Info.Allocas.back(); + analyzeAllUses(AI, AS.Use); + } + } + + for (const Argument &A : make_range(F.arg_begin(), F.arg_end())) { + Info.Params.emplace_back(PointerSize, &A); + ParamInfo &PS = Info.Params.back(); + analyzeAllUses(&A, PS.Use); + } + + LLVM_DEBUG(dbgs() << "[StackSafety] done\n"); + LLVM_DEBUG(Info.print(dbgs())); + return StackSafetyInfo(std::move(Info)); +} + +class StackSafetyDataFlowAnalysis { + using FunctionMap = + std::map<const GlobalValue *, StackSafetyInfo::FunctionInfo>; + + FunctionMap Functions; + // Callee-to-Caller multimap. + DenseMap<const GlobalValue *, SmallVector<const GlobalValue *, 4>> Callers; + SetVector<const GlobalValue *> WorkList; + + unsigned PointerSize = 0; + const ConstantRange UnknownRange; + + ConstantRange getArgumentAccessRange(const GlobalValue *Callee, + unsigned ParamNo) const; + bool updateOneUse(UseInfo &US, bool UpdateToFullSet); + void updateOneNode(const GlobalValue *Callee, + StackSafetyInfo::FunctionInfo &FS); + void updateOneNode(const GlobalValue *Callee) { + updateOneNode(Callee, Functions.find(Callee)->second); + } + void updateAllNodes() { + for (auto &F : Functions) + updateOneNode(F.first, F.second); + } + void runDataFlow(); +#ifndef NDEBUG + void verifyFixedPoint(); +#endif + +public: + StackSafetyDataFlowAnalysis( + Module &M, std::function<const StackSafetyInfo &(Function &)> FI); + StackSafetyGlobalInfo run(); +}; + +StackSafetyDataFlowAnalysis::StackSafetyDataFlowAnalysis( + Module &M, std::function<const StackSafetyInfo &(Function &)> FI) + : PointerSize(M.getDataLayout().getPointerSizeInBits()), + UnknownRange(PointerSize, true) { + // Without ThinLTO, run the local analysis for every function in the TU and + // then run the DFA. + for (auto &F : M.functions()) + if (!F.isDeclaration()) + Functions.emplace(&F, FI(F)); + for (auto &A : M.aliases()) + if (isa<Function>(A.getBaseObject())) + Functions.emplace(&A, StackSafetyInfo::FunctionInfo(&A)); +} + +ConstantRange +StackSafetyDataFlowAnalysis::getArgumentAccessRange(const GlobalValue *Callee, + unsigned ParamNo) const { + auto IT = Functions.find(Callee); + // Unknown callee (outside of LTO domain or an indirect call). + if (IT == Functions.end()) + return UnknownRange; + const StackSafetyInfo::FunctionInfo &FS = IT->second; + // The definition of this symbol may not be the definition in this linkage + // unit. + if (!FS.IsDSOLocal() || FS.IsInterposable()) + return UnknownRange; + if (ParamNo >= FS.Params.size()) // possibly vararg + return UnknownRange; + return FS.Params[ParamNo].Use.Range; +} + +bool StackSafetyDataFlowAnalysis::updateOneUse(UseInfo &US, + bool UpdateToFullSet) { + bool Changed = false; + for (auto &CS : US.Calls) { + assert(!CS.Offset.isEmptySet() && + "Param range can't be empty-set, invalid offset range"); + + ConstantRange CalleeRange = getArgumentAccessRange(CS.Callee, CS.ParamNo); + CalleeRange = CalleeRange.add(CS.Offset); + if (!US.Range.contains(CalleeRange)) { + Changed = true; + if (UpdateToFullSet) + US.Range = UnknownRange; + else + US.Range = US.Range.unionWith(CalleeRange); + } + } + return Changed; +} + +void StackSafetyDataFlowAnalysis::updateOneNode( + const GlobalValue *Callee, StackSafetyInfo::FunctionInfo &FS) { + bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations; + bool Changed = false; + for (auto &AS : FS.Allocas) + Changed |= updateOneUse(AS.Use, UpdateToFullSet); + for (auto &PS : FS.Params) + Changed |= updateOneUse(PS.Use, UpdateToFullSet); + + if (Changed) { + LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount + << (UpdateToFullSet ? ", full-set" : "") << "] " + << FS.getName() << "\n"); + // Callers of this function may need updating. + for (auto &CallerID : Callers[Callee]) + WorkList.insert(CallerID); + + ++FS.UpdateCount; + } +} + +void StackSafetyDataFlowAnalysis::runDataFlow() { + Callers.clear(); + WorkList.clear(); + + SmallVector<const GlobalValue *, 16> Callees; + for (auto &F : Functions) { + Callees.clear(); + StackSafetyInfo::FunctionInfo &FS = F.second; + for (auto &AS : FS.Allocas) + for (auto &CS : AS.Use.Calls) + Callees.push_back(CS.Callee); + for (auto &PS : FS.Params) + for (auto &CS : PS.Use.Calls) + Callees.push_back(CS.Callee); + + llvm::sort(Callees); + Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end()); + + for (auto &Callee : Callees) + Callers[Callee].push_back(F.first); + } + + updateAllNodes(); + + while (!WorkList.empty()) { + const GlobalValue *Callee = WorkList.back(); + WorkList.pop_back(); + updateOneNode(Callee); + } +} + +#ifndef NDEBUG +void StackSafetyDataFlowAnalysis::verifyFixedPoint() { + WorkList.clear(); + updateAllNodes(); + assert(WorkList.empty()); +} +#endif + +StackSafetyGlobalInfo StackSafetyDataFlowAnalysis::run() { + runDataFlow(); + LLVM_DEBUG(verifyFixedPoint()); + + StackSafetyGlobalInfo SSI; + for (auto &F : Functions) + SSI.emplace(F.first, std::move(F.second)); + return SSI; +} + +void print(const StackSafetyGlobalInfo &SSI, raw_ostream &O, const Module &M) { + size_t Count = 0; + for (auto &F : M.functions()) + if (!F.isDeclaration()) { + SSI.find(&F)->second.print(O); + O << "\n"; + ++Count; + } + for (auto &A : M.aliases()) { + SSI.find(&A)->second.print(O); + O << "\n"; + ++Count; + } + assert(Count == SSI.size() && "Unexpected functions in the result"); +} + +} // end anonymous namespace + +StackSafetyInfo::StackSafetyInfo() = default; +StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default; +StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default; + +StackSafetyInfo::StackSafetyInfo(FunctionInfo &&Info) + : Info(new FunctionInfo(std::move(Info))) {} + +StackSafetyInfo::~StackSafetyInfo() = default; + +void StackSafetyInfo::print(raw_ostream &O) const { Info->print(O); } + +AnalysisKey StackSafetyAnalysis::Key; + +StackSafetyInfo StackSafetyAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + StackSafetyLocalAnalysis SSLA(F, AM.getResult<ScalarEvolutionAnalysis>(F)); + return SSLA.run(); +} + +PreservedAnalyses StackSafetyPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n"; + AM.getResult<StackSafetyAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} + +char StackSafetyInfoWrapperPass::ID = 0; + +StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) { + initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.setPreservesAll(); +} + +void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const { + SSI.print(O); +} + +bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) { + StackSafetyLocalAnalysis SSLA( + F, getAnalysis<ScalarEvolutionWrapperPass>().getSE()); + SSI = StackSafetyInfo(SSLA.run()); + return false; +} + +AnalysisKey StackSafetyGlobalAnalysis::Key; + +StackSafetyGlobalInfo +StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + StackSafetyDataFlowAnalysis SSDFA( + M, [&FAM](Function &F) -> const StackSafetyInfo & { + return FAM.getResult<StackSafetyAnalysis>(F); + }); + return SSDFA.run(); +} + +PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n"; + print(AM.getResult<StackSafetyGlobalAnalysis>(M), OS, M); + return PreservedAnalyses::all(); +} + +char StackSafetyGlobalInfoWrapperPass::ID = 0; + +StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass() + : ModulePass(ID) { + initializeStackSafetyGlobalInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O, + const Module *M) const { + ::print(SSI, O, *M); +} + +void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.addRequired<StackSafetyInfoWrapperPass>(); +} + +bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) { + StackSafetyDataFlowAnalysis SSDFA( + M, [this](Function &F) -> const StackSafetyInfo & { + return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult(); + }); + SSI = SSDFA.run(); + return false; +} + +static const char LocalPassArg[] = "stack-safety-local"; +static const char LocalPassName[] = "Stack Safety Local Analysis"; +INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName, + false, true) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName, + false, true) + +static const char GlobalPassName[] = "Stack Safety Analysis"; +INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE, + GlobalPassName, false, false) +INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass) +INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE, + GlobalPassName, false, false) diff --git a/llvm/lib/Analysis/StratifiedSets.h b/llvm/lib/Analysis/StratifiedSets.h new file mode 100644 index 000000000000..60ea2451b0ef --- /dev/null +++ b/llvm/lib/Analysis/StratifiedSets.h @@ -0,0 +1,596 @@ +//===- StratifiedSets.h - Abstract stratified sets implementation. --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_STRATIFIEDSETS_H +#define LLVM_ADT_STRATIFIEDSETS_H + +#include "AliasAnalysisSummary.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include <bitset> +#include <cassert> +#include <cmath> +#include <type_traits> +#include <utility> +#include <vector> + +namespace llvm { +namespace cflaa { +/// An index into Stratified Sets. +typedef unsigned StratifiedIndex; +/// NOTE: ^ This can't be a short -- bootstrapping clang has a case where +/// ~1M sets exist. + +// Container of information related to a value in a StratifiedSet. +struct StratifiedInfo { + StratifiedIndex Index; + /// For field sensitivity, etc. we can tack fields on here. +}; + +/// A "link" between two StratifiedSets. +struct StratifiedLink { + /// This is a value used to signify "does not exist" where the + /// StratifiedIndex type is used. + /// + /// This is used instead of Optional<StratifiedIndex> because + /// Optional<StratifiedIndex> would eat up a considerable amount of extra + /// memory, after struct padding/alignment is taken into account. + static const StratifiedIndex SetSentinel; + + /// The index for the set "above" current + StratifiedIndex Above; + + /// The link for the set "below" current + StratifiedIndex Below; + + /// Attributes for these StratifiedSets. + AliasAttrs Attrs; + + StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} + + bool hasBelow() const { return Below != SetSentinel; } + bool hasAbove() const { return Above != SetSentinel; } + + void clearBelow() { Below = SetSentinel; } + void clearAbove() { Above = SetSentinel; } +}; + +/// These are stratified sets, as described in "Fast algorithms for +/// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M +/// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets +/// of Value*s. If two Value*s are in the same set, or if both sets have +/// overlapping attributes, then the Value*s are said to alias. +/// +/// Sets may be related by position, meaning that one set may be considered as +/// above or below another. In CFL Alias Analysis, this gives us an indication +/// of how two variables are related; if the set of variable A is below a set +/// containing variable B, then at some point, a variable that has interacted +/// with B (or B itself) was either used in order to extract the variable A, or +/// was used as storage of variable A. +/// +/// Sets may also have attributes (as noted above). These attributes are +/// generally used for noting whether a variable in the set has interacted with +/// a variable whose origins we don't quite know (i.e. globals/arguments), or if +/// the variable may have had operations performed on it (modified in a function +/// call). All attributes that exist in a set A must exist in all sets marked as +/// below set A. +template <typename T> class StratifiedSets { +public: + StratifiedSets() = default; + StratifiedSets(StratifiedSets &&) = default; + StratifiedSets &operator=(StratifiedSets &&) = default; + + StratifiedSets(DenseMap<T, StratifiedInfo> Map, + std::vector<StratifiedLink> Links) + : Values(std::move(Map)), Links(std::move(Links)) {} + + Optional<StratifiedInfo> find(const T &Elem) const { + auto Iter = Values.find(Elem); + if (Iter == Values.end()) + return None; + return Iter->second; + } + + const StratifiedLink &getLink(StratifiedIndex Index) const { + assert(inbounds(Index)); + return Links[Index]; + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<StratifiedLink> Links; + + bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } +}; + +/// Generic Builder class that produces StratifiedSets instances. +/// +/// The goal of this builder is to efficiently produce correct StratifiedSets +/// instances. To this end, we use a few tricks: +/// > Set chains (A method for linking sets together) +/// > Set remaps (A method for marking a set as an alias [irony?] of another) +/// +/// ==== Set chains ==== +/// This builder has a notion of some value A being above, below, or with some +/// other value B: +/// > The `A above B` relationship implies that there is a reference edge +/// going from A to B. Namely, it notes that A can store anything in B's set. +/// > The `A below B` relationship is the opposite of `A above B`. It implies +/// that there's a dereference edge going from A to B. +/// > The `A with B` relationship states that there's an assignment edge going +/// from A to B, and that A and B should be treated as equals. +/// +/// As an example, take the following code snippet: +/// +/// %a = alloca i32, align 4 +/// %ap = alloca i32*, align 8 +/// %app = alloca i32**, align 8 +/// store %a, %ap +/// store %ap, %app +/// %aw = getelementptr %ap, i32 0 +/// +/// Given this, the following relations exist: +/// - %a below %ap & %ap above %a +/// - %ap below %app & %app above %ap +/// - %aw with %ap & %ap with %aw +/// +/// These relations produce the following sets: +/// [{%a}, {%ap, %aw}, {%app}] +/// +/// ...Which state that the only MayAlias relationship in the above program is +/// between %ap and %aw. +/// +/// Because LLVM allows arbitrary casts, code like the following needs to be +/// supported: +/// %ip = alloca i64, align 8 +/// %ipp = alloca i64*, align 8 +/// %i = bitcast i64** ipp to i64 +/// store i64* %ip, i64** %ipp +/// store i64 %i, i64* %ip +/// +/// Which, because %ipp ends up *both* above and below %ip, is fun. +/// +/// This is solved by merging %i and %ipp into a single set (...which is the +/// only way to solve this, since their bit patterns are equivalent). Any sets +/// that ended up in between %i and %ipp at the time of merging (in this case, +/// the set containing %ip) also get conservatively merged into the set of %i +/// and %ipp. In short, the resulting StratifiedSet from the above code would be +/// {%ip, %ipp, %i}. +/// +/// ==== Set remaps ==== +/// More of an implementation detail than anything -- when merging sets, we need +/// to update the numbers of all of the elements mapped to those sets. Rather +/// than doing this at each merge, we note in the BuilderLink structure that a +/// remap has occurred, and use this information so we can defer renumbering set +/// elements until build time. +template <typename T> class StratifiedSetsBuilder { + /// Represents a Stratified Set, with information about the Stratified + /// Set above it, the set below it, and whether the current set has been + /// remapped to another. + struct BuilderLink { + const StratifiedIndex Number; + + BuilderLink(StratifiedIndex N) : Number(N) { + Remap = StratifiedLink::SetSentinel; + } + + bool hasAbove() const { + assert(!isRemapped()); + return Link.hasAbove(); + } + + bool hasBelow() const { + assert(!isRemapped()); + return Link.hasBelow(); + } + + void setBelow(StratifiedIndex I) { + assert(!isRemapped()); + Link.Below = I; + } + + void setAbove(StratifiedIndex I) { + assert(!isRemapped()); + Link.Above = I; + } + + void clearBelow() { + assert(!isRemapped()); + Link.clearBelow(); + } + + void clearAbove() { + assert(!isRemapped()); + Link.clearAbove(); + } + + StratifiedIndex getBelow() const { + assert(!isRemapped()); + assert(hasBelow()); + return Link.Below; + } + + StratifiedIndex getAbove() const { + assert(!isRemapped()); + assert(hasAbove()); + return Link.Above; + } + + AliasAttrs getAttrs() { + assert(!isRemapped()); + return Link.Attrs; + } + + void setAttrs(AliasAttrs Other) { + assert(!isRemapped()); + Link.Attrs |= Other; + } + + bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } + + /// For initial remapping to another set + void remapTo(StratifiedIndex Other) { + assert(!isRemapped()); + Remap = Other; + } + + StratifiedIndex getRemapIndex() const { + assert(isRemapped()); + return Remap; + } + + /// Should only be called when we're already remapped. + void updateRemap(StratifiedIndex Other) { + assert(isRemapped()); + Remap = Other; + } + + /// Prefer the above functions to calling things directly on what's returned + /// from this -- they guard against unexpected calls when the current + /// BuilderLink is remapped. + const StratifiedLink &getLink() const { return Link; } + + private: + StratifiedLink Link; + StratifiedIndex Remap; + }; + + /// This function performs all of the set unioning/value renumbering + /// that we've been putting off, and generates a vector<StratifiedLink> that + /// may be placed in a StratifiedSets instance. + void finalizeSets(std::vector<StratifiedLink> &StratLinks) { + DenseMap<StratifiedIndex, StratifiedIndex> Remaps; + for (auto &Link : Links) { + if (Link.isRemapped()) + continue; + + StratifiedIndex Number = StratLinks.size(); + Remaps.insert(std::make_pair(Link.Number, Number)); + StratLinks.push_back(Link.getLink()); + } + + for (auto &Link : StratLinks) { + if (Link.hasAbove()) { + auto &Above = linksAt(Link.Above); + auto Iter = Remaps.find(Above.Number); + assert(Iter != Remaps.end()); + Link.Above = Iter->second; + } + + if (Link.hasBelow()) { + auto &Below = linksAt(Link.Below); + auto Iter = Remaps.find(Below.Number); + assert(Iter != Remaps.end()); + Link.Below = Iter->second; + } + } + + for (auto &Pair : Values) { + auto &Info = Pair.second; + auto &Link = linksAt(Info.Index); + auto Iter = Remaps.find(Link.Number); + assert(Iter != Remaps.end()); + Info.Index = Iter->second; + } + } + + /// There's a guarantee in StratifiedLink where all bits set in a + /// Link.externals will be set in all Link.externals "below" it. + static void propagateAttrs(std::vector<StratifiedLink> &Links) { + const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { + const auto *Link = &Links[Idx]; + while (Link->hasAbove()) { + Idx = Link->Above; + Link = &Links[Idx]; + } + return Idx; + }; + + SmallSet<StratifiedIndex, 16> Visited; + for (unsigned I = 0, E = Links.size(); I < E; ++I) { + auto CurrentIndex = getHighestParentAbove(I); + if (!Visited.insert(CurrentIndex).second) + continue; + + while (Links[CurrentIndex].hasBelow()) { + auto &CurrentBits = Links[CurrentIndex].Attrs; + auto NextIndex = Links[CurrentIndex].Below; + auto &NextBits = Links[NextIndex].Attrs; + NextBits |= CurrentBits; + CurrentIndex = NextIndex; + } + } + } + +public: + /// Builds a StratifiedSet from the information we've been given since either + /// construction or the prior build() call. + StratifiedSets<T> build() { + std::vector<StratifiedLink> StratLinks; + finalizeSets(StratLinks); + propagateAttrs(StratLinks); + Links.clear(); + return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); + } + + bool has(const T &Elem) const { return get(Elem).hasValue(); } + + bool add(const T &Main) { + if (get(Main).hasValue()) + return false; + + auto NewIndex = getNewUnlinkedIndex(); + return addAtMerging(Main, NewIndex); + } + + /// Restructures the stratified sets as necessary to make "ToAdd" in a + /// set above "Main". There are some cases where this is not possible (see + /// above), so we merge them such that ToAdd and Main are in the same set. + bool addAbove(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasAbove()) + addLinkAbove(Index); + + auto Above = linksAt(Index).getAbove(); + return addAtMerging(ToAdd, Above); + } + + /// Restructures the stratified sets as necessary to make "ToAdd" in a + /// set below "Main". There are some cases where this is not possible (see + /// above), so we merge them such that ToAdd and Main are in the same set. + bool addBelow(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasBelow()) + addLinkBelow(Index); + + auto Below = linksAt(Index).getBelow(); + return addAtMerging(ToAdd, Below); + } + + bool addWith(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto MainIndex = *indexOf(Main); + return addAtMerging(ToAdd, MainIndex); + } + + void noteAttributes(const T &Main, AliasAttrs NewAttrs) { + assert(has(Main)); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + Link.setAttrs(NewAttrs); + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<BuilderLink> Links; + + /// Adds the given element at the given index, merging sets if necessary. + bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { + StratifiedInfo Info = {Index}; + auto Pair = Values.insert(std::make_pair(ToAdd, Info)); + if (Pair.second) + return true; + + auto &Iter = Pair.first; + auto &IterSet = linksAt(Iter->second.Index); + auto &ReqSet = linksAt(Index); + + // Failed to add where we wanted to. Merge the sets. + if (&IterSet != &ReqSet) + merge(IterSet.Number, ReqSet.Number); + + return false; + } + + /// Gets the BuilderLink at the given index, taking set remapping into + /// account. + BuilderLink &linksAt(StratifiedIndex Index) { + auto *Start = &Links[Index]; + if (!Start->isRemapped()) + return *Start; + + auto *Current = Start; + while (Current->isRemapped()) + Current = &Links[Current->getRemapIndex()]; + + auto NewRemap = Current->Number; + + // Run through everything that has yet to be updated, and update them to + // remap to NewRemap + Current = Start; + while (Current->isRemapped()) { + auto *Next = &Links[Current->getRemapIndex()]; + Current->updateRemap(NewRemap); + Current = Next; + } + + return *Current; + } + + /// Merges two sets into one another. Assumes that these sets are not + /// already one in the same. + void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + assert(&linksAt(Idx1) != &linksAt(Idx2) && + "Merging a set into itself is not allowed"); + + // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge + // both the + // given sets, and all sets between them, into one. + if (tryMergeUpwards(Idx1, Idx2)) + return; + + if (tryMergeUpwards(Idx2, Idx1)) + return; + + // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`. + // We therefore need to merge the two chains together. + mergeDirect(Idx1, Idx2); + } + + /// Merges two sets assuming that the set at `Idx1` is unreachable from + /// traversing above or below the set at `Idx2`. + void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + + auto *LinksInto = &linksAt(Idx1); + auto *LinksFrom = &linksAt(Idx2); + // Merging everything above LinksInto then proceeding to merge everything + // below LinksInto becomes problematic, so we go as far "up" as possible! + while (LinksInto->hasAbove() && LinksFrom->hasAbove()) { + LinksInto = &linksAt(LinksInto->getAbove()); + LinksFrom = &linksAt(LinksFrom->getAbove()); + } + + if (LinksFrom->hasAbove()) { + LinksInto->setAbove(LinksFrom->getAbove()); + auto &NewAbove = linksAt(LinksInto->getAbove()); + NewAbove.setBelow(LinksInto->Number); + } + + // Merging strategy: + // > If neither has links below, stop. + // > If only `LinksInto` has links below, stop. + // > If only `LinksFrom` has links below, reset `LinksInto.Below` to + // match `LinksFrom.Below` + // > If both have links above, deal with those next. + while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { + auto FromAttrs = LinksFrom->getAttrs(); + LinksInto->setAttrs(FromAttrs); + + // Remap needs to happen after getBelow(), but before + // assignment of LinksFrom + auto *NewLinksFrom = &linksAt(LinksFrom->getBelow()); + LinksFrom->remapTo(LinksInto->Number); + LinksFrom = NewLinksFrom; + LinksInto = &linksAt(LinksInto->getBelow()); + } + + if (LinksFrom->hasBelow()) { + LinksInto->setBelow(LinksFrom->getBelow()); + auto &NewBelow = linksAt(LinksInto->getBelow()); + NewBelow.setAbove(LinksInto->Number); + } + + LinksInto->setAttrs(LinksFrom->getAttrs()); + LinksFrom->remapTo(LinksInto->Number); + } + + /// Checks to see if lowerIndex is at a level lower than upperIndex. If so, it + /// will merge lowerIndex with upperIndex (and all of the sets between) and + /// return true. Otherwise, it will return false. + bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { + assert(inbounds(LowerIndex) && inbounds(UpperIndex)); + auto *Lower = &linksAt(LowerIndex); + auto *Upper = &linksAt(UpperIndex); + if (Lower == Upper) + return true; + + SmallVector<BuilderLink *, 8> Found; + auto *Current = Lower; + auto Attrs = Current->getAttrs(); + while (Current->hasAbove() && Current != Upper) { + Found.push_back(Current); + Attrs |= Current->getAttrs(); + Current = &linksAt(Current->getAbove()); + } + + if (Current != Upper) + return false; + + Upper->setAttrs(Attrs); + + if (Lower->hasBelow()) { + auto NewBelowIndex = Lower->getBelow(); + Upper->setBelow(NewBelowIndex); + auto &NewBelow = linksAt(NewBelowIndex); + NewBelow.setAbove(UpperIndex); + } else { + Upper->clearBelow(); + } + + for (const auto &Ptr : Found) + Ptr->remapTo(Upper->Number); + + return true; + } + + Optional<const StratifiedInfo *> get(const T &Val) const { + auto Result = Values.find(Val); + if (Result == Values.end()) + return None; + return &Result->second; + } + + Optional<StratifiedInfo *> get(const T &Val) { + auto Result = Values.find(Val); + if (Result == Values.end()) + return None; + return &Result->second; + } + + Optional<StratifiedIndex> indexOf(const T &Val) { + auto MaybeVal = get(Val); + if (!MaybeVal.hasValue()) + return None; + auto *Info = *MaybeVal; + auto &Link = linksAt(Info->Index); + return Link.Number; + } + + StratifiedIndex addLinkBelow(StratifiedIndex Set) { + auto At = addLinks(); + Links[Set].setBelow(At); + Links[At].setAbove(Set); + return At; + } + + StratifiedIndex addLinkAbove(StratifiedIndex Set) { + auto At = addLinks(); + Links[At].setBelow(Set); + Links[Set].setAbove(At); + return At; + } + + StratifiedIndex getNewUnlinkedIndex() { return addLinks(); } + + StratifiedIndex addLinks() { + auto Link = Links.size(); + Links.push_back(BuilderLink(Link)); + return Link; + } + + bool inbounds(StratifiedIndex N) const { return N < Links.size(); } +}; +} +} +#endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp new file mode 100644 index 000000000000..8447dc87069d --- /dev/null +++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp @@ -0,0 +1,380 @@ +//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation +//--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an algorithm that returns for a divergent branch +// the set of basic blocks whose phi nodes become divergent due to divergent +// control. These are the blocks that are reachable by two disjoint paths from +// the branch or loop exits that have a reaching path that is disjoint from a +// path to the loop latch. +// +// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model +// control-induced divergence in phi nodes. +// +// -- Summary -- +// The SyncDependenceAnalysis lazily computes sync dependences [3]. +// The analysis evaluates the disjoint path criterion [2] by a reduction +// to SSA construction. The SSA construction algorithm is implemented as +// a simple data-flow analysis [1]. +// +// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy +// [2] "Efficiently Computing Static Single Assignment Form +// and the Control Dependence Graph", TOPLAS '91, +// Cytron, Ferrante, Rosen, Wegman and Zadeck +// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack +// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira +// +// -- Sync dependence -- +// Sync dependence [4] characterizes the control flow aspect of the +// propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// -- Reduction to SSA construction -- +// There are two disjoint paths from A to X, if a certain variant of SSA +// construction places a phi node in X under the following set-up scheme [2]. +// +// This variant of SSA construction ignores incoming undef values. +// That is paths from the entry without a definition do not result in +// phi nodes. +// +// entry +// / \ +// A \ +// / \ Y +// B C / +// \ / \ / +// D E +// \ / +// F +// Assume that A contains a divergent branch. We are interested +// in the set of all blocks where each block is reachable from A +// via two disjoint paths. This would be the set {D, F} in this +// case. +// To generally reduce this query to SSA construction we introduce +// a virtual variable x and assign to x different values in each +// successor block of A. +// entry +// / \ +// A \ +// / \ Y +// x = 0 x = 1 / +// \ / \ / +// D E +// \ / +// F +// Our flavor of SSA construction for x will construct the following +// entry +// / \ +// A \ +// / \ Y +// x0 = 0 x1 = 1 / +// \ / \ / +// x2=phi E +// \ / +// x3=phi +// The blocks D and F contain phi nodes and are thus each reachable +// by two disjoins paths from A. +// +// -- Remarks -- +// In case of loop exits we need to check the disjoint path criterion for loops +// [2]. To this end, we check whether the definition of x differs between the +// loop exit and the loop header (_after_ SSA construction). +// +//===----------------------------------------------------------------------===// +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/SyncDependenceAnalysis.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" + +#include <stack> +#include <unordered_set> + +#define DEBUG_TYPE "sync-dependence" + +namespace llvm { + +ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; + +SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI) + : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} + +SyncDependenceAnalysis::~SyncDependenceAnalysis() {} + +using FunctionRPOT = ReversePostOrderTraversal<const Function *>; + +// divergence propagator for reducible CFGs +struct DivergencePropagator { + const FunctionRPOT &FuncRPOT; + const DominatorTree &DT; + const PostDominatorTree &PDT; + const LoopInfo &LI; + + // identified join points + std::unique_ptr<ConstBlockSet> JoinBlocks; + + // reached loop exits (by a path disjoint to a path to the loop header) + SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits; + + // if DefMap[B] == C then C is the dominating definition at block B + // if DefMap[B] ~ undef then we haven't seen B yet + // if DefMap[B] == B then B is a join point of disjoint paths from X or B is + // an immediate successor of X (initial value). + using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>; + DefiningBlockMap DefMap; + + // all blocks with pending visits + std::unordered_set<const BasicBlock *> PendingUpdates; + + DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, + const PostDominatorTree &PDT, const LoopInfo &LI) + : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), + JoinBlocks(new ConstBlockSet) {} + + // set the definition at @block and mark @block as pending for a visit + void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { + bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; + if (WasAdded) + PendingUpdates.insert(&Block); + } + + void printDefs(raw_ostream &Out) { + Out << "Propagator::DefMap {\n"; + for (const auto *Block : FuncRPOT) { + auto It = DefMap.find(Block); + Out << Block->getName() << " : "; + if (It == DefMap.end()) { + Out << "\n"; + } else { + const auto *DefBlock = It->second; + Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n"; + } + } + Out << "}\n"; + } + + // process @succBlock with reaching definition @defBlock + // the original divergent branch was in @parentLoop (if any) + void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, + const BasicBlock &DefBlock) { + + // @succBlock is a loop exit + if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { + DefMap.emplace(&SuccBlock, &DefBlock); + ReachedLoopExits.insert(&SuccBlock); + return; + } + + // first reaching def? + auto ItLastDef = DefMap.find(&SuccBlock); + if (ItLastDef == DefMap.end()) { + addPending(SuccBlock, DefBlock); + return; + } + + // a join of at least two definitions + if (ItLastDef->second != &DefBlock) { + // do we know this join already? + if (!JoinBlocks->insert(&SuccBlock).second) + return; + + // update the definition + addPending(SuccBlock, SuccBlock); + } + } + + // find all blocks reachable by two disjoint paths from @rootTerm. + // This method works for both divergent terminators and loops with + // divergent exits. + // @rootBlock is either the block containing the branch or the header of the + // divergent loop. + // @nodeSuccessors is the set of successors of the node (Loop or Terminator) + // headed by @rootBlock. + // @parentLoop is the parent loop of the Loop or the loop that contains the + // Terminator. + template <typename SuccessorIterable> + std::unique_ptr<ConstBlockSet> + computeJoinPoints(const BasicBlock &RootBlock, + SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { + assert(JoinBlocks); + + LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop ? ParentLoop->getName() : "<null>") << "\n" ); + + // bootstrap with branch targets + for (const auto *SuccBlock : NodeSuccessors) { + DefMap.emplace(SuccBlock, SuccBlock); + + if (ParentLoop && !ParentLoop->contains(SuccBlock)) { + // immediate loop exit from node. + ReachedLoopExits.insert(SuccBlock); + } else { + // regular successor + PendingUpdates.insert(SuccBlock); + } + } + + LLVM_DEBUG( + dbgs() << "SDA: rpo order:\n"; + for (const auto * RpoBlock : FuncRPOT) { + dbgs() << "- " << RpoBlock->getName() << "\n"; + } + ); + + auto ItBeginRPO = FuncRPOT.begin(); + + // skip until term (TODO RPOT won't let us start at @term directly) + for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {} + + auto ItEndRPO = FuncRPOT.end(); + assert(ItBeginRPO != ItEndRPO); + + // propagate definitions at the immediate successors of the node in RPO + auto ItBlockRPO = ItBeginRPO; + while ((++ItBlockRPO != ItEndRPO) && + !PendingUpdates.empty()) { + const auto *Block = *ItBlockRPO; + LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); + + // skip Block if not pending update + auto ItPending = PendingUpdates.find(Block); + if (ItPending == PendingUpdates.end()) + continue; + PendingUpdates.erase(ItPending); + + // propagate definition at Block to its successors + auto ItDef = DefMap.find(Block); + const auto *DefBlock = ItDef->second; + assert(DefBlock); + + auto *BlockLoop = LI.getLoopFor(Block); + if (ParentLoop && + (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { + // if the successor is the header of a nested loop pretend its a + // single node with the loop's exits as successors + SmallVector<BasicBlock *, 4> BlockLoopExits; + BlockLoop->getExitBlocks(BlockLoopExits); + for (const auto *BlockLoopExit : BlockLoopExits) { + visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); + } + + } else { + // the successors are either on the same loop level or loop exits + for (const auto *SuccBlock : successors(Block)) { + visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); + } + } + } + + LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); + + // We need to know the definition at the parent loop header to decide + // whether the definition at the header is different from the definition at + // the loop exits, which would indicate a divergent loop exits. + // + // A // loop header + // | + // B // nested loop header + // | + // C -> X (exit from B loop) -..-> (A latch) + // | + // D -> back to B (B latch) + // | + // proper exit from both loops + // + // analyze reached loop exits + if (!ReachedLoopExits.empty()) { + const BasicBlock *ParentLoopHeader = + ParentLoop ? ParentLoop->getHeader() : nullptr; + + assert(ParentLoop); + auto ItHeaderDef = DefMap.find(ParentLoopHeader); + const auto *HeaderDefBlock = (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second; + + LLVM_DEBUG(printDefs(dbgs())); + assert(HeaderDefBlock && "no definition at header of carrying loop"); + + for (const auto *ExitBlock : ReachedLoopExits) { + auto ItExitDef = DefMap.find(ExitBlock); + assert((ItExitDef != DefMap.end()) && + "no reaching def at reachable loop exit"); + if (ItExitDef->second != HeaderDefBlock) { + JoinBlocks->insert(ExitBlock); + } + } + } + + return std::move(JoinBlocks); + } +}; + +const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { + using LoopExitVec = SmallVector<BasicBlock *, 4>; + LoopExitVec LoopExits; + Loop.getExitBlocks(LoopExits); + if (LoopExits.size() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedLoopExitJoins.find(&Loop); + if (ItCached != CachedLoopExitJoins.end()) { + return *ItCached->second; + } + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>( + *Loop.getHeader(), LoopExits, Loop.getParentLoop()); + + auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +const ConstBlockSet & +SyncDependenceAnalysis::join_blocks(const Instruction &Term) { + // trivial case + if (Term.getNumSuccessors() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedBranchJoins.find(&Term); + if (ItCached != CachedBranchJoins.end()) + return *ItCached->second; + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + const auto &TermBlock = *Term.getParent(); + auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>( + TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); + + auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +} // namespace llvm diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp new file mode 100644 index 000000000000..22766e5f07f5 --- /dev/null +++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp @@ -0,0 +1,104 @@ +//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for propagating synthetic counts. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SyntheticCountsUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/ModuleSummaryIndex.h" + +using namespace llvm; + +// Given an SCC, propagate entry counts along the edge of the SCC nodes. +template <typename CallGraphType> +void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( + const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) { + + DenseSet<NodeRef> SCCNodes; + SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges; + + for (auto &Node : SCC) + SCCNodes.insert(Node); + + // Partition the edges coming out of the SCC into those whose destination is + // in the SCC and the rest. + for (const auto &Node : SCCNodes) { + for (auto &E : children_edges<CallGraphType>(Node)) { + if (SCCNodes.count(CGT::edge_dest(E))) + SCCEdges.emplace_back(Node, E); + else + NonSCCEdges.emplace_back(Node, E); + } + } + + // For nodes in the same SCC, update the counts in two steps: + // 1. Compute the additional count for each node by propagating the counts + // along all incoming edges to the node that originate from within the same + // SCC and summing them up. + // 2. Add the additional counts to the nodes in the SCC. + // This ensures that the order of + // traversal of nodes within the SCC doesn't affect the final result. + + DenseMap<NodeRef, Scaled64> AdditionalCounts; + for (auto &E : SCCEdges) { + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) + continue; + auto Callee = CGT::edge_dest(E.second); + AdditionalCounts[Callee] += OptProfCount.getValue(); + } + + // Update the counts for the nodes in the SCC. + for (auto &Entry : AdditionalCounts) + AddCount(Entry.first, Entry.second); + + // Now update the counts for nodes outside the SCC. + for (auto &E : NonSCCEdges) { + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) + continue; + auto Callee = CGT::edge_dest(E.second); + AddCount(Callee, OptProfCount.getValue()); + } +} + +/// Propgate synthetic entry counts on a callgraph \p CG. +/// +/// This performs a reverse post-order traversal of the callgraph SCC. For each +/// SCC, it first propagates the entry counts to the nodes within the SCC +/// through call edges and updates them in one shot. Then the entry counts are +/// propagated to nodes outside the SCC. This requires \p GraphTraits +/// to have a specialization for \p CallGraphType. + +template <typename CallGraphType> +void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, + GetProfCountTy GetProfCount, + AddCountTy AddCount) { + std::vector<SccTy> SCCs; + + // Collect all the SCCs. + for (auto I = scc_begin(CG); !I.isAtEnd(); ++I) + SCCs.push_back(*I); + + // The callgraph-scc needs to be visited in top-down order for propagation. + // The scc iterator returns the scc in bottom-up order, so reverse the SCCs + // and call propagateFromSCC. + for (auto &SCC : reverse(SCCs)) + propagateFromSCC(SCC, GetProfCount, AddCount); +} + +template class llvm::SyntheticCountsUtils<const CallGraph *>; +template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>; diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp new file mode 100644 index 000000000000..230969698054 --- /dev/null +++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp @@ -0,0 +1,1638 @@ +//===-- TargetLibraryInfo.cpp - Runtime library information ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the TargetLibraryInfo class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary( + "vector-library", cl::Hidden, cl::desc("Vector functions library"), + cl::init(TargetLibraryInfoImpl::NoLibrary), + cl::values(clEnumValN(TargetLibraryInfoImpl::NoLibrary, "none", + "No vector functions library"), + clEnumValN(TargetLibraryInfoImpl::Accelerate, "Accelerate", + "Accelerate framework"), + clEnumValN(TargetLibraryInfoImpl::MASSV, "MASSV", + "IBM MASS vector library"), + clEnumValN(TargetLibraryInfoImpl::SVML, "SVML", + "Intel SVML library"))); + +StringLiteral const TargetLibraryInfoImpl::StandardNames[LibFunc::NumLibFuncs] = + { +#define TLI_DEFINE_STRING +#include "llvm/Analysis/TargetLibraryInfo.def" +}; + +static bool hasSinCosPiStret(const Triple &T) { + // Only Darwin variants have _stret versions of combined trig functions. + if (!T.isOSDarwin()) + return false; + + // The ABI is rather complicated on x86, so don't do anything special there. + if (T.getArch() == Triple::x86) + return false; + + if (T.isMacOSX() && T.isMacOSXVersionLT(10, 9)) + return false; + + if (T.isiOS() && T.isOSVersionLT(7, 0)) + return false; + + return true; +} + +static bool hasBcmp(const Triple &TT) { + // Posix removed support from bcmp() in 2001, but the glibc and several + // implementations of the libc still have it. + if (TT.isOSLinux()) + return TT.isGNUEnvironment() || TT.isMusl(); + // Both NetBSD and OpenBSD are planning to remove the function. Windows does + // not have it. + return TT.isOSFreeBSD() || TT.isOSSolaris(); +} + +/// Initialize the set of available library functions based on the specified +/// target triple. This should be carefully written so that a missing target +/// triple gets a sane set of defaults. +static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, + ArrayRef<StringLiteral> StandardNames) { + // Verify that the StandardNames array is in alphabetical order. + assert(std::is_sorted(StandardNames.begin(), StandardNames.end(), + [](StringRef LHS, StringRef RHS) { + return LHS < RHS; + }) && + "TargetLibraryInfoImpl function names must be sorted"); + + // Set IO unlocked variants as unavailable + // Set them as available per system below + TLI.setUnavailable(LibFunc_getchar_unlocked); + TLI.setUnavailable(LibFunc_putc_unlocked); + TLI.setUnavailable(LibFunc_putchar_unlocked); + TLI.setUnavailable(LibFunc_fputc_unlocked); + TLI.setUnavailable(LibFunc_fgetc_unlocked); + TLI.setUnavailable(LibFunc_fread_unlocked); + TLI.setUnavailable(LibFunc_fwrite_unlocked); + TLI.setUnavailable(LibFunc_fputs_unlocked); + TLI.setUnavailable(LibFunc_fgets_unlocked); + + bool ShouldExtI32Param = false, ShouldExtI32Return = false, + ShouldSignExtI32Param = false; + // PowerPC64, Sparc64, SystemZ need signext/zeroext on i32 parameters and + // returns corresponding to C-level ints and unsigned ints. + if (T.isPPC64() || T.getArch() == Triple::sparcv9 || + T.getArch() == Triple::systemz) { + ShouldExtI32Param = true; + ShouldExtI32Return = true; + } + // Mips, on the other hand, needs signext on i32 parameters corresponding + // to both signed and unsigned ints. + if (T.isMIPS()) { + ShouldSignExtI32Param = true; + } + TLI.setShouldExtI32Param(ShouldExtI32Param); + TLI.setShouldExtI32Return(ShouldExtI32Return); + TLI.setShouldSignExtI32Param(ShouldSignExtI32Param); + + if (T.getArch() == Triple::r600 || + T.getArch() == Triple::amdgcn) + TLI.disableAllFunctions(); + + // There are no library implementations of memcpy and memset for AMD gpus and + // these can be difficult to lower in the backend. + if (T.getArch() == Triple::r600 || + T.getArch() == Triple::amdgcn) { + TLI.setUnavailable(LibFunc_memcpy); + TLI.setUnavailable(LibFunc_memset); + TLI.setUnavailable(LibFunc_memset_pattern16); + return; + } + + // memset_pattern16 is only available on iOS 3.0 and Mac OS X 10.5 and later. + // All versions of watchOS support it. + if (T.isMacOSX()) { + // available IO unlocked variants on Mac OS X + TLI.setAvailable(LibFunc_getc_unlocked); + TLI.setAvailable(LibFunc_getchar_unlocked); + TLI.setAvailable(LibFunc_putc_unlocked); + TLI.setAvailable(LibFunc_putchar_unlocked); + + if (T.isMacOSXVersionLT(10, 5)) + TLI.setUnavailable(LibFunc_memset_pattern16); + } else if (T.isiOS()) { + if (T.isOSVersionLT(3, 0)) + TLI.setUnavailable(LibFunc_memset_pattern16); + } else if (!T.isWatchOS()) { + TLI.setUnavailable(LibFunc_memset_pattern16); + } + + if (!hasSinCosPiStret(T)) { + TLI.setUnavailable(LibFunc_sinpi); + TLI.setUnavailable(LibFunc_sinpif); + TLI.setUnavailable(LibFunc_cospi); + TLI.setUnavailable(LibFunc_cospif); + TLI.setUnavailable(LibFunc_sincospi_stret); + TLI.setUnavailable(LibFunc_sincospif_stret); + } + + if (!hasBcmp(T)) + TLI.setUnavailable(LibFunc_bcmp); + + if (T.isMacOSX() && T.getArch() == Triple::x86 && + !T.isMacOSXVersionLT(10, 7)) { + // x86-32 OSX has a scheme where fwrite and fputs (and some other functions + // we don't care about) have two versions; on recent OSX, the one we want + // has a $UNIX2003 suffix. The two implementations are identical except + // for the return value in some edge cases. However, we don't want to + // generate code that depends on the old symbols. + TLI.setAvailableWithName(LibFunc_fwrite, "fwrite$UNIX2003"); + TLI.setAvailableWithName(LibFunc_fputs, "fputs$UNIX2003"); + } + + // iprintf and friends are only available on XCore, TCE, and Emscripten. + if (T.getArch() != Triple::xcore && T.getArch() != Triple::tce && + T.getOS() != Triple::Emscripten) { + TLI.setUnavailable(LibFunc_iprintf); + TLI.setUnavailable(LibFunc_siprintf); + TLI.setUnavailable(LibFunc_fiprintf); + } + + // __small_printf and friends are only available on Emscripten. + if (T.getOS() != Triple::Emscripten) { + TLI.setUnavailable(LibFunc_small_printf); + TLI.setUnavailable(LibFunc_small_sprintf); + TLI.setUnavailable(LibFunc_small_fprintf); + } + + if (T.isOSWindows() && !T.isOSCygMing()) { + // XXX: The earliest documentation available at the moment is for VS2015/VC19: + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/floating-point-support?view=vs-2015 + // XXX: In order to use an MSVCRT older than VC19, + // the specific library version must be explicit in the target triple, + // e.g., x86_64-pc-windows-msvc18. + bool hasPartialC99 = true; + if (T.isKnownWindowsMSVCEnvironment()) { + unsigned Major, Minor, Micro; + T.getEnvironmentVersion(Major, Minor, Micro); + hasPartialC99 = (Major == 0 || Major >= 19); + } + + // Latest targets support C89 math functions, in part. + bool isARM = (T.getArch() == Triple::aarch64 || + T.getArch() == Triple::arm); + bool hasPartialFloat = (isARM || + T.getArch() == Triple::x86_64); + + // Win32 does not support float C89 math functions, in general. + if (!hasPartialFloat) { + TLI.setUnavailable(LibFunc_acosf); + TLI.setUnavailable(LibFunc_asinf); + TLI.setUnavailable(LibFunc_atan2f); + TLI.setUnavailable(LibFunc_atanf); + TLI.setUnavailable(LibFunc_ceilf); + TLI.setUnavailable(LibFunc_cosf); + TLI.setUnavailable(LibFunc_coshf); + TLI.setUnavailable(LibFunc_expf); + TLI.setUnavailable(LibFunc_floorf); + TLI.setUnavailable(LibFunc_fmodf); + TLI.setUnavailable(LibFunc_log10f); + TLI.setUnavailable(LibFunc_logf); + TLI.setUnavailable(LibFunc_modff); + TLI.setUnavailable(LibFunc_powf); + TLI.setUnavailable(LibFunc_sinf); + TLI.setUnavailable(LibFunc_sinhf); + TLI.setUnavailable(LibFunc_sqrtf); + TLI.setUnavailable(LibFunc_tanf); + TLI.setUnavailable(LibFunc_tanhf); + } + if (!isARM) + TLI.setUnavailable(LibFunc_fabsf); + TLI.setUnavailable(LibFunc_frexpf); + TLI.setUnavailable(LibFunc_ldexpf); + + // Win32 does not support long double C89 math functions. + TLI.setUnavailable(LibFunc_acosl); + TLI.setUnavailable(LibFunc_asinl); + TLI.setUnavailable(LibFunc_atan2l); + TLI.setUnavailable(LibFunc_atanl); + TLI.setUnavailable(LibFunc_ceill); + TLI.setUnavailable(LibFunc_cosl); + TLI.setUnavailable(LibFunc_coshl); + TLI.setUnavailable(LibFunc_expl); + TLI.setUnavailable(LibFunc_fabsl); + TLI.setUnavailable(LibFunc_floorl); + TLI.setUnavailable(LibFunc_fmodl); + TLI.setUnavailable(LibFunc_frexpl); + TLI.setUnavailable(LibFunc_ldexpl); + TLI.setUnavailable(LibFunc_log10l); + TLI.setUnavailable(LibFunc_logl); + TLI.setUnavailable(LibFunc_modfl); + TLI.setUnavailable(LibFunc_powl); + TLI.setUnavailable(LibFunc_sinl); + TLI.setUnavailable(LibFunc_sinhl); + TLI.setUnavailable(LibFunc_sqrtl); + TLI.setUnavailable(LibFunc_tanl); + TLI.setUnavailable(LibFunc_tanhl); + + // Win32 does not fully support C99 math functions. + if (!hasPartialC99) { + TLI.setUnavailable(LibFunc_acosh); + TLI.setUnavailable(LibFunc_acoshf); + TLI.setUnavailable(LibFunc_asinh); + TLI.setUnavailable(LibFunc_asinhf); + TLI.setUnavailable(LibFunc_atanh); + TLI.setUnavailable(LibFunc_atanhf); + TLI.setAvailableWithName(LibFunc_cabs, "_cabs"); + TLI.setUnavailable(LibFunc_cabsf); + TLI.setUnavailable(LibFunc_cbrt); + TLI.setUnavailable(LibFunc_cbrtf); + TLI.setAvailableWithName(LibFunc_copysign, "_copysign"); + TLI.setAvailableWithName(LibFunc_copysignf, "_copysignf"); + TLI.setUnavailable(LibFunc_exp2); + TLI.setUnavailable(LibFunc_exp2f); + TLI.setUnavailable(LibFunc_expm1); + TLI.setUnavailable(LibFunc_expm1f); + TLI.setUnavailable(LibFunc_fmax); + TLI.setUnavailable(LibFunc_fmaxf); + TLI.setUnavailable(LibFunc_fmin); + TLI.setUnavailable(LibFunc_fminf); + TLI.setUnavailable(LibFunc_log1p); + TLI.setUnavailable(LibFunc_log1pf); + TLI.setUnavailable(LibFunc_log2); + TLI.setUnavailable(LibFunc_log2f); + TLI.setAvailableWithName(LibFunc_logb, "_logb"); + if (hasPartialFloat) + TLI.setAvailableWithName(LibFunc_logbf, "_logbf"); + else + TLI.setUnavailable(LibFunc_logbf); + TLI.setUnavailable(LibFunc_rint); + TLI.setUnavailable(LibFunc_rintf); + TLI.setUnavailable(LibFunc_round); + TLI.setUnavailable(LibFunc_roundf); + TLI.setUnavailable(LibFunc_trunc); + TLI.setUnavailable(LibFunc_truncf); + } + + // Win32 does not support long double C99 math functions. + TLI.setUnavailable(LibFunc_acoshl); + TLI.setUnavailable(LibFunc_asinhl); + TLI.setUnavailable(LibFunc_atanhl); + TLI.setUnavailable(LibFunc_cabsl); + TLI.setUnavailable(LibFunc_cbrtl); + TLI.setUnavailable(LibFunc_copysignl); + TLI.setUnavailable(LibFunc_exp2l); + TLI.setUnavailable(LibFunc_expm1l); + TLI.setUnavailable(LibFunc_fmaxl); + TLI.setUnavailable(LibFunc_fminl); + TLI.setUnavailable(LibFunc_log1pl); + TLI.setUnavailable(LibFunc_log2l); + TLI.setUnavailable(LibFunc_logbl); + TLI.setUnavailable(LibFunc_nearbyintl); + TLI.setUnavailable(LibFunc_rintl); + TLI.setUnavailable(LibFunc_roundl); + TLI.setUnavailable(LibFunc_truncl); + + // Win32 does not support these functions, but + // they are generally available on POSIX-compliant systems. + TLI.setUnavailable(LibFunc_access); + TLI.setUnavailable(LibFunc_bcmp); + TLI.setUnavailable(LibFunc_bcopy); + TLI.setUnavailable(LibFunc_bzero); + TLI.setUnavailable(LibFunc_chmod); + TLI.setUnavailable(LibFunc_chown); + TLI.setUnavailable(LibFunc_closedir); + TLI.setUnavailable(LibFunc_ctermid); + TLI.setUnavailable(LibFunc_fdopen); + TLI.setUnavailable(LibFunc_ffs); + TLI.setUnavailable(LibFunc_fileno); + TLI.setUnavailable(LibFunc_flockfile); + TLI.setUnavailable(LibFunc_fseeko); + TLI.setUnavailable(LibFunc_fstat); + TLI.setUnavailable(LibFunc_fstatvfs); + TLI.setUnavailable(LibFunc_ftello); + TLI.setUnavailable(LibFunc_ftrylockfile); + TLI.setUnavailable(LibFunc_funlockfile); + TLI.setUnavailable(LibFunc_getitimer); + TLI.setUnavailable(LibFunc_getlogin_r); + TLI.setUnavailable(LibFunc_getpwnam); + TLI.setUnavailable(LibFunc_gettimeofday); + TLI.setUnavailable(LibFunc_htonl); + TLI.setUnavailable(LibFunc_htons); + TLI.setUnavailable(LibFunc_lchown); + TLI.setUnavailable(LibFunc_lstat); + TLI.setUnavailable(LibFunc_memccpy); + TLI.setUnavailable(LibFunc_mkdir); + TLI.setUnavailable(LibFunc_ntohl); + TLI.setUnavailable(LibFunc_ntohs); + TLI.setUnavailable(LibFunc_open); + TLI.setUnavailable(LibFunc_opendir); + TLI.setUnavailable(LibFunc_pclose); + TLI.setUnavailable(LibFunc_popen); + TLI.setUnavailable(LibFunc_pread); + TLI.setUnavailable(LibFunc_pwrite); + TLI.setUnavailable(LibFunc_read); + TLI.setUnavailable(LibFunc_readlink); + TLI.setUnavailable(LibFunc_realpath); + TLI.setUnavailable(LibFunc_rmdir); + TLI.setUnavailable(LibFunc_setitimer); + TLI.setUnavailable(LibFunc_stat); + TLI.setUnavailable(LibFunc_statvfs); + TLI.setUnavailable(LibFunc_stpcpy); + TLI.setUnavailable(LibFunc_stpncpy); + TLI.setUnavailable(LibFunc_strcasecmp); + TLI.setUnavailable(LibFunc_strncasecmp); + TLI.setUnavailable(LibFunc_times); + TLI.setUnavailable(LibFunc_uname); + TLI.setUnavailable(LibFunc_unlink); + TLI.setUnavailable(LibFunc_unsetenv); + TLI.setUnavailable(LibFunc_utime); + TLI.setUnavailable(LibFunc_utimes); + TLI.setUnavailable(LibFunc_write); + } + + switch (T.getOS()) { + case Triple::MacOSX: + // exp10 and exp10f are not available on OS X until 10.9 and iOS until 7.0 + // and their names are __exp10 and __exp10f. exp10l is not available on + // OS X or iOS. + TLI.setUnavailable(LibFunc_exp10l); + if (T.isMacOSXVersionLT(10, 9)) { + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); + } else { + TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); + TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); + } + break; + case Triple::IOS: + case Triple::TvOS: + case Triple::WatchOS: + TLI.setUnavailable(LibFunc_exp10l); + if (!T.isWatchOS() && (T.isOSVersionLT(7, 0) || + (T.isOSVersionLT(9, 0) && + (T.getArch() == Triple::x86 || + T.getArch() == Triple::x86_64)))) { + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); + } else { + TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); + TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); + } + break; + case Triple::Linux: + // exp10, exp10f, exp10l is available on Linux (GLIBC) but are extremely + // buggy prior to glibc version 2.18. Until this version is widely deployed + // or we have a reasonable detection strategy, we cannot use exp10 reliably + // on Linux. + // + // Fall through to disable all of them. + LLVM_FALLTHROUGH; + default: + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); + TLI.setUnavailable(LibFunc_exp10l); + } + + // ffsl is available on at least Darwin, Mac OS X, iOS, FreeBSD, and + // Linux (GLIBC): + // http://developer.apple.com/library/mac/#documentation/Darwin/Reference/ManPages/man3/ffsl.3.html + // http://svn.freebsd.org/base/head/lib/libc/string/ffsl.c + // http://www.gnu.org/software/gnulib/manual/html_node/ffsl.html + switch (T.getOS()) { + case Triple::Darwin: + case Triple::MacOSX: + case Triple::IOS: + case Triple::TvOS: + case Triple::WatchOS: + case Triple::FreeBSD: + case Triple::Linux: + break; + default: + TLI.setUnavailable(LibFunc_ffsl); + } + + // ffsll is available on at least FreeBSD and Linux (GLIBC): + // http://svn.freebsd.org/base/head/lib/libc/string/ffsll.c + // http://www.gnu.org/software/gnulib/manual/html_node/ffsll.html + switch (T.getOS()) { + case Triple::Darwin: + case Triple::MacOSX: + case Triple::IOS: + case Triple::TvOS: + case Triple::WatchOS: + case Triple::FreeBSD: + case Triple::Linux: + break; + default: + TLI.setUnavailable(LibFunc_ffsll); + } + + // The following functions are available on at least FreeBSD: + // http://svn.freebsd.org/base/head/lib/libc/string/fls.c + // http://svn.freebsd.org/base/head/lib/libc/string/flsl.c + // http://svn.freebsd.org/base/head/lib/libc/string/flsll.c + if (!T.isOSFreeBSD()) { + TLI.setUnavailable(LibFunc_fls); + TLI.setUnavailable(LibFunc_flsl); + TLI.setUnavailable(LibFunc_flsll); + } + + // The following functions are only available on GNU/Linux (using glibc). + // Linux variants without glibc (eg: bionic, musl) may have some subset. + if (!T.isOSLinux() || !T.isGNUEnvironment()) { + TLI.setUnavailable(LibFunc_dunder_strdup); + TLI.setUnavailable(LibFunc_dunder_strtok_r); + TLI.setUnavailable(LibFunc_dunder_isoc99_scanf); + TLI.setUnavailable(LibFunc_dunder_isoc99_sscanf); + TLI.setUnavailable(LibFunc_under_IO_getc); + TLI.setUnavailable(LibFunc_under_IO_putc); + // But, Android and musl have memalign. + if (!T.isAndroid() && !T.isMusl()) + TLI.setUnavailable(LibFunc_memalign); + TLI.setUnavailable(LibFunc_fopen64); + TLI.setUnavailable(LibFunc_fseeko64); + TLI.setUnavailable(LibFunc_fstat64); + TLI.setUnavailable(LibFunc_fstatvfs64); + TLI.setUnavailable(LibFunc_ftello64); + TLI.setUnavailable(LibFunc_lstat64); + TLI.setUnavailable(LibFunc_open64); + TLI.setUnavailable(LibFunc_stat64); + TLI.setUnavailable(LibFunc_statvfs64); + TLI.setUnavailable(LibFunc_tmpfile64); + + // Relaxed math functions are included in math-finite.h on Linux (GLIBC). + TLI.setUnavailable(LibFunc_acos_finite); + TLI.setUnavailable(LibFunc_acosf_finite); + TLI.setUnavailable(LibFunc_acosl_finite); + TLI.setUnavailable(LibFunc_acosh_finite); + TLI.setUnavailable(LibFunc_acoshf_finite); + TLI.setUnavailable(LibFunc_acoshl_finite); + TLI.setUnavailable(LibFunc_asin_finite); + TLI.setUnavailable(LibFunc_asinf_finite); + TLI.setUnavailable(LibFunc_asinl_finite); + TLI.setUnavailable(LibFunc_atan2_finite); + TLI.setUnavailable(LibFunc_atan2f_finite); + TLI.setUnavailable(LibFunc_atan2l_finite); + TLI.setUnavailable(LibFunc_atanh_finite); + TLI.setUnavailable(LibFunc_atanhf_finite); + TLI.setUnavailable(LibFunc_atanhl_finite); + TLI.setUnavailable(LibFunc_cosh_finite); + TLI.setUnavailable(LibFunc_coshf_finite); + TLI.setUnavailable(LibFunc_coshl_finite); + TLI.setUnavailable(LibFunc_exp10_finite); + TLI.setUnavailable(LibFunc_exp10f_finite); + TLI.setUnavailable(LibFunc_exp10l_finite); + TLI.setUnavailable(LibFunc_exp2_finite); + TLI.setUnavailable(LibFunc_exp2f_finite); + TLI.setUnavailable(LibFunc_exp2l_finite); + TLI.setUnavailable(LibFunc_exp_finite); + TLI.setUnavailable(LibFunc_expf_finite); + TLI.setUnavailable(LibFunc_expl_finite); + TLI.setUnavailable(LibFunc_log10_finite); + TLI.setUnavailable(LibFunc_log10f_finite); + TLI.setUnavailable(LibFunc_log10l_finite); + TLI.setUnavailable(LibFunc_log2_finite); + TLI.setUnavailable(LibFunc_log2f_finite); + TLI.setUnavailable(LibFunc_log2l_finite); + TLI.setUnavailable(LibFunc_log_finite); + TLI.setUnavailable(LibFunc_logf_finite); + TLI.setUnavailable(LibFunc_logl_finite); + TLI.setUnavailable(LibFunc_pow_finite); + TLI.setUnavailable(LibFunc_powf_finite); + TLI.setUnavailable(LibFunc_powl_finite); + TLI.setUnavailable(LibFunc_sinh_finite); + TLI.setUnavailable(LibFunc_sinhf_finite); + TLI.setUnavailable(LibFunc_sinhl_finite); + } + + if ((T.isOSLinux() && T.isGNUEnvironment()) || + (T.isAndroid() && !T.isAndroidVersionLT(28))) { + // available IO unlocked variants on GNU/Linux and Android P or later + TLI.setAvailable(LibFunc_getc_unlocked); + TLI.setAvailable(LibFunc_getchar_unlocked); + TLI.setAvailable(LibFunc_putc_unlocked); + TLI.setAvailable(LibFunc_putchar_unlocked); + TLI.setAvailable(LibFunc_fputc_unlocked); + TLI.setAvailable(LibFunc_fgetc_unlocked); + TLI.setAvailable(LibFunc_fread_unlocked); + TLI.setAvailable(LibFunc_fwrite_unlocked); + TLI.setAvailable(LibFunc_fputs_unlocked); + TLI.setAvailable(LibFunc_fgets_unlocked); + } + + // As currently implemented in clang, NVPTX code has no standard library to + // speak of. Headers provide a standard-ish library implementation, but many + // of the signatures are wrong -- for example, many libm functions are not + // extern "C". + // + // libdevice, an IR library provided by nvidia, is linked in by the front-end, + // but only used functions are provided to llvm. Moreover, most of the + // functions in libdevice don't map precisely to standard library functions. + // + // FIXME: Having no standard library prevents e.g. many fastmath + // optimizations, so this situation should be fixed. + if (T.isNVPTX()) { + TLI.disableAllFunctions(); + TLI.setAvailable(LibFunc_nvvm_reflect); + } else { + TLI.setUnavailable(LibFunc_nvvm_reflect); + } + + TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl() { + // Default to everything being available. + memset(AvailableArray, -1, sizeof(AvailableArray)); + + initialize(*this, Triple(), StandardNames); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(const Triple &T) { + // Default to everything being available. + memset(AvailableArray, -1, sizeof(AvailableArray)); + + initialize(*this, T, StandardNames); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(const TargetLibraryInfoImpl &TLI) + : CustomNames(TLI.CustomNames), ShouldExtI32Param(TLI.ShouldExtI32Param), + ShouldExtI32Return(TLI.ShouldExtI32Return), + ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) { + memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray)); + VectorDescs = TLI.VectorDescs; + ScalarDescs = TLI.ScalarDescs; +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(TargetLibraryInfoImpl &&TLI) + : CustomNames(std::move(TLI.CustomNames)), + ShouldExtI32Param(TLI.ShouldExtI32Param), + ShouldExtI32Return(TLI.ShouldExtI32Return), + ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) { + std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray), + AvailableArray); + VectorDescs = TLI.VectorDescs; + ScalarDescs = TLI.ScalarDescs; +} + +TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(const TargetLibraryInfoImpl &TLI) { + CustomNames = TLI.CustomNames; + ShouldExtI32Param = TLI.ShouldExtI32Param; + ShouldExtI32Return = TLI.ShouldExtI32Return; + ShouldSignExtI32Param = TLI.ShouldSignExtI32Param; + memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray)); + return *this; +} + +TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(TargetLibraryInfoImpl &&TLI) { + CustomNames = std::move(TLI.CustomNames); + ShouldExtI32Param = TLI.ShouldExtI32Param; + ShouldExtI32Return = TLI.ShouldExtI32Return; + ShouldSignExtI32Param = TLI.ShouldSignExtI32Param; + std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray), + AvailableArray); + return *this; +} + +static StringRef sanitizeFunctionName(StringRef funcName) { + // Filter out empty names and names containing null bytes, those can't be in + // our table. + if (funcName.empty() || funcName.find('\0') != StringRef::npos) + return StringRef(); + + // Check for \01 prefix that is used to mangle __asm declarations and + // strip it if present. + return GlobalValue::dropLLVMManglingEscape(funcName); +} + +bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, LibFunc &F) const { + funcName = sanitizeFunctionName(funcName); + if (funcName.empty()) + return false; + + const auto *Start = std::begin(StandardNames); + const auto *End = std::end(StandardNames); + const auto *I = std::lower_bound(Start, End, funcName); + if (I != End && *I == funcName) { + F = (LibFunc)(I - Start); + return true; + } + return false; +} + +bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, + LibFunc F, + const DataLayout *DL) const { + LLVMContext &Ctx = FTy.getContext(); + Type *PCharTy = Type::getInt8PtrTy(Ctx); + Type *SizeTTy = DL ? DL->getIntPtrType(Ctx, /*AS=*/0) : nullptr; + auto IsSizeTTy = [SizeTTy](Type *Ty) { + return SizeTTy ? Ty == SizeTTy : Ty->isIntegerTy(); + }; + unsigned NumParams = FTy.getNumParams(); + + switch (F) { + case LibFunc_execl: + case LibFunc_execlp: + case LibFunc_execle: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_execv: + case LibFunc_execvp: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_execvP: + case LibFunc_execvpe: + case LibFunc_execve: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_strlen: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc_strchr: + case LibFunc_strrchr: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1)->isIntegerTy()); + + case LibFunc_strtol: + case LibFunc_strtod: + case LibFunc_strtof: + case LibFunc_strtoul: + case LibFunc_strtoll: + case LibFunc_strtold: + case LibFunc_strtoull: + return ((NumParams == 2 || NumParams == 3) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_strcat_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_strcat: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1) == FTy.getReturnType()); + + case LibFunc_strncat_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_strncat: + return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1) == FTy.getReturnType() && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc_strcpy_chk: + case LibFunc_stpcpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_strcpy: + case LibFunc_stpcpy: + return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy); + + case LibFunc_strlcat_chk: + case LibFunc_strlcpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_strlcat: + case LibFunc_strlcpy: + return NumParams == 3 && IsSizeTTy(FTy.getReturnType()) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + IsSizeTTy(FTy.getParamType(2)); + + case LibFunc_strncpy_chk: + case LibFunc_stpncpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_strncpy: + case LibFunc_stpncpy: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc_strxfrm: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc_strcmp: + return (NumParams == 2 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1)); + + case LibFunc_strncmp: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1) && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc_strspn: + case LibFunc_strcspn: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc_strcoll: + case LibFunc_strcasecmp: + case LibFunc_strncasecmp: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc_strstr: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc_strpbrk: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1)); + + case LibFunc_strtok: + case LibFunc_strtok_r: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_scanf: + case LibFunc_setbuf: + case LibFunc_setvbuf: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_strdup: + case LibFunc_strndup: + return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc_sscanf: + case LibFunc_stat: + case LibFunc_statvfs: + case LibFunc_siprintf: + case LibFunc_small_sprintf: + case LibFunc_sprintf: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + + case LibFunc_sprintf_chk: + return NumParams == 4 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy(32) && + IsSizeTTy(FTy.getParamType(2)) && + FTy.getParamType(3)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32); + + case LibFunc_snprintf: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + + case LibFunc_snprintf_chk: + return NumParams == 5 && FTy.getParamType(0)->isPointerTy() && + IsSizeTTy(FTy.getParamType(1)) && + FTy.getParamType(2)->isIntegerTy(32) && + IsSizeTTy(FTy.getParamType(3)) && + FTy.getParamType(4)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32); + + case LibFunc_setitimer: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc_system: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_malloc: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + case LibFunc_memcmp: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc_memchr: + case LibFunc_memrchr: + return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(1)->isIntegerTy(32) && + IsSizeTTy(FTy.getParamType(2))); + case LibFunc_modf: + case LibFunc_modff: + case LibFunc_modfl: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + + case LibFunc_memcpy_chk: + case LibFunc_memmove_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_memcpy: + case LibFunc_mempcpy: + case LibFunc_memmove: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc_memset_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_memset: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy() && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc_memccpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + LLVM_FALLTHROUGH; + case LibFunc_memccpy: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_memalign: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc_realloc: + case LibFunc_reallocf: + return (NumParams == 2 && FTy.getReturnType() == PCharTy && + FTy.getParamType(0) == FTy.getReturnType() && + IsSizeTTy(FTy.getParamType(1))); + case LibFunc_read: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_rewind: + case LibFunc_rmdir: + case LibFunc_remove: + case LibFunc_realpath: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_rename: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_readlink: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_write: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_bcopy: + case LibFunc_bcmp: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_bzero: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_calloc: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + + case LibFunc_atof: + case LibFunc_atoi: + case LibFunc_atol: + case LibFunc_atoll: + case LibFunc_ferror: + case LibFunc_getenv: + case LibFunc_getpwnam: + case LibFunc_iprintf: + case LibFunc_small_printf: + case LibFunc_pclose: + case LibFunc_perror: + case LibFunc_printf: + case LibFunc_puts: + case LibFunc_uname: + case LibFunc_under_IO_getc: + case LibFunc_unlink: + case LibFunc_unsetenv: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_access: + case LibFunc_chmod: + case LibFunc_chown: + case LibFunc_clearerr: + case LibFunc_closedir: + case LibFunc_ctermid: + case LibFunc_fclose: + case LibFunc_feof: + case LibFunc_fflush: + case LibFunc_fgetc: + case LibFunc_fgetc_unlocked: + case LibFunc_fileno: + case LibFunc_flockfile: + case LibFunc_free: + case LibFunc_fseek: + case LibFunc_fseeko64: + case LibFunc_fseeko: + case LibFunc_fsetpos: + case LibFunc_ftell: + case LibFunc_ftello64: + case LibFunc_ftello: + case LibFunc_ftrylockfile: + case LibFunc_funlockfile: + case LibFunc_getc: + case LibFunc_getc_unlocked: + case LibFunc_getlogin_r: + case LibFunc_mkdir: + case LibFunc_mktime: + case LibFunc_times: + return (NumParams != 0 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_fopen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_fork: + return (NumParams == 0 && FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_fdopen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_fputc: + case LibFunc_fputc_unlocked: + case LibFunc_fstat: + case LibFunc_frexp: + case LibFunc_frexpf: + case LibFunc_frexpl: + case LibFunc_fstatvfs: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_fgets: + case LibFunc_fgets_unlocked: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc_fread: + case LibFunc_fread_unlocked: + return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(3)->isPointerTy()); + case LibFunc_fwrite: + case LibFunc_fwrite_unlocked: + return (NumParams == 4 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy() && + FTy.getParamType(2)->isIntegerTy() && + FTy.getParamType(3)->isPointerTy()); + case LibFunc_fputs: + case LibFunc_fputs_unlocked: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_fscanf: + case LibFunc_fiprintf: + case LibFunc_small_fprintf: + case LibFunc_fprintf: + return (NumParams >= 2 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_fgetpos: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_getchar: + case LibFunc_getchar_unlocked: + return (NumParams == 0 && FTy.getReturnType()->isIntegerTy()); + case LibFunc_gets: + return (NumParams == 1 && FTy.getParamType(0) == PCharTy); + case LibFunc_getitimer: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_ungetc: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_utime: + case LibFunc_utimes: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_putc: + case LibFunc_putc_unlocked: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_pread: + case LibFunc_pwrite: + return (NumParams == 4 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_popen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_vscanf: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_vsscanf: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc_vfscanf: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc_valloc: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc_vprintf: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_vfprintf: + case LibFunc_vsprintf: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_vsprintf_chk: + return NumParams == 5 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy(32) && + IsSizeTTy(FTy.getParamType(2)) && FTy.getParamType(3)->isPointerTy(); + case LibFunc_vsnprintf: + return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc_vsnprintf_chk: + return NumParams == 6 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isIntegerTy(32) && + IsSizeTTy(FTy.getParamType(3)) && FTy.getParamType(4)->isPointerTy(); + case LibFunc_open: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_opendir: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc_tmpfile: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc_htonl: + case LibFunc_ntohl: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getReturnType() == FTy.getParamType(0)); + case LibFunc_htons: + case LibFunc_ntohs: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(16) && + FTy.getReturnType() == FTy.getParamType(0)); + case LibFunc_lstat: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_lchown: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_qsort: + return (NumParams == 4 && FTy.getParamType(3)->isPointerTy()); + case LibFunc_dunder_strdup: + case LibFunc_dunder_strndup: + return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc_dunder_strtok_r: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_under_IO_putc: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_dunder_isoc99_scanf: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_stat64: + case LibFunc_lstat64: + case LibFunc_statvfs64: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_dunder_isoc99_sscanf: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_fopen64: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc_tmpfile64: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc_fstat64: + case LibFunc_fstatvfs64: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc_open64: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc_gettimeofday: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + // new(unsigned int); + case LibFunc_Znwj: + // new(unsigned long); + case LibFunc_Znwm: + // new[](unsigned int); + case LibFunc_Znaj: + // new[](unsigned long); + case LibFunc_Znam: + // new(unsigned int); + case LibFunc_msvc_new_int: + // new(unsigned long long); + case LibFunc_msvc_new_longlong: + // new[](unsigned int); + case LibFunc_msvc_new_array_int: + // new[](unsigned long long); + case LibFunc_msvc_new_array_longlong: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + + // new(unsigned int, nothrow); + case LibFunc_ZnwjRKSt9nothrow_t: + // new(unsigned long, nothrow); + case LibFunc_ZnwmRKSt9nothrow_t: + // new[](unsigned int, nothrow); + case LibFunc_ZnajRKSt9nothrow_t: + // new[](unsigned long, nothrow); + case LibFunc_ZnamRKSt9nothrow_t: + // new(unsigned int, nothrow); + case LibFunc_msvc_new_int_nothrow: + // new(unsigned long long, nothrow); + case LibFunc_msvc_new_longlong_nothrow: + // new[](unsigned int, nothrow); + case LibFunc_msvc_new_array_int_nothrow: + // new[](unsigned long long, nothrow); + case LibFunc_msvc_new_array_longlong_nothrow: + // new(unsigned int, align_val_t) + case LibFunc_ZnwjSt11align_val_t: + // new(unsigned long, align_val_t) + case LibFunc_ZnwmSt11align_val_t: + // new[](unsigned int, align_val_t) + case LibFunc_ZnajSt11align_val_t: + // new[](unsigned long, align_val_t) + case LibFunc_ZnamSt11align_val_t: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + + // new(unsigned int, align_val_t, nothrow) + case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: + // new(unsigned long, align_val_t, nothrow) + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: + // new[](unsigned int, align_val_t, nothrow) + case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: + // new[](unsigned long, align_val_t, nothrow) + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + return (NumParams == 3 && FTy.getReturnType()->isPointerTy()); + + // void operator delete[](void*); + case LibFunc_ZdaPv: + // void operator delete(void*); + case LibFunc_ZdlPv: + // void operator delete[](void*); + case LibFunc_msvc_delete_array_ptr32: + // void operator delete[](void*); + case LibFunc_msvc_delete_array_ptr64: + // void operator delete(void*); + case LibFunc_msvc_delete_ptr32: + // void operator delete(void*); + case LibFunc_msvc_delete_ptr64: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + // void operator delete[](void*, nothrow); + case LibFunc_ZdaPvRKSt9nothrow_t: + // void operator delete[](void*, unsigned int); + case LibFunc_ZdaPvj: + // void operator delete[](void*, unsigned long); + case LibFunc_ZdaPvm: + // void operator delete(void*, nothrow); + case LibFunc_ZdlPvRKSt9nothrow_t: + // void operator delete(void*, unsigned int); + case LibFunc_ZdlPvj: + // void operator delete(void*, unsigned long); + case LibFunc_ZdlPvm: + // void operator delete(void*, align_val_t) + case LibFunc_ZdlPvSt11align_val_t: + // void operator delete[](void*, align_val_t) + case LibFunc_ZdaPvSt11align_val_t: + // void operator delete[](void*, unsigned int); + case LibFunc_msvc_delete_array_ptr32_int: + // void operator delete[](void*, nothrow); + case LibFunc_msvc_delete_array_ptr32_nothrow: + // void operator delete[](void*, unsigned long long); + case LibFunc_msvc_delete_array_ptr64_longlong: + // void operator delete[](void*, nothrow); + case LibFunc_msvc_delete_array_ptr64_nothrow: + // void operator delete(void*, unsigned int); + case LibFunc_msvc_delete_ptr32_int: + // void operator delete(void*, nothrow); + case LibFunc_msvc_delete_ptr32_nothrow: + // void operator delete(void*, unsigned long long); + case LibFunc_msvc_delete_ptr64_longlong: + // void operator delete(void*, nothrow); + case LibFunc_msvc_delete_ptr64_nothrow: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + + // void operator delete(void*, align_val_t, nothrow) + case LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t: + // void operator delete[](void*, align_val_t, nothrow) + case LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_memset_pattern16: + return (!FTy.isVarArg() && NumParams == 3 && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isIntegerTy()); + + case LibFunc_cxa_guard_abort: + case LibFunc_cxa_guard_acquire: + case LibFunc_cxa_guard_release: + case LibFunc_nvvm_reflect: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_sincospi_stret: + case LibFunc_sincospif_stret: + return (NumParams == 1 && FTy.getParamType(0)->isFloatingPointTy()); + + case LibFunc_acos: + case LibFunc_acos_finite: + case LibFunc_acosf: + case LibFunc_acosf_finite: + case LibFunc_acosh: + case LibFunc_acosh_finite: + case LibFunc_acoshf: + case LibFunc_acoshf_finite: + case LibFunc_acoshl: + case LibFunc_acoshl_finite: + case LibFunc_acosl: + case LibFunc_acosl_finite: + case LibFunc_asin: + case LibFunc_asin_finite: + case LibFunc_asinf: + case LibFunc_asinf_finite: + case LibFunc_asinh: + case LibFunc_asinhf: + case LibFunc_asinhl: + case LibFunc_asinl: + case LibFunc_asinl_finite: + case LibFunc_atan: + case LibFunc_atanf: + case LibFunc_atanh: + case LibFunc_atanh_finite: + case LibFunc_atanhf: + case LibFunc_atanhf_finite: + case LibFunc_atanhl: + case LibFunc_atanhl_finite: + case LibFunc_atanl: + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosh: + case LibFunc_cosh_finite: + case LibFunc_coshf: + case LibFunc_coshf_finite: + case LibFunc_coshl: + case LibFunc_coshl_finite: + case LibFunc_cosl: + case LibFunc_exp10: + case LibFunc_exp10_finite: + case LibFunc_exp10f: + case LibFunc_exp10f_finite: + case LibFunc_exp10l: + case LibFunc_exp10l_finite: + case LibFunc_exp2: + case LibFunc_exp2_finite: + case LibFunc_exp2f: + case LibFunc_exp2f_finite: + case LibFunc_exp2l: + case LibFunc_exp2l_finite: + case LibFunc_exp: + case LibFunc_exp_finite: + case LibFunc_expf: + case LibFunc_expf_finite: + case LibFunc_expl: + case LibFunc_expl_finite: + case LibFunc_expm1: + case LibFunc_expm1f: + case LibFunc_expm1l: + case LibFunc_fabs: + case LibFunc_fabsf: + case LibFunc_fabsl: + case LibFunc_floor: + case LibFunc_floorf: + case LibFunc_floorl: + case LibFunc_log10: + case LibFunc_log10_finite: + case LibFunc_log10f: + case LibFunc_log10f_finite: + case LibFunc_log10l: + case LibFunc_log10l_finite: + case LibFunc_log1p: + case LibFunc_log1pf: + case LibFunc_log1pl: + case LibFunc_log2: + case LibFunc_log2_finite: + case LibFunc_log2f: + case LibFunc_log2f_finite: + case LibFunc_log2l: + case LibFunc_log2l_finite: + case LibFunc_log: + case LibFunc_log_finite: + case LibFunc_logb: + case LibFunc_logbf: + case LibFunc_logbl: + case LibFunc_logf: + case LibFunc_logf_finite: + case LibFunc_logl: + case LibFunc_logl_finite: + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_nearbyintl: + case LibFunc_rint: + case LibFunc_rintf: + case LibFunc_rintl: + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinh: + case LibFunc_sinh_finite: + case LibFunc_sinhf: + case LibFunc_sinhf_finite: + case LibFunc_sinhl: + case LibFunc_sinhl_finite: + case LibFunc_sinl: + case LibFunc_sqrt: + case LibFunc_sqrt_finite: + case LibFunc_sqrtf: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl: + case LibFunc_sqrtl_finite: + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanh: + case LibFunc_tanhf: + case LibFunc_tanhl: + case LibFunc_tanl: + case LibFunc_trunc: + case LibFunc_truncf: + case LibFunc_truncl: + return (NumParams == 1 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc_atan2: + case LibFunc_atan2_finite: + case LibFunc_atan2f: + case LibFunc_atan2f_finite: + case LibFunc_atan2l: + case LibFunc_atan2l_finite: + case LibFunc_fmin: + case LibFunc_fminf: + case LibFunc_fminl: + case LibFunc_fmax: + case LibFunc_fmaxf: + case LibFunc_fmaxl: + case LibFunc_fmod: + case LibFunc_fmodf: + case LibFunc_fmodl: + case LibFunc_copysign: + case LibFunc_copysignf: + case LibFunc_copysignl: + case LibFunc_pow: + case LibFunc_pow_finite: + case LibFunc_powf: + case LibFunc_powf_finite: + case LibFunc_powl: + case LibFunc_powl_finite: + return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getReturnType() == FTy.getParamType(1)); + + case LibFunc_ldexp: + case LibFunc_ldexpf: + case LibFunc_ldexpl: + return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(1)->isIntegerTy(32)); + + case LibFunc_ffs: + case LibFunc_ffsl: + case LibFunc_ffsll: + case LibFunc_fls: + case LibFunc_flsl: + case LibFunc_flsll: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isIntegerTy()); + + case LibFunc_isdigit: + case LibFunc_isascii: + case LibFunc_toascii: + case LibFunc_putchar: + case LibFunc_putchar_unlocked: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc_abs: + case LibFunc_labs: + case LibFunc_llabs: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc_cxa_atexit: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + + case LibFunc_sinpi: + case LibFunc_cospi: + return (NumParams == 1 && FTy.getReturnType()->isDoubleTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc_sinpif: + case LibFunc_cospif: + return (NumParams == 1 && FTy.getReturnType()->isFloatTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc_strnlen: + return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy && + FTy.getParamType(1) == SizeTTy); + + case LibFunc_posix_memalign: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1) == SizeTTy && FTy.getParamType(2) == SizeTTy); + + case LibFunc_wcslen: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc_cabs: + case LibFunc_cabsf: + case LibFunc_cabsl: { + Type* RetTy = FTy.getReturnType(); + if (!RetTy->isFloatingPointTy()) + return false; + + // NOTE: These prototypes are target specific and currently support + // "complex" passed as an array or discrete real & imaginary parameters. + // Add other calling conventions to enable libcall optimizations. + if (NumParams == 1) + return (FTy.getParamType(0)->isArrayTy() && + FTy.getParamType(0)->getArrayNumElements() == 2 && + FTy.getParamType(0)->getArrayElementType() == RetTy); + else if (NumParams == 2) + return (FTy.getParamType(0) == RetTy && FTy.getParamType(1) == RetTy); + else + return false; + } + case LibFunc::NumLibFuncs: + case LibFunc::NotLibFunc: + break; + } + + llvm_unreachable("Invalid libfunc"); +} + +bool TargetLibraryInfoImpl::getLibFunc(const Function &FDecl, + LibFunc &F) const { + // Intrinsics don't overlap w/libcalls; if our module has a large number of + // intrinsics, this ends up being an interesting compile time win since we + // avoid string normalization and comparison. + if (FDecl.isIntrinsic()) return false; + + const DataLayout *DL = + FDecl.getParent() ? &FDecl.getParent()->getDataLayout() : nullptr; + return getLibFunc(FDecl.getName(), F) && + isValidProtoForLibFunc(*FDecl.getFunctionType(), F, DL); +} + +void TargetLibraryInfoImpl::disableAllFunctions() { + memset(AvailableArray, 0, sizeof(AvailableArray)); +} + +static bool compareByScalarFnName(const VecDesc &LHS, const VecDesc &RHS) { + return LHS.ScalarFnName < RHS.ScalarFnName; +} + +static bool compareByVectorFnName(const VecDesc &LHS, const VecDesc &RHS) { + return LHS.VectorFnName < RHS.VectorFnName; +} + +static bool compareWithScalarFnName(const VecDesc &LHS, StringRef S) { + return LHS.ScalarFnName < S; +} + +static bool compareWithVectorFnName(const VecDesc &LHS, StringRef S) { + return LHS.VectorFnName < S; +} + +void TargetLibraryInfoImpl::addVectorizableFunctions(ArrayRef<VecDesc> Fns) { + VectorDescs.insert(VectorDescs.end(), Fns.begin(), Fns.end()); + llvm::sort(VectorDescs, compareByScalarFnName); + + ScalarDescs.insert(ScalarDescs.end(), Fns.begin(), Fns.end()); + llvm::sort(ScalarDescs, compareByVectorFnName); +} + +void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( + enum VectorLibrary VecLib) { + switch (VecLib) { + case Accelerate: { + const VecDesc VecFuncs[] = { + #define TLI_DEFINE_ACCELERATE_VECFUNCS + #include "llvm/Analysis/VecFuncs.def" + }; + addVectorizableFunctions(VecFuncs); + break; + } + case MASSV: { + const VecDesc VecFuncs[] = { + #define TLI_DEFINE_MASSV_VECFUNCS + #include "llvm/Analysis/VecFuncs.def" + }; + addVectorizableFunctions(VecFuncs); + break; + } + case SVML: { + const VecDesc VecFuncs[] = { + #define TLI_DEFINE_SVML_VECFUNCS + #include "llvm/Analysis/VecFuncs.def" + }; + addVectorizableFunctions(VecFuncs); + break; + } + case NoLibrary: + break; + } +} + +bool TargetLibraryInfoImpl::isFunctionVectorizable(StringRef funcName) const { + funcName = sanitizeFunctionName(funcName); + if (funcName.empty()) + return false; + + std::vector<VecDesc>::const_iterator I = + llvm::lower_bound(VectorDescs, funcName, compareWithScalarFnName); + return I != VectorDescs.end() && StringRef(I->ScalarFnName) == funcName; +} + +StringRef TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, + unsigned VF) const { + F = sanitizeFunctionName(F); + if (F.empty()) + return F; + std::vector<VecDesc>::const_iterator I = + llvm::lower_bound(VectorDescs, F, compareWithScalarFnName); + while (I != VectorDescs.end() && StringRef(I->ScalarFnName) == F) { + if (I->VectorizationFactor == VF) + return I->VectorFnName; + ++I; + } + return StringRef(); +} + +StringRef TargetLibraryInfoImpl::getScalarizedFunction(StringRef F, + unsigned &VF) const { + F = sanitizeFunctionName(F); + if (F.empty()) + return F; + + std::vector<VecDesc>::const_iterator I = + llvm::lower_bound(ScalarDescs, F, compareWithVectorFnName); + if (I == VectorDescs.end() || StringRef(I->VectorFnName) != F) + return StringRef(); + VF = I->VectorizationFactor; + return I->ScalarFnName; +} + +TargetLibraryInfo TargetLibraryAnalysis::run(Function &F, + FunctionAnalysisManager &) { + if (PresetInfoImpl) + return TargetLibraryInfo(*PresetInfoImpl); + + return TargetLibraryInfo( + lookupInfoImpl(Triple(F.getParent()->getTargetTriple()))); +} + +TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(const Triple &T) { + std::unique_ptr<TargetLibraryInfoImpl> &Impl = + Impls[T.normalize()]; + if (!Impl) + Impl.reset(new TargetLibraryInfoImpl(T)); + + return *Impl; +} + +unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const { + if (auto *ShortWChar = cast_or_null<ConstantAsMetadata>( + M.getModuleFlag("wchar_size"))) + return cast<ConstantInt>(ShortWChar->getValue())->getZExtValue(); + return 0; +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass() + : ImmutablePass(ID), TLIImpl(), TLI(TLIImpl) { + initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass(const Triple &T) + : ImmutablePass(ID), TLIImpl(T), TLI(TLIImpl) { + initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass( + const TargetLibraryInfoImpl &TLIImpl) + : ImmutablePass(ID), TLIImpl(TLIImpl), TLI(this->TLIImpl) { + initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +AnalysisKey TargetLibraryAnalysis::Key; + +// Register the basic pass. +INITIALIZE_PASS(TargetLibraryInfoWrapperPass, "targetlibinfo", + "Target Library Information", false, true) +char TargetLibraryInfoWrapperPass::ID = 0; + +void TargetLibraryInfoWrapperPass::anchor() {} diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp new file mode 100644 index 000000000000..c9c294873ea6 --- /dev/null +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -0,0 +1,1386 @@ +//===- llvm/Analysis/TargetTransformInfo.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/TargetTransformInfoImpl.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/LoopIterator.h" +#include <utility> + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "tti" + +static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), + cl::Hidden, + cl::desc("Recognize reduction patterns.")); + +namespace { +/// No-op implementation of the TTI interface using the utility base +/// classes. +/// +/// This is used when no target specific information is available. +struct NoTTIImpl : TargetTransformInfoImplCRTPBase<NoTTIImpl> { + explicit NoTTIImpl(const DataLayout &DL) + : TargetTransformInfoImplCRTPBase<NoTTIImpl>(DL) {} +}; +} + +bool HardwareLoopInfo::canAnalyze(LoopInfo &LI) { + // If the loop has irreducible control flow, it can not be converted to + // Hardware loop. + LoopBlocksRPO RPOT(L); + RPOT.perform(&LI); + if (containsIrreducibleCFG<const BasicBlock *>(RPOT, LI)) + return false; + return true; +} + +bool HardwareLoopInfo::isHardwareLoopCandidate(ScalarEvolution &SE, + LoopInfo &LI, DominatorTree &DT, + bool ForceNestedLoop, + bool ForceHardwareLoopPHI) { + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + for (BasicBlock *BB : ExitingBlocks) { + // If we pass the updated counter back through a phi, we need to know + // which latch the updated value will be coming from. + if (!L->isLoopLatch(BB)) { + if (ForceHardwareLoopPHI || CounterInReg) + continue; + } + + const SCEV *EC = SE.getExitCount(L, BB); + if (isa<SCEVCouldNotCompute>(EC)) + continue; + if (const SCEVConstant *ConstEC = dyn_cast<SCEVConstant>(EC)) { + if (ConstEC->getValue()->isZero()) + continue; + } else if (!SE.isLoopInvariant(EC, L)) + continue; + + if (SE.getTypeSizeInBits(EC->getType()) > CountType->getBitWidth()) + continue; + + // If this exiting block is contained in a nested loop, it is not eligible + // for insertion of the branch-and-decrement since the inner loop would + // end up messing up the value in the CTR. + if (!IsNestingLegal && LI.getLoopFor(BB) != L && !ForceNestedLoop) + continue; + + // We now have a loop-invariant count of loop iterations (which is not the + // constant zero) for which we know that this loop will not exit via this + // existing block. + + // We need to make sure that this block will run on every loop iteration. + // For this to be true, we must dominate all blocks with backedges. Such + // blocks are in-loop predecessors to the header block. + bool NotAlways = false; + for (BasicBlock *Pred : predecessors(L->getHeader())) { + if (!L->contains(Pred)) + continue; + + if (!DT.dominates(BB, Pred)) { + NotAlways = true; + break; + } + } + + if (NotAlways) + continue; + + // Make sure this blocks ends with a conditional branch. + Instruction *TI = BB->getTerminator(); + if (!TI) + continue; + + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (!BI->isConditional()) + continue; + + ExitBranch = BI; + } else + continue; + + // Note that this block may not be the loop latch block, even if the loop + // has a latch block. + ExitBlock = BB; + ExitCount = EC; + break; + } + + if (!ExitBlock) + return false; + return true; +} + +TargetTransformInfo::TargetTransformInfo(const DataLayout &DL) + : TTIImpl(new Model<NoTTIImpl>(NoTTIImpl(DL))) {} + +TargetTransformInfo::~TargetTransformInfo() {} + +TargetTransformInfo::TargetTransformInfo(TargetTransformInfo &&Arg) + : TTIImpl(std::move(Arg.TTIImpl)) {} + +TargetTransformInfo &TargetTransformInfo::operator=(TargetTransformInfo &&RHS) { + TTIImpl = std::move(RHS.TTIImpl); + return *this; +} + +int TargetTransformInfo::getOperationCost(unsigned Opcode, Type *Ty, + Type *OpTy) const { + int Cost = TTIImpl->getOperationCost(Opcode, Ty, OpTy); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCallCost(FunctionType *FTy, int NumArgs, + const User *U) const { + int Cost = TTIImpl->getCallCost(FTy, NumArgs, U); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCallCost(const Function *F, + ArrayRef<const Value *> Arguments, + const User *U) const { + int Cost = TTIImpl->getCallCost(F, Arguments, U); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +unsigned TargetTransformInfo::getInliningThresholdMultiplier() const { + return TTIImpl->getInliningThresholdMultiplier(); +} + +int TargetTransformInfo::getInlinerVectorBonusPercent() const { + return TTIImpl->getInlinerVectorBonusPercent(); +} + +int TargetTransformInfo::getGEPCost(Type *PointeeType, const Value *Ptr, + ArrayRef<const Value *> Operands) const { + return TTIImpl->getGEPCost(PointeeType, Ptr, Operands); +} + +int TargetTransformInfo::getExtCost(const Instruction *I, + const Value *Src) const { + return TTIImpl->getExtCost(I, Src); +} + +int TargetTransformInfo::getIntrinsicCost( + Intrinsic::ID IID, Type *RetTy, ArrayRef<const Value *> Arguments, + const User *U) const { + int Cost = TTIImpl->getIntrinsicCost(IID, RetTy, Arguments, U); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +unsigned +TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JTSize) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +} + +int TargetTransformInfo::getUserCost(const User *U, + ArrayRef<const Value *> Operands) const { + int Cost = TTIImpl->getUserCost(U, Operands); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +bool TargetTransformInfo::hasBranchDivergence() const { + return TTIImpl->hasBranchDivergence(); +} + +bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { + return TTIImpl->isSourceOfDivergence(V); +} + +bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const { + return TTIImpl->isAlwaysUniform(V); +} + +unsigned TargetTransformInfo::getFlatAddressSpace() const { + return TTIImpl->getFlatAddressSpace(); +} + +bool TargetTransformInfo::collectFlatAddressOperands( + SmallVectorImpl<int> &OpIndexes, Intrinsic::ID IID) const { + return TTIImpl->collectFlatAddressOperands(OpIndexes, IID); +} + +bool TargetTransformInfo::rewriteIntrinsicWithAddressSpace( + IntrinsicInst *II, Value *OldV, Value *NewV) const { + return TTIImpl->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); +} + +bool TargetTransformInfo::isLoweredToCall(const Function *F) const { + return TTIImpl->isLoweredToCall(F); +} + +bool TargetTransformInfo::isHardwareLoopProfitable( + Loop *L, ScalarEvolution &SE, AssumptionCache &AC, + TargetLibraryInfo *LibInfo, HardwareLoopInfo &HWLoopInfo) const { + return TTIImpl->isHardwareLoopProfitable(L, SE, AC, LibInfo, HWLoopInfo); +} + +void TargetTransformInfo::getUnrollingPreferences( + Loop *L, ScalarEvolution &SE, UnrollingPreferences &UP) const { + return TTIImpl->getUnrollingPreferences(L, SE, UP); +} + +bool TargetTransformInfo::isLegalAddImmediate(int64_t Imm) const { + return TTIImpl->isLegalAddImmediate(Imm); +} + +bool TargetTransformInfo::isLegalICmpImmediate(int64_t Imm) const { + return TTIImpl->isLegalICmpImmediate(Imm); +} + +bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, + bool HasBaseReg, + int64_t Scale, + unsigned AddrSpace, + Instruction *I) const { + return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, + Scale, AddrSpace, I); +} + +bool TargetTransformInfo::isLSRCostLess(LSRCost &C1, LSRCost &C2) const { + return TTIImpl->isLSRCostLess(C1, C2); +} + +bool TargetTransformInfo::canMacroFuseCmp() const { + return TTIImpl->canMacroFuseCmp(); +} + +bool TargetTransformInfo::canSaveCmp(Loop *L, BranchInst **BI, + ScalarEvolution *SE, LoopInfo *LI, + DominatorTree *DT, AssumptionCache *AC, + TargetLibraryInfo *LibInfo) const { + return TTIImpl->canSaveCmp(L, BI, SE, LI, DT, AC, LibInfo); +} + +bool TargetTransformInfo::shouldFavorPostInc() const { + return TTIImpl->shouldFavorPostInc(); +} + +bool TargetTransformInfo::shouldFavorBackedgeIndex(const Loop *L) const { + return TTIImpl->shouldFavorBackedgeIndex(L); +} + +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, + MaybeAlign Alignment) const { + return TTIImpl->isLegalMaskedStore(DataType, Alignment); +} + +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, + MaybeAlign Alignment) const { + return TTIImpl->isLegalMaskedLoad(DataType, Alignment); +} + +bool TargetTransformInfo::isLegalNTStore(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalNTStore(DataType, Alignment); +} + +bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const { + return TTIImpl->isLegalNTLoad(DataType, Alignment); +} + +bool TargetTransformInfo::isLegalMaskedGather(Type *DataType) const { + return TTIImpl->isLegalMaskedGather(DataType); +} + +bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType) const { + return TTIImpl->isLegalMaskedScatter(DataType); +} + +bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const { + return TTIImpl->isLegalMaskedCompressStore(DataType); +} + +bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const { + return TTIImpl->isLegalMaskedExpandLoad(DataType); +} + +bool TargetTransformInfo::hasDivRemOp(Type *DataType, bool IsSigned) const { + return TTIImpl->hasDivRemOp(DataType, IsSigned); +} + +bool TargetTransformInfo::hasVolatileVariant(Instruction *I, + unsigned AddrSpace) const { + return TTIImpl->hasVolatileVariant(I, AddrSpace); +} + +bool TargetTransformInfo::prefersVectorizedAddressing() const { + return TTIImpl->prefersVectorizedAddressing(); +} + +int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, + bool HasBaseReg, + int64_t Scale, + unsigned AddrSpace) const { + int Cost = TTIImpl->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, + Scale, AddrSpace); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +bool TargetTransformInfo::LSRWithInstrQueries() const { + return TTIImpl->LSRWithInstrQueries(); +} + +bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const { + return TTIImpl->isTruncateFree(Ty1, Ty2); +} + +bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const { + return TTIImpl->isProfitableToHoist(I); +} + +bool TargetTransformInfo::useAA() const { return TTIImpl->useAA(); } + +bool TargetTransformInfo::isTypeLegal(Type *Ty) const { + return TTIImpl->isTypeLegal(Ty); +} + +bool TargetTransformInfo::shouldBuildLookupTables() const { + return TTIImpl->shouldBuildLookupTables(); +} +bool TargetTransformInfo::shouldBuildLookupTablesForConstant(Constant *C) const { + return TTIImpl->shouldBuildLookupTablesForConstant(C); +} + +bool TargetTransformInfo::useColdCCForColdCall(Function &F) const { + return TTIImpl->useColdCCForColdCall(F); +} + +unsigned TargetTransformInfo:: +getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const { + return TTIImpl->getScalarizationOverhead(Ty, Insert, Extract); +} + +unsigned TargetTransformInfo:: +getOperandsScalarizationOverhead(ArrayRef<const Value *> Args, + unsigned VF) const { + return TTIImpl->getOperandsScalarizationOverhead(Args, VF); +} + +bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const { + return TTIImpl->supportsEfficientVectorElementLoadStore(); +} + +bool TargetTransformInfo::enableAggressiveInterleaving(bool LoopHasReductions) const { + return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); +} + +TargetTransformInfo::MemCmpExpansionOptions +TargetTransformInfo::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { + return TTIImpl->enableMemCmpExpansion(OptSize, IsZeroCmp); +} + +bool TargetTransformInfo::enableInterleavedAccessVectorization() const { + return TTIImpl->enableInterleavedAccessVectorization(); +} + +bool TargetTransformInfo::enableMaskedInterleavedAccessVectorization() const { + return TTIImpl->enableMaskedInterleavedAccessVectorization(); +} + +bool TargetTransformInfo::isFPVectorizationPotentiallyUnsafe() const { + return TTIImpl->isFPVectorizationPotentiallyUnsafe(); +} + +bool TargetTransformInfo::allowsMisalignedMemoryAccesses(LLVMContext &Context, + unsigned BitWidth, + unsigned AddressSpace, + unsigned Alignment, + bool *Fast) const { + return TTIImpl->allowsMisalignedMemoryAccesses(Context, BitWidth, AddressSpace, + Alignment, Fast); +} + +TargetTransformInfo::PopcntSupportKind +TargetTransformInfo::getPopcntSupport(unsigned IntTyWidthInBit) const { + return TTIImpl->getPopcntSupport(IntTyWidthInBit); +} + +bool TargetTransformInfo::haveFastSqrt(Type *Ty) const { + return TTIImpl->haveFastSqrt(Ty); +} + +bool TargetTransformInfo::isFCmpOrdCheaperThanFCmpZero(Type *Ty) const { + return TTIImpl->isFCmpOrdCheaperThanFCmpZero(Ty); +} + +int TargetTransformInfo::getFPOpCost(Type *Ty) const { + int Cost = TTIImpl->getFPOpCost(Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx, + const APInt &Imm, + Type *Ty) const { + int Cost = TTIImpl->getIntImmCodeSizeCost(Opcode, Idx, Imm, Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntImmCost(const APInt &Imm, Type *Ty) const { + int Cost = TTIImpl->getIntImmCost(Imm, Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntImmCost(unsigned Opcode, unsigned Idx, + const APInt &Imm, Type *Ty) const { + int Cost = TTIImpl->getIntImmCost(Opcode, Idx, Imm, Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntImmCost(Intrinsic::ID IID, unsigned Idx, + const APInt &Imm, Type *Ty) const { + int Cost = TTIImpl->getIntImmCost(IID, Idx, Imm, Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const { + return TTIImpl->getNumberOfRegisters(ClassID); +} + +unsigned TargetTransformInfo::getRegisterClassForType(bool Vector, Type *Ty) const { + return TTIImpl->getRegisterClassForType(Vector, Ty); +} + +const char* TargetTransformInfo::getRegisterClassName(unsigned ClassID) const { + return TTIImpl->getRegisterClassName(ClassID); +} + +unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { + return TTIImpl->getRegisterBitWidth(Vector); +} + +unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const { + return TTIImpl->getMinVectorRegisterBitWidth(); +} + +bool TargetTransformInfo::shouldMaximizeVectorBandwidth(bool OptSize) const { + return TTIImpl->shouldMaximizeVectorBandwidth(OptSize); +} + +unsigned TargetTransformInfo::getMinimumVF(unsigned ElemWidth) const { + return TTIImpl->getMinimumVF(ElemWidth); +} + +bool TargetTransformInfo::shouldConsiderAddressTypePromotion( + const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const { + return TTIImpl->shouldConsiderAddressTypePromotion( + I, AllowPromotionWithoutCommonHeader); +} + +unsigned TargetTransformInfo::getCacheLineSize() const { + return TTIImpl->getCacheLineSize(); +} + +llvm::Optional<unsigned> TargetTransformInfo::getCacheSize(CacheLevel Level) + const { + return TTIImpl->getCacheSize(Level); +} + +llvm::Optional<unsigned> TargetTransformInfo::getCacheAssociativity( + CacheLevel Level) const { + return TTIImpl->getCacheAssociativity(Level); +} + +unsigned TargetTransformInfo::getPrefetchDistance() const { + return TTIImpl->getPrefetchDistance(); +} + +unsigned TargetTransformInfo::getMinPrefetchStride() const { + return TTIImpl->getMinPrefetchStride(); +} + +unsigned TargetTransformInfo::getMaxPrefetchIterationsAhead() const { + return TTIImpl->getMaxPrefetchIterationsAhead(); +} + +unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { + return TTIImpl->getMaxInterleaveFactor(VF); +} + +TargetTransformInfo::OperandValueKind +TargetTransformInfo::getOperandInfo(Value *V, OperandValueProperties &OpProps) { + OperandValueKind OpInfo = OK_AnyValue; + OpProps = OP_None; + + if (auto *CI = dyn_cast<ConstantInt>(V)) { + if (CI->getValue().isPowerOf2()) + OpProps = OP_PowerOf2; + return OK_UniformConstantValue; + } + + // A broadcast shuffle creates a uniform value. + // TODO: Add support for non-zero index broadcasts. + // TODO: Add support for different source vector width. + if (auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V)) + if (ShuffleInst->isZeroEltSplat()) + OpInfo = OK_UniformValue; + + const Value *Splat = getSplatValue(V); + + // Check for a splat of a constant or for a non uniform vector of constants + // and check if the constant(s) are all powers of two. + if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { + OpInfo = OK_NonUniformConstantValue; + if (Splat) { + OpInfo = OK_UniformConstantValue; + if (auto *CI = dyn_cast<ConstantInt>(Splat)) + if (CI->getValue().isPowerOf2()) + OpProps = OP_PowerOf2; + } else if (auto *CDS = dyn_cast<ConstantDataSequential>(V)) { + OpProps = OP_PowerOf2; + for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { + if (auto *CI = dyn_cast<ConstantInt>(CDS->getElementAsConstant(I))) + if (CI->getValue().isPowerOf2()) + continue; + OpProps = OP_None; + break; + } + } + } + + // Check for a splat of a uniform value. This is not loop aware, so return + // true only for the obviously uniform cases (argument, globalvalue) + if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat))) + OpInfo = OK_UniformValue; + + return OpInfo; +} + +int TargetTransformInfo::getArithmeticInstrCost( + unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, + OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, + OperandValueProperties Opd2PropInfo, + ArrayRef<const Value *> Args) const { + int Cost = TTIImpl->getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info, + Opd1PropInfo, Opd2PropInfo, Args); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Ty, int Index, + Type *SubTp) const { + int Cost = TTIImpl->getShuffleCost(Kind, Ty, Index, SubTp); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, + Type *Src, const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) const { + int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { + int Cost = TTIImpl->getCFInstrCost(Opcode); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, + Type *CondTy, const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy, I); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getVectorInstrCost(unsigned Opcode, Type *Val, + unsigned Index) const { + int Cost = TTIImpl->getVectorInstrCost(Opcode, Val, Index); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getMemoryOpCost(unsigned Opcode, Type *Src, + unsigned Alignment, + unsigned AddressSpace, + const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, + unsigned Alignment, + unsigned AddressSpace) const { + int Cost = + TTIImpl->getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getGatherScatterOpCost(unsigned Opcode, Type *DataTy, + Value *Ptr, bool VariableMask, + unsigned Alignment) const { + int Cost = TTIImpl->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, + Alignment); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getInterleavedMemoryOpCost( + unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, + unsigned Alignment, unsigned AddressSpace, bool UseMaskForCond, + bool UseMaskForGaps) const { + int Cost = TTIImpl->getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, + Alignment, AddressSpace, + UseMaskForCond, + UseMaskForGaps); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, + ArrayRef<Type *> Tys, FastMathFlags FMF, + unsigned ScalarizationCostPassed) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF, + ScalarizationCostPassed); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, + ArrayRef<Value *> Args, FastMathFlags FMF, unsigned VF) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getCallInstrCost(Function *F, Type *RetTy, + ArrayRef<Type *> Tys) const { + int Cost = TTIImpl->getCallInstrCost(F, RetTy, Tys); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const { + return TTIImpl->getNumberOfParts(Tp); +} + +int TargetTransformInfo::getAddressComputationCost(Type *Tp, + ScalarEvolution *SE, + const SCEV *Ptr) const { + int Cost = TTIImpl->getAddressComputationCost(Tp, SE, Ptr); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getMemcpyCost(const Instruction *I) const { + int Cost = TTIImpl->getMemcpyCost(I); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const { + int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy, + bool IsPairwiseForm, + bool IsUnsigned) const { + int Cost = + TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + +unsigned +TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const { + return TTIImpl->getCostOfKeepingLiveOverCall(Tys); +} + +bool TargetTransformInfo::getTgtMemIntrinsic(IntrinsicInst *Inst, + MemIntrinsicInfo &Info) const { + return TTIImpl->getTgtMemIntrinsic(Inst, Info); +} + +unsigned TargetTransformInfo::getAtomicMemIntrinsicMaxElementSize() const { + return TTIImpl->getAtomicMemIntrinsicMaxElementSize(); +} + +Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic( + IntrinsicInst *Inst, Type *ExpectedType) const { + return TTIImpl->getOrCreateResultFromMemIntrinsic(Inst, ExpectedType); +} + +Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context, + Value *Length, + unsigned SrcAlign, + unsigned DestAlign) const { + return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAlign, + DestAlign); +} + +void TargetTransformInfo::getMemcpyLoopResidualLoweringType( + SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context, + unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const { + TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes, + SrcAlign, DestAlign); +} + +bool TargetTransformInfo::areInlineCompatible(const Function *Caller, + const Function *Callee) const { + return TTIImpl->areInlineCompatible(Caller, Callee); +} + +bool TargetTransformInfo::areFunctionArgsABICompatible( + const Function *Caller, const Function *Callee, + SmallPtrSetImpl<Argument *> &Args) const { + return TTIImpl->areFunctionArgsABICompatible(Caller, Callee, Args); +} + +bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, + Type *Ty) const { + return TTIImpl->isIndexedLoadLegal(Mode, Ty); +} + +bool TargetTransformInfo::isIndexedStoreLegal(MemIndexedMode Mode, + Type *Ty) const { + return TTIImpl->isIndexedStoreLegal(Mode, Ty); +} + +unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const { + return TTIImpl->getLoadStoreVecRegBitWidth(AS); +} + +bool TargetTransformInfo::isLegalToVectorizeLoad(LoadInst *LI) const { + return TTIImpl->isLegalToVectorizeLoad(LI); +} + +bool TargetTransformInfo::isLegalToVectorizeStore(StoreInst *SI) const { + return TTIImpl->isLegalToVectorizeStore(SI); +} + +bool TargetTransformInfo::isLegalToVectorizeLoadChain( + unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { + return TTIImpl->isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment, + AddrSpace); +} + +bool TargetTransformInfo::isLegalToVectorizeStoreChain( + unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { + return TTIImpl->isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, + AddrSpace); +} + +unsigned TargetTransformInfo::getLoadVectorFactor(unsigned VF, + unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return TTIImpl->getLoadVectorFactor(VF, LoadSize, ChainSizeInBytes, VecTy); +} + +unsigned TargetTransformInfo::getStoreVectorFactor(unsigned VF, + unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); +} + +bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode, + Type *Ty, ReductionFlags Flags) const { + return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); +} + +bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { + return TTIImpl->shouldExpandReduction(II); +} + +unsigned TargetTransformInfo::getGISelRematGlobalCost() const { + return TTIImpl->getGISelRematGlobalCost(); +} + +int TargetTransformInfo::getInstructionLatency(const Instruction *I) const { + return TTIImpl->getInstructionLatency(I); +} + +static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, + unsigned Level) { + // We don't need a shuffle if we just want to have element 0 in position 0 of + // the vector. + if (!SI && Level == 0 && IsLeft) + return true; + else if (!SI) + return false; + + SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1); + + // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether + // we look at the left or right side. + for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2) + Mask[i] = val; + + SmallVector<int, 16> ActualMask = SI->getShuffleMask(); + return Mask == ActualMask; +} + +namespace { +/// Kind of the reduction data. +enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_MinMax, /// Min/max reduction data. + RK_UnsignedMinMax, /// Unsigned min/max reduction data. +}; +/// Contains opcode + LHS/RHS parts of the reduction operations. +struct ReductionData { + ReductionData() = delete; + ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + assert(Kind != RK_None && "expected binary or min/max reduction only."); + } + unsigned Opcode = 0; + Value *LHS = nullptr; + Value *RHS = nullptr; + ReductionKind Kind = RK_None; + bool hasSameData(ReductionData &RD) const { + return Kind == RD.Kind && Opcode == RD.Opcode; + } +}; +} // namespace + +static Optional<ReductionData> getReductionData(Instruction *I) { + Value *L, *R; + if (m_BinOp(m_Value(L), m_Value(R)).match(I)) + return ReductionData(RK_Arithmetic, I->getOpcode(), L, R); + if (auto *SI = dyn_cast<SelectInst>(I)) { + if (m_SMin(m_Value(L), m_Value(R)).match(SI) || + m_SMax(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMin(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMax(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMin(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_MinMax, CI->getOpcode(), L, R); + } + if (m_UMin(m_Value(L), m_Value(R)).match(SI) || + m_UMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R); + } + } + return llvm::None; +} + +static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, + unsigned Level, + unsigned NumLevels) { + // Match one level of pairwise operations. + // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, + // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> + // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, + // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> + // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 + if (!I) + return RK_None; + + assert(I->getType()->isVectorTy() && "Expecting a vector type"); + + Optional<ReductionData> RD = getReductionData(I); + if (!RD) + return RK_None; + + ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS); + if (!LS && Level) + return RK_None; + ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS); + if (!RS && Level) + return RK_None; + + // On level 0 we can omit one shufflevector instruction. + if (!Level && !RS && !LS) + return RK_None; + + // Shuffle inputs must match. + Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; + Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr; + Value *NextLevelOp = nullptr; + if (NextLevelOpR && NextLevelOpL) { + // If we have two shuffles their operands must match. + if (NextLevelOpL != NextLevelOpR) + return RK_None; + + NextLevelOp = NextLevelOpL; + } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { + // On the first level we can omit the shufflevector <0, undef,...>. So the + // input to the other shufflevector <1, undef> must match with one of the + // inputs to the current binary operation. + // Example: + // %NextLevelOpL = shufflevector %R, <1, undef ...> + // %BinOp = fadd %NextLevelOpL, %R + if (NextLevelOpL && NextLevelOpL != RD->RHS) + return RK_None; + else if (NextLevelOpR && NextLevelOpR != RD->LHS) + return RK_None; + + NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; + } else + return RK_None; + + // Check that the next levels binary operation exists and matches with the + // current one. + if (Level + 1 != NumLevels) { + Optional<ReductionData> NextLevelRD = + getReductionData(cast<Instruction>(NextLevelOp)); + if (!NextLevelRD || !RD->hasSameData(*NextLevelRD)) + return RK_None; + } + + // Shuffle mask for pairwise operation must match. + if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) + return RK_None; + } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) + return RK_None; + } else { + return RK_None; + } + + if (++Level == NumLevels) + return RD->Kind; + + // Match next level. + return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level, + NumLevels); +} + +static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { + if (!EnableReduxCost) + return RK_None; + + // Need to extract the first element. + ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); + unsigned Idx = ~0u; + if (CI) + Idx = CI->getZExtValue(); + if (Idx != 0) + return RK_None; + + auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); + if (!RdxStart) + return RK_None; + Optional<ReductionData> RD = getReductionData(RdxStart); + if (!RD) + return RK_None; + + Type *VecTy = RdxStart->getType(); + unsigned NumVecElems = VecTy->getVectorNumElements(); + if (!isPowerOf2_32(NumVecElems)) + return RK_None; + + // We look for a sequence of shuffle,shuffle,add triples like the following + // that builds a pairwise reduction tree. + // + // (X0, X1, X2, X3) + // (X0 + X1, X2 + X3, undef, undef) + // ((X0 + X1) + (X2 + X3), undef, undef, undef) + // + // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, + // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> + // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, + // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> + // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 + // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, + // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef> + // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, + // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> + // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 + // %r = extractelement <4 x float> %bin.rdx8, i32 0 + if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) == + RK_None) + return RK_None; + + Opcode = RD->Opcode; + Ty = VecTy; + + return RD->Kind; +} + +static std::pair<Value *, ShuffleVectorInst *> +getShuffleAndOtherOprd(Value *L, Value *R) { + ShuffleVectorInst *S = nullptr; + + if ((S = dyn_cast<ShuffleVectorInst>(L))) + return std::make_pair(R, S); + + S = dyn_cast<ShuffleVectorInst>(R); + return std::make_pair(L, S); +} + +static ReductionKind +matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { + if (!EnableReduxCost) + return RK_None; + + // Need to extract the first element. + ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); + unsigned Idx = ~0u; + if (CI) + Idx = CI->getZExtValue(); + if (Idx != 0) + return RK_None; + + auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); + if (!RdxStart) + return RK_None; + Optional<ReductionData> RD = getReductionData(RdxStart); + if (!RD) + return RK_None; + + Type *VecTy = ReduxRoot->getOperand(0)->getType(); + unsigned NumVecElems = VecTy->getVectorNumElements(); + if (!isPowerOf2_32(NumVecElems)) + return RK_None; + + // We look for a sequence of shuffles and adds like the following matching one + // fadd, shuffle vector pair at a time. + // + // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef, + // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef> + // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf + // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef, + // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> + // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7 + // %r = extractelement <4 x float> %bin.rdx8, i32 0 + + unsigned MaskStart = 1; + Instruction *RdxOp = RdxStart; + SmallVector<int, 32> ShuffleMask(NumVecElems, 0); + unsigned NumVecElemsRemain = NumVecElems; + while (NumVecElemsRemain - 1) { + // Check for the right reduction operation. + if (!RdxOp) + return RK_None; + Optional<ReductionData> RDLevel = getReductionData(RdxOp); + if (!RDLevel || !RDLevel->hasSameData(*RD)) + return RK_None; + + Value *NextRdxOp; + ShuffleVectorInst *Shuffle; + std::tie(NextRdxOp, Shuffle) = + getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS); + + // Check the current reduction operation and the shuffle use the same value. + if (Shuffle == nullptr) + return RK_None; + if (Shuffle->getOperand(0) != NextRdxOp) + return RK_None; + + // Check that shuffle masks matches. + for (unsigned j = 0; j != MaskStart; ++j) + ShuffleMask[j] = MaskStart + j; + // Fill the rest of the mask with -1 for undef. + std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1); + + SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); + if (ShuffleMask != Mask) + return RK_None; + + RdxOp = dyn_cast<Instruction>(NextRdxOp); + NumVecElemsRemain /= 2; + MaskStart *= 2; + } + + Opcode = RD->Opcode; + Ty = VecTy; + return RD->Kind; +} + +int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { + switch (I->getOpcode()) { + case Instruction::GetElementPtr: + return getUserCost(I); + + case Instruction::Ret: + case Instruction::PHI: + case Instruction::Br: { + return getCFInstrCost(I->getOpcode()); + } + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + TargetTransformInfo::OperandValueKind Op1VK, Op2VK; + TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; + Op1VK = getOperandInfo(I->getOperand(0), Op1VP); + Op2VK = getOperandInfo(I->getOperand(1), Op2VP); + SmallVector<const Value *, 2> Operands(I->operand_values()); + return getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, Op2VK, + Op1VP, Op2VP, Operands); + } + case Instruction::FNeg: { + TargetTransformInfo::OperandValueKind Op1VK, Op2VK; + TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; + Op1VK = getOperandInfo(I->getOperand(0), Op1VP); + Op2VK = OK_AnyValue; + Op2VP = OP_None; + SmallVector<const Value *, 2> Operands(I->operand_values()); + return getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, Op2VK, + Op1VP, Op2VP, Operands); + } + case Instruction::Select: { + const SelectInst *SI = cast<SelectInst>(I); + Type *CondTy = SI->getCondition()->getType(); + return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, I); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = I->getOperand(0)->getType(); + return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), I); + } + case Instruction::Store: { + const StoreInst *SI = cast<StoreInst>(I); + Type *ValTy = SI->getValueOperand()->getType(); + return getMemoryOpCost(I->getOpcode(), ValTy, + SI->getAlignment(), + SI->getPointerAddressSpace(), I); + } + case Instruction::Load: { + const LoadInst *LI = cast<LoadInst>(I); + return getMemoryOpCost(I->getOpcode(), I->getType(), + LI->getAlignment(), + LI->getPointerAddressSpace(), I); + } + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: { + Type *SrcTy = I->getOperand(0)->getType(); + return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, I); + } + case Instruction::ExtractElement: { + const ExtractElementInst * EEI = cast<ExtractElementInst>(I); + ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); + unsigned Idx = -1; + if (CI) + Idx = CI->getZExtValue(); + + // Try to match a reduction sequence (series of shufflevector and vector + // adds followed by a extractelement). + unsigned ReduxOpCode; + Type *ReduxType; + + switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: + return getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false); + case RK_MinMax: + return getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); + case RK_None: + break; + } + + switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: + return getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true); + case RK_MinMax: + return getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/true); + case RK_None: + break; + } + + return getVectorInstrCost(I->getOpcode(), + EEI->getOperand(0)->getType(), Idx); + } + case Instruction::InsertElement: { + const InsertElementInst * IE = cast<InsertElementInst>(I); + ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); + unsigned Idx = -1; + if (CI) + Idx = CI->getZExtValue(); + return getVectorInstrCost(I->getOpcode(), + IE->getType(), Idx); + } + case Instruction::ExtractValue: + return 0; // Model all ExtractValue nodes as free. + case Instruction::ShuffleVector: { + const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); + Type *Ty = Shuffle->getType(); + Type *SrcTy = Shuffle->getOperand(0)->getType(); + + // TODO: Identify and add costs for insert subvector, etc. + int SubIndex; + if (Shuffle->isExtractSubvectorMask(SubIndex)) + return TTIImpl->getShuffleCost(SK_ExtractSubvector, SrcTy, SubIndex, Ty); + + if (Shuffle->changesLength()) + return -1; + + if (Shuffle->isIdentity()) + return 0; + + if (Shuffle->isReverse()) + return TTIImpl->getShuffleCost(SK_Reverse, Ty, 0, nullptr); + + if (Shuffle->isSelect()) + return TTIImpl->getShuffleCost(SK_Select, Ty, 0, nullptr); + + if (Shuffle->isTranspose()) + return TTIImpl->getShuffleCost(SK_Transpose, Ty, 0, nullptr); + + if (Shuffle->isZeroEltSplat()) + return TTIImpl->getShuffleCost(SK_Broadcast, Ty, 0, nullptr); + + if (Shuffle->isSingleSource()) + return TTIImpl->getShuffleCost(SK_PermuteSingleSrc, Ty, 0, nullptr); + + return TTIImpl->getShuffleCost(SK_PermuteTwoSrc, Ty, 0, nullptr); + } + case Instruction::Call: + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + SmallVector<Value *, 4> Args(II->arg_operands()); + + FastMathFlags FMF; + if (auto *FPMO = dyn_cast<FPMathOperator>(II)) + FMF = FPMO->getFastMathFlags(); + + return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), + Args, FMF); + } + return -1; + default: + // We don't have any information on this instruction. + return -1; + } +} + +TargetTransformInfo::Concept::~Concept() {} + +TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} + +TargetIRAnalysis::TargetIRAnalysis( + std::function<Result(const Function &)> TTICallback) + : TTICallback(std::move(TTICallback)) {} + +TargetIRAnalysis::Result TargetIRAnalysis::run(const Function &F, + FunctionAnalysisManager &) { + return TTICallback(F); +} + +AnalysisKey TargetIRAnalysis::Key; + +TargetIRAnalysis::Result TargetIRAnalysis::getDefaultTTI(const Function &F) { + return Result(F.getParent()->getDataLayout()); +} + +// Register the basic pass. +INITIALIZE_PASS(TargetTransformInfoWrapperPass, "tti", + "Target Transform Information", false, true) +char TargetTransformInfoWrapperPass::ID = 0; + +void TargetTransformInfoWrapperPass::anchor() {} + +TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass() + : ImmutablePass(ID) { + initializeTargetTransformInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass( + TargetIRAnalysis TIRA) + : ImmutablePass(ID), TIRA(std::move(TIRA)) { + initializeTargetTransformInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +TargetTransformInfo &TargetTransformInfoWrapperPass::getTTI(const Function &F) { + FunctionAnalysisManager DummyFAM; + TTI = TIRA.run(F, DummyFAM); + return *TTI; +} + +ImmutablePass * +llvm::createTargetTransformInfoWrapperPass(TargetIRAnalysis TIRA) { + return new TargetTransformInfoWrapperPass(std::move(TIRA)); +} diff --git a/llvm/lib/Analysis/Trace.cpp b/llvm/lib/Analysis/Trace.cpp new file mode 100644 index 000000000000..879c7172d038 --- /dev/null +++ b/llvm/lib/Analysis/Trace.cpp @@ -0,0 +1,53 @@ +//===- Trace.cpp - Implementation of Trace class --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This class represents a single trace of LLVM basic blocks. A trace is a +// single entry, multiple exit, region of code that is often hot. Trace-based +// optimizations treat traces almost like they are a large, strange, basic +// block: because the trace path is assumed to be hot, optimizations for the +// fall-through path are made at the expense of the non-fall-through paths. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Trace.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +Function *Trace::getFunction() const { + return getEntryBasicBlock()->getParent(); +} + +Module *Trace::getModule() const { + return getFunction()->getParent(); +} + +/// print - Write trace to output stream. +void Trace::print(raw_ostream &O) const { + Function *F = getFunction(); + O << "; Trace from function " << F->getName() << ", blocks:\n"; + for (const_iterator i = begin(), e = end(); i != e; ++i) { + O << "; "; + (*i)->printAsOperand(O, true, getModule()); + O << "\n"; + } + O << "; Trace parent function: \n" << *F; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +/// dump - Debugger convenience method; writes trace to standard error +/// output stream. +LLVM_DUMP_METHOD void Trace::dump() const { + print(dbgs()); +} +#endif diff --git a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp new file mode 100644 index 000000000000..3b9040aa0f52 --- /dev/null +++ b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -0,0 +1,740 @@ +//===- TypeBasedAliasAnalysis.cpp - Type-Based Alias Analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the TypeBasedAliasAnalysis pass, which implements +// metadata-based TBAA. +// +// In LLVM IR, memory does not have types, so LLVM's own type system is not +// suitable for doing TBAA. Instead, metadata is added to the IR to describe +// a type system of a higher level language. This can be used to implement +// typical C/C++ TBAA, but it can also be used to implement custom alias +// analysis behavior for other languages. +// +// We now support two types of metadata format: scalar TBAA and struct-path +// aware TBAA. After all testing cases are upgraded to use struct-path aware +// TBAA and we can auto-upgrade existing bc files, the support for scalar TBAA +// can be dropped. +// +// The scalar TBAA metadata format is very simple. TBAA MDNodes have up to +// three fields, e.g.: +// !0 = !{ !"an example type tree" } +// !1 = !{ !"int", !0 } +// !2 = !{ !"float", !0 } +// !3 = !{ !"const float", !2, i64 1 } +// +// The first field is an identity field. It can be any value, usually +// an MDString, which uniquely identifies the type. The most important +// name in the tree is the name of the root node. Two trees with +// different root node names are entirely disjoint, even if they +// have leaves with common names. +// +// The second field identifies the type's parent node in the tree, or +// is null or omitted for a root node. A type is considered to alias +// all of its descendants and all of its ancestors in the tree. Also, +// a type is considered to alias all types in other trees, so that +// bitcode produced from multiple front-ends is handled conservatively. +// +// If the third field is present, it's an integer which if equal to 1 +// indicates that the type is "constant" (meaning pointsToConstantMemory +// should return true; see +// http://llvm.org/docs/AliasAnalysis.html#OtherItfs). +// +// With struct-path aware TBAA, the MDNodes attached to an instruction using +// "!tbaa" are called path tag nodes. +// +// The path tag node has 4 fields with the last field being optional. +// +// The first field is the base type node, it can be a struct type node +// or a scalar type node. The second field is the access type node, it +// must be a scalar type node. The third field is the offset into the base type. +// The last field has the same meaning as the last field of our scalar TBAA: +// it's an integer which if equal to 1 indicates that the access is "constant". +// +// The struct type node has a name and a list of pairs, one pair for each member +// of the struct. The first element of each pair is a type node (a struct type +// node or a scalar type node), specifying the type of the member, the second +// element of each pair is the offset of the member. +// +// Given an example +// typedef struct { +// short s; +// } A; +// typedef struct { +// uint16_t s; +// A a; +// } B; +// +// For an access to B.a.s, we attach !5 (a path tag node) to the load/store +// instruction. The base type is !4 (struct B), the access type is !2 (scalar +// type short) and the offset is 4. +// +// !0 = !{!"Simple C/C++ TBAA"} +// !1 = !{!"omnipotent char", !0} // Scalar type node +// !2 = !{!"short", !1} // Scalar type node +// !3 = !{!"A", !2, i64 0} // Struct type node +// !4 = !{!"B", !2, i64 0, !3, i64 4} +// // Struct type node +// !5 = !{!4, !2, i64 4} // Path tag node +// +// The struct type nodes and the scalar type nodes form a type DAG. +// Root (!0) +// char (!1) -- edge to Root +// short (!2) -- edge to char +// A (!3) -- edge with offset 0 to short +// B (!4) -- edge with offset 0 to short and edge with offset 4 to A +// +// To check if two tags (tagX and tagY) can alias, we start from the base type +// of tagX, follow the edge with the correct offset in the type DAG and adjust +// the offset until we reach the base type of tagY or until we reach the Root +// node. +// If we reach the base type of tagY, compare the adjusted offset with +// offset of tagY, return Alias if the offsets are the same, return NoAlias +// otherwise. +// If we reach the Root node, perform the above starting from base type of tagY +// to see if we reach base type of tagX. +// +// If they have different roots, they're part of different potentially +// unrelated type systems, so we return Alias to be conservative. +// If neither node is an ancestor of the other and they have the same root, +// then we say NoAlias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include <cassert> +#include <cstdint> + +using namespace llvm; + +// A handy option for disabling TBAA functionality. The same effect can also be +// achieved by stripping the !tbaa tags from IR, but this option is sometimes +// more convenient. +static cl::opt<bool> EnableTBAA("enable-tbaa", cl::init(true), cl::Hidden); + +namespace { + +/// isNewFormatTypeNode - Return true iff the given type node is in the new +/// size-aware format. +static bool isNewFormatTypeNode(const MDNode *N) { + if (N->getNumOperands() < 3) + return false; + // In the old format the first operand is a string. + if (!isa<MDNode>(N->getOperand(0))) + return false; + return true; +} + +/// This is a simple wrapper around an MDNode which provides a higher-level +/// interface by hiding the details of how alias analysis information is encoded +/// in its operands. +template<typename MDNodeTy> +class TBAANodeImpl { + MDNodeTy *Node = nullptr; + +public: + TBAANodeImpl() = default; + explicit TBAANodeImpl(MDNodeTy *N) : Node(N) {} + + /// getNode - Get the MDNode for this TBAANode. + MDNodeTy *getNode() const { return Node; } + + /// isNewFormat - Return true iff the wrapped type node is in the new + /// size-aware format. + bool isNewFormat() const { return isNewFormatTypeNode(Node); } + + /// getParent - Get this TBAANode's Alias tree parent. + TBAANodeImpl<MDNodeTy> getParent() const { + if (isNewFormat()) + return TBAANodeImpl(cast<MDNodeTy>(Node->getOperand(0))); + + if (Node->getNumOperands() < 2) + return TBAANodeImpl<MDNodeTy>(); + MDNodeTy *P = dyn_cast_or_null<MDNodeTy>(Node->getOperand(1)); + if (!P) + return TBAANodeImpl<MDNodeTy>(); + // Ok, this node has a valid parent. Return it. + return TBAANodeImpl<MDNodeTy>(P); + } + + /// Test if this TBAANode represents a type for objects which are + /// not modified (by any means) in the context where this + /// AliasAnalysis is relevant. + bool isTypeImmutable() const { + if (Node->getNumOperands() < 3) + return false; + ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(2)); + if (!CI) + return false; + return CI->getValue()[0]; + } +}; + +/// \name Specializations of \c TBAANodeImpl for const and non const qualified +/// \c MDNode. +/// @{ +using TBAANode = TBAANodeImpl<const MDNode>; +using MutableTBAANode = TBAANodeImpl<MDNode>; +/// @} + +/// This is a simple wrapper around an MDNode which provides a +/// higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +template<typename MDNodeTy> +class TBAAStructTagNodeImpl { + /// This node should be created with createTBAAAccessTag(). + MDNodeTy *Node; + +public: + explicit TBAAStructTagNodeImpl(MDNodeTy *N) : Node(N) {} + + /// Get the MDNode for this TBAAStructTagNode. + MDNodeTy *getNode() const { return Node; } + + /// isNewFormat - Return true iff the wrapped access tag is in the new + /// size-aware format. + bool isNewFormat() const { + if (Node->getNumOperands() < 4) + return false; + if (MDNodeTy *AccessType = getAccessType()) + if (!TBAANodeImpl<MDNodeTy>(AccessType).isNewFormat()) + return false; + return true; + } + + MDNodeTy *getBaseType() const { + return dyn_cast_or_null<MDNode>(Node->getOperand(0)); + } + + MDNodeTy *getAccessType() const { + return dyn_cast_or_null<MDNode>(Node->getOperand(1)); + } + + uint64_t getOffset() const { + return mdconst::extract<ConstantInt>(Node->getOperand(2))->getZExtValue(); + } + + uint64_t getSize() const { + if (!isNewFormat()) + return UINT64_MAX; + return mdconst::extract<ConstantInt>(Node->getOperand(3))->getZExtValue(); + } + + /// Test if this TBAAStructTagNode represents a type for objects + /// which are not modified (by any means) in the context where this + /// AliasAnalysis is relevant. + bool isTypeImmutable() const { + unsigned OpNo = isNewFormat() ? 4 : 3; + if (Node->getNumOperands() < OpNo + 1) + return false; + ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpNo)); + if (!CI) + return false; + return CI->getValue()[0]; + } +}; + +/// \name Specializations of \c TBAAStructTagNodeImpl for const and non const +/// qualified \c MDNods. +/// @{ +using TBAAStructTagNode = TBAAStructTagNodeImpl<const MDNode>; +using MutableTBAAStructTagNode = TBAAStructTagNodeImpl<MDNode>; +/// @} + +/// This is a simple wrapper around an MDNode which provides a +/// higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +class TBAAStructTypeNode { + /// This node should be created with createTBAATypeNode(). + const MDNode *Node = nullptr; + +public: + TBAAStructTypeNode() = default; + explicit TBAAStructTypeNode(const MDNode *N) : Node(N) {} + + /// Get the MDNode for this TBAAStructTypeNode. + const MDNode *getNode() const { return Node; } + + /// isNewFormat - Return true iff the wrapped type node is in the new + /// size-aware format. + bool isNewFormat() const { return isNewFormatTypeNode(Node); } + + bool operator==(const TBAAStructTypeNode &Other) const { + return getNode() == Other.getNode(); + } + + /// getId - Return type identifier. + Metadata *getId() const { + return Node->getOperand(isNewFormat() ? 2 : 0); + } + + unsigned getNumFields() const { + unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; + unsigned NumOpsPerField = isNewFormat() ? 3 : 2; + return (getNode()->getNumOperands() - FirstFieldOpNo) / NumOpsPerField; + } + + TBAAStructTypeNode getFieldType(unsigned FieldIndex) const { + unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; + unsigned NumOpsPerField = isNewFormat() ? 3 : 2; + unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField; + auto *TypeNode = cast<MDNode>(getNode()->getOperand(OpIndex)); + return TBAAStructTypeNode(TypeNode); + } + + /// Get this TBAAStructTypeNode's field in the type DAG with + /// given offset. Update the offset to be relative to the field type. + TBAAStructTypeNode getField(uint64_t &Offset) const { + bool NewFormat = isNewFormat(); + if (NewFormat) { + // New-format root and scalar type nodes have no fields. + if (Node->getNumOperands() < 6) + return TBAAStructTypeNode(); + } else { + // Parent can be omitted for the root node. + if (Node->getNumOperands() < 2) + return TBAAStructTypeNode(); + + // Fast path for a scalar type node and a struct type node with a single + // field. + if (Node->getNumOperands() <= 3) { + uint64_t Cur = Node->getNumOperands() == 2 + ? 0 + : mdconst::extract<ConstantInt>(Node->getOperand(2)) + ->getZExtValue(); + Offset -= Cur; + MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(1)); + if (!P) + return TBAAStructTypeNode(); + return TBAAStructTypeNode(P); + } + } + + // Assume the offsets are in order. We return the previous field if + // the current offset is bigger than the given offset. + unsigned FirstFieldOpNo = NewFormat ? 3 : 1; + unsigned NumOpsPerField = NewFormat ? 3 : 2; + unsigned TheIdx = 0; + for (unsigned Idx = FirstFieldOpNo; Idx < Node->getNumOperands(); + Idx += NumOpsPerField) { + uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(Idx + 1)) + ->getZExtValue(); + if (Cur > Offset) { + assert(Idx >= FirstFieldOpNo + NumOpsPerField && + "TBAAStructTypeNode::getField should have an offset match!"); + TheIdx = Idx - NumOpsPerField; + break; + } + } + // Move along the last field. + if (TheIdx == 0) + TheIdx = Node->getNumOperands() - NumOpsPerField; + uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(TheIdx + 1)) + ->getZExtValue(); + Offset -= Cur; + MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(TheIdx)); + if (!P) + return TBAAStructTypeNode(); + return TBAAStructTypeNode(P); + } +}; + +} // end anonymous namespace + +/// Check the first operand of the tbaa tag node, if it is a MDNode, we treat +/// it as struct-path aware TBAA format, otherwise, we treat it as scalar TBAA +/// format. +static bool isStructPathTBAA(const MDNode *MD) { + // Anonymous TBAA root starts with a MDNode and dragonegg uses it as + // a TBAA tag. + return isa<MDNode>(MD->getOperand(0)) && MD->getNumOperands() >= 3; +} + +AliasResult TypeBasedAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB, + AAQueryInfo &AAQI) { + if (!EnableTBAA) + return AAResultBase::alias(LocA, LocB, AAQI); + + // If accesses may alias, chain to the next AliasAnalysis. + if (Aliases(LocA.AATags.TBAA, LocB.AATags.TBAA)) + return AAResultBase::alias(LocA, LocB, AAQI); + + // Otherwise return a definitive result. + return NoAlias; +} + +bool TypeBasedAAResult::pointsToConstantMemory(const MemoryLocation &Loc, + AAQueryInfo &AAQI, + bool OrLocal) { + if (!EnableTBAA) + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + + const MDNode *M = Loc.AATags.TBAA; + if (!M) + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); + + // If this is an "immutable" type, we can assume the pointer is pointing + // to constant memory. + if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) || + (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable())) + return true; + + return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal); +} + +FunctionModRefBehavior +TypeBasedAAResult::getModRefBehavior(const CallBase *Call) { + if (!EnableTBAA) + return AAResultBase::getModRefBehavior(Call); + + FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + + // If this is an "immutable" type, we can assume the call doesn't write + // to memory. + if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa)) + if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) || + (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable())) + Min = FMRB_OnlyReadsMemory; + + return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min); +} + +FunctionModRefBehavior TypeBasedAAResult::getModRefBehavior(const Function *F) { + // Functions don't have metadata. Just chain to the next implementation. + return AAResultBase::getModRefBehavior(F); +} + +ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call, + const MemoryLocation &Loc, + AAQueryInfo &AAQI) { + if (!EnableTBAA) + return AAResultBase::getModRefInfo(Call, Loc, AAQI); + + if (const MDNode *L = Loc.AATags.TBAA) + if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa)) + if (!Aliases(L, M)) + return ModRefInfo::NoModRef; + + return AAResultBase::getModRefInfo(Call, Loc, AAQI); +} + +ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2, + AAQueryInfo &AAQI) { + if (!EnableTBAA) + return AAResultBase::getModRefInfo(Call1, Call2, AAQI); + + if (const MDNode *M1 = Call1->getMetadata(LLVMContext::MD_tbaa)) + if (const MDNode *M2 = Call2->getMetadata(LLVMContext::MD_tbaa)) + if (!Aliases(M1, M2)) + return ModRefInfo::NoModRef; + + return AAResultBase::getModRefInfo(Call1, Call2, AAQI); +} + +bool MDNode::isTBAAVtableAccess() const { + if (!isStructPathTBAA(this)) { + if (getNumOperands() < 1) + return false; + if (MDString *Tag1 = dyn_cast<MDString>(getOperand(0))) { + if (Tag1->getString() == "vtable pointer") + return true; + } + return false; + } + + // For struct-path aware TBAA, we use the access type of the tag. + TBAAStructTagNode Tag(this); + TBAAStructTypeNode AccessType(Tag.getAccessType()); + if(auto *Id = dyn_cast<MDString>(AccessType.getId())) + if (Id->getString() == "vtable pointer") + return true; + return false; +} + +static bool matchAccessTags(const MDNode *A, const MDNode *B, + const MDNode **GenericTag = nullptr); + +MDNode *MDNode::getMostGenericTBAA(MDNode *A, MDNode *B) { + const MDNode *GenericTag; + matchAccessTags(A, B, &GenericTag); + return const_cast<MDNode*>(GenericTag); +} + +static const MDNode *getLeastCommonType(const MDNode *A, const MDNode *B) { + if (!A || !B) + return nullptr; + + if (A == B) + return A; + + SmallSetVector<const MDNode *, 4> PathA; + TBAANode TA(A); + while (TA.getNode()) { + if (PathA.count(TA.getNode())) + report_fatal_error("Cycle found in TBAA metadata."); + PathA.insert(TA.getNode()); + TA = TA.getParent(); + } + + SmallSetVector<const MDNode *, 4> PathB; + TBAANode TB(B); + while (TB.getNode()) { + if (PathB.count(TB.getNode())) + report_fatal_error("Cycle found in TBAA metadata."); + PathB.insert(TB.getNode()); + TB = TB.getParent(); + } + + int IA = PathA.size() - 1; + int IB = PathB.size() - 1; + + const MDNode *Ret = nullptr; + while (IA >= 0 && IB >= 0) { + if (PathA[IA] == PathB[IB]) + Ret = PathA[IA]; + else + break; + --IA; + --IB; + } + + return Ret; +} + +void Instruction::getAAMetadata(AAMDNodes &N, bool Merge) const { + if (Merge) + N.TBAA = + MDNode::getMostGenericTBAA(N.TBAA, getMetadata(LLVMContext::MD_tbaa)); + else + N.TBAA = getMetadata(LLVMContext::MD_tbaa); + + if (Merge) + N.Scope = MDNode::getMostGenericAliasScope( + N.Scope, getMetadata(LLVMContext::MD_alias_scope)); + else + N.Scope = getMetadata(LLVMContext::MD_alias_scope); + + if (Merge) + N.NoAlias = + MDNode::intersect(N.NoAlias, getMetadata(LLVMContext::MD_noalias)); + else + N.NoAlias = getMetadata(LLVMContext::MD_noalias); +} + +static const MDNode *createAccessTag(const MDNode *AccessType) { + // If there is no access type or the access type is the root node, then + // we don't have any useful access tag to return. + if (!AccessType || AccessType->getNumOperands() < 2) + return nullptr; + + Type *Int64 = IntegerType::get(AccessType->getContext(), 64); + auto *OffsetNode = ConstantAsMetadata::get(ConstantInt::get(Int64, 0)); + + if (TBAAStructTypeNode(AccessType).isNewFormat()) { + // TODO: Take access ranges into account when matching access tags and + // fix this code to generate actual access sizes for generic tags. + uint64_t AccessSize = UINT64_MAX; + auto *SizeNode = + ConstantAsMetadata::get(ConstantInt::get(Int64, AccessSize)); + Metadata *Ops[] = {const_cast<MDNode*>(AccessType), + const_cast<MDNode*>(AccessType), + OffsetNode, SizeNode}; + return MDNode::get(AccessType->getContext(), Ops); + } + + Metadata *Ops[] = {const_cast<MDNode*>(AccessType), + const_cast<MDNode*>(AccessType), + OffsetNode}; + return MDNode::get(AccessType->getContext(), Ops); +} + +static bool hasField(TBAAStructTypeNode BaseType, + TBAAStructTypeNode FieldType) { + for (unsigned I = 0, E = BaseType.getNumFields(); I != E; ++I) { + TBAAStructTypeNode T = BaseType.getFieldType(I); + if (T == FieldType || hasField(T, FieldType)) + return true; + } + return false; +} + +/// Return true if for two given accesses, one of the accessed objects may be a +/// subobject of the other. The \p BaseTag and \p SubobjectTag parameters +/// describe the accesses to the base object and the subobject respectively. +/// \p CommonType must be the metadata node describing the common type of the +/// accessed objects. On return, \p MayAlias is set to true iff these accesses +/// may alias and \p Generic, if not null, points to the most generic access +/// tag for the given two. +static bool mayBeAccessToSubobjectOf(TBAAStructTagNode BaseTag, + TBAAStructTagNode SubobjectTag, + const MDNode *CommonType, + const MDNode **GenericTag, + bool &MayAlias) { + // If the base object is of the least common type, then this may be an access + // to its subobject. + if (BaseTag.getAccessType() == BaseTag.getBaseType() && + BaseTag.getAccessType() == CommonType) { + if (GenericTag) + *GenericTag = createAccessTag(CommonType); + MayAlias = true; + return true; + } + + // If the access to the base object is through a field of the subobject's + // type, then this may be an access to that field. To check for that we start + // from the base type, follow the edge with the correct offset in the type DAG + // and adjust the offset until we reach the field type or until we reach the + // access type. + bool NewFormat = BaseTag.isNewFormat(); + TBAAStructTypeNode BaseType(BaseTag.getBaseType()); + uint64_t OffsetInBase = BaseTag.getOffset(); + + for (;;) { + // In the old format there is no distinction between fields and parent + // types, so in this case we consider all nodes up to the root. + if (!BaseType.getNode()) { + assert(!NewFormat && "Did not see access type in access path!"); + break; + } + + if (BaseType.getNode() == SubobjectTag.getBaseType()) { + bool SameMemberAccess = OffsetInBase == SubobjectTag.getOffset(); + if (GenericTag) { + *GenericTag = SameMemberAccess ? SubobjectTag.getNode() : + createAccessTag(CommonType); + } + MayAlias = SameMemberAccess; + return true; + } + + // With new-format nodes we stop at the access type. + if (NewFormat && BaseType.getNode() == BaseTag.getAccessType()) + break; + + // Follow the edge with the correct offset. Offset will be adjusted to + // be relative to the field type. + BaseType = BaseType.getField(OffsetInBase); + } + + // If the base object has a direct or indirect field of the subobject's type, + // then this may be an access to that field. We need this to check now that + // we support aggregates as access types. + if (NewFormat) { + // TBAAStructTypeNode BaseAccessType(BaseTag.getAccessType()); + TBAAStructTypeNode FieldType(SubobjectTag.getBaseType()); + if (hasField(BaseType, FieldType)) { + if (GenericTag) + *GenericTag = createAccessTag(CommonType); + MayAlias = true; + return true; + } + } + + return false; +} + +/// matchTags - Return true if the given couple of accesses are allowed to +/// overlap. If \arg GenericTag is not null, then on return it points to the +/// most generic access descriptor for the given two. +static bool matchAccessTags(const MDNode *A, const MDNode *B, + const MDNode **GenericTag) { + if (A == B) { + if (GenericTag) + *GenericTag = A; + return true; + } + + // Accesses with no TBAA information may alias with any other accesses. + if (!A || !B) { + if (GenericTag) + *GenericTag = nullptr; + return true; + } + + // Verify that both input nodes are struct-path aware. Auto-upgrade should + // have taken care of this. + assert(isStructPathTBAA(A) && "Access A is not struct-path aware!"); + assert(isStructPathTBAA(B) && "Access B is not struct-path aware!"); + + TBAAStructTagNode TagA(A), TagB(B); + const MDNode *CommonType = getLeastCommonType(TagA.getAccessType(), + TagB.getAccessType()); + + // If the final access types have different roots, they're part of different + // potentially unrelated type systems, so we must be conservative. + if (!CommonType) { + if (GenericTag) + *GenericTag = nullptr; + return true; + } + + // If one of the accessed objects may be a subobject of the other, then such + // accesses may alias. + bool MayAlias; + if (mayBeAccessToSubobjectOf(/* BaseTag= */ TagA, /* SubobjectTag= */ TagB, + CommonType, GenericTag, MayAlias) || + mayBeAccessToSubobjectOf(/* BaseTag= */ TagB, /* SubobjectTag= */ TagA, + CommonType, GenericTag, MayAlias)) + return MayAlias; + + // Otherwise, we've proved there's no alias. + if (GenericTag) + *GenericTag = createAccessTag(CommonType); + return false; +} + +/// Aliases - Test whether the access represented by tag A may alias the +/// access represented by tag B. +bool TypeBasedAAResult::Aliases(const MDNode *A, const MDNode *B) const { + return matchAccessTags(A, B); +} + +AnalysisKey TypeBasedAA::Key; + +TypeBasedAAResult TypeBasedAA::run(Function &F, FunctionAnalysisManager &AM) { + return TypeBasedAAResult(); +} + +char TypeBasedAAWrapperPass::ID = 0; +INITIALIZE_PASS(TypeBasedAAWrapperPass, "tbaa", "Type-Based Alias Analysis", + false, true) + +ImmutablePass *llvm::createTypeBasedAAWrapperPass() { + return new TypeBasedAAWrapperPass(); +} + +TypeBasedAAWrapperPass::TypeBasedAAWrapperPass() : ImmutablePass(ID) { + initializeTypeBasedAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool TypeBasedAAWrapperPass::doInitialization(Module &M) { + Result.reset(new TypeBasedAAResult()); + return false; +} + +bool TypeBasedAAWrapperPass::doFinalization(Module &M) { + Result.reset(); + return false; +} + +void TypeBasedAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp new file mode 100644 index 000000000000..072d291f3f93 --- /dev/null +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -0,0 +1,161 @@ +//===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains functions that make it easier to manipulate type metadata +// for devirtualization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +// Search for virtual calls that call FPtr and add them to DevirtCalls. +static void +findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, + bool *HasNonCallUses, Value *FPtr, uint64_t Offset, + const CallInst *CI, DominatorTree &DT) { + for (const Use &U : FPtr->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + // Ignore this instruction if it is not dominated by the type intrinsic + // being analyzed. Otherwise we may transform a call sharing the same + // vtable pointer incorrectly. Specifically, this situation can arise + // after indirect call promotion and inlining, where we may have uses + // of the vtable pointer guarded by a function pointer check, and a fallback + // indirect call. + if (!DT.dominates(CI, User)) + continue; + if (isa<BitCastInst>(User)) { + findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI, + DT); + } else if (auto CI = dyn_cast<CallInst>(User)) { + DevirtCalls.push_back({Offset, CI}); + } else if (auto II = dyn_cast<InvokeInst>(User)) { + DevirtCalls.push_back({Offset, II}); + } else if (HasNonCallUses) { + *HasNonCallUses = true; + } + } +} + +// Search for virtual calls that load from VPtr and add them to DevirtCalls. +static void findLoadCallsAtConstantOffset( + const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr, + int64_t Offset, const CallInst *CI, DominatorTree &DT) { + for (const Use &U : VPtr->uses()) { + Value *User = U.getUser(); + if (isa<BitCastInst>(User)) { + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT); + } else if (isa<LoadInst>(User)) { + findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT); + } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { + // Take into account the GEP offset. + if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { + SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end()); + int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType( + GEP->getSourceElementType(), Indices); + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset, + CI, DT); + } + } + } +} + +void llvm::findDevirtualizableCallsForTypeTest( + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI, + DominatorTree &DT) { + assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); + + const Module *M = CI->getParent()->getParent()->getParent(); + + // Find llvm.assume intrinsics for this llvm.type.test call. + for (const Use &CIU : CI->uses()) { + if (auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser())) { + Function *F = AssumeCI->getCalledFunction(); + if (F && F->getIntrinsicID() == Intrinsic::assume) + Assumes.push_back(AssumeCI); + } + } + + // If we found any, search for virtual calls based on %p and add them to + // DevirtCalls. + if (!Assumes.empty()) + findLoadCallsAtConstantOffset( + M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT); +} + +void llvm::findDevirtualizableCallsForTypeCheckedLoad( + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + SmallVectorImpl<Instruction *> &LoadedPtrs, + SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, + const CallInst *CI, DominatorTree &DT) { + assert(CI->getCalledFunction()->getIntrinsicID() == + Intrinsic::type_checked_load); + + auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!Offset) { + HasNonCallUses = true; + return; + } + + for (const Use &U : CI->uses()) { + auto CIU = U.getUser(); + if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) { + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) { + LoadedPtrs.push_back(EVI); + continue; + } + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) { + Preds.push_back(EVI); + continue; + } + } + HasNonCallUses = true; + } + + for (Value *LoadedPtr : LoadedPtrs) + findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr, + Offset->getZExtValue(), CI, DT); +} + +Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M) { + if (I->getType()->isPointerTy()) { + if (Offset == 0) + return I; + return nullptr; + } + + const DataLayout &DL = M.getDataLayout(); + + if (auto *C = dyn_cast<ConstantStruct>(I)) { + const StructLayout *SL = DL.getStructLayout(C->getType()); + if (Offset >= SL->getSizeInBytes()) + return nullptr; + + unsigned Op = SL->getElementContainingOffset(Offset); + return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), + Offset - SL->getElementOffset(Op), M); + } + if (auto *C = dyn_cast<ConstantArray>(I)) { + ArrayType *VTableTy = C->getType(); + uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); + + unsigned Op = Offset / ElemSize; + if (Op >= C->getNumOperands()) + return nullptr; + + return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), + Offset % ElemSize, M); + } + return nullptr; +} diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp new file mode 100644 index 000000000000..6fd8ae63f5f0 --- /dev/null +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -0,0 +1,418 @@ +//===- VFABIDemangling.cpp - Vector Function ABI demangling utilities. ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/VectorUtils.h" + +using namespace llvm; + +namespace { +/// Utilities for the Vector Function ABI name parser. + +/// Return types for the parser functions. +enum class ParseRet { + OK, // Found. + None, // Not found. + Error // Syntax error. +}; + +/// Extracts the `<isa>` information from the mangled string, and +/// sets the `ISA` accordingly. +ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) { + if (MangledName.empty()) + return ParseRet::Error; + + ISA = StringSwitch<VFISAKind>(MangledName.take_front(1)) + .Case("n", VFISAKind::AdvancedSIMD) + .Case("s", VFISAKind::SVE) + .Case("b", VFISAKind::SSE) + .Case("c", VFISAKind::AVX) + .Case("d", VFISAKind::AVX2) + .Case("e", VFISAKind::AVX512) + .Default(VFISAKind::Unknown); + + MangledName = MangledName.drop_front(1); + + return ParseRet::OK; +} + +/// Extracts the `<mask>` information from the mangled string, and +/// sets `IsMasked` accordingly. The input string `MangledName` is +/// left unmodified. +ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) { + if (MangledName.consume_front("M")) { + IsMasked = true; + return ParseRet::OK; + } + + if (MangledName.consume_front("N")) { + IsMasked = false; + return ParseRet::OK; + } + + return ParseRet::Error; +} + +/// Extract the `<vlen>` information from the mangled string, and +/// sets `VF` accordingly. A `<vlen> == "x"` token is interpreted as a scalable +/// vector length. On success, the `<vlen>` token is removed from +/// the input string `ParseString`. +/// +ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) { + if (ParseString.consume_front("x")) { + VF = 0; + IsScalable = true; + return ParseRet::OK; + } + + if (ParseString.consumeInteger(10, VF)) + return ParseRet::Error; + + IsScalable = false; + return ParseRet::OK; +} + +/// The function looks for the following strings at the beginning of +/// the input string `ParseString`: +/// +/// <token> <number> +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `Pos` to +/// <number>, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +/// +/// The function expects <token> to be one of "ls", "Rs", "Us" or +/// "Ls". +ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString, + VFParamKind &PKind, int &Pos, + const StringRef Token) { + if (ParseString.consume_front(Token)) { + PKind = VFABI::getVFParamKindFromString(Token); + if (ParseString.consumeInteger(10, Pos)) + return ParseRet::Error; + return ParseRet::OK; + } + + return ParseRet::None; +} + +/// The function looks for the following stringt at the beginning of +/// the input string `ParseString`: +/// +/// <token> <number> +/// +/// <token> is one of "ls", "Rs", "Us" or "Ls". +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `StepOrPos` to +/// <number>, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString, + VFParamKind &PKind, int &StepOrPos) { + ParseRet Ret; + + // "ls" <RuntimeStepPos> + Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "ls"); + if (Ret != ParseRet::None) + return Ret; + + // "Rs" <RuntimeStepPos> + Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Rs"); + if (Ret != ParseRet::None) + return Ret; + + // "Ls" <RuntimeStepPos> + Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Ls"); + if (Ret != ParseRet::None) + return Ret; + + // "Us" <RuntimeStepPos> + Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Us"); + if (Ret != ParseRet::None) + return Ret; + + return ParseRet::None; +} + +/// The function looks for the following strings at the beginning of +/// the input string `ParseString`: +/// +/// <token> {"n"} <number> +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `LinearStep` to +/// <number>, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +/// +/// The function expects <token> to be one of "l", "R", "U" or +/// "L". +ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString, + VFParamKind &PKind, int &LinearStep, + const StringRef Token) { + if (ParseString.consume_front(Token)) { + PKind = VFABI::getVFParamKindFromString(Token); + const bool Negate = ParseString.consume_front("n"); + if (ParseString.consumeInteger(10, LinearStep)) + LinearStep = 1; + if (Negate) + LinearStep *= -1; + return ParseRet::OK; + } + + return ParseRet::None; +} + +/// The function looks for the following strings at the beginning of +/// the input string `ParseString`: +/// +/// ["l" | "R" | "U" | "L"] {"n"} <number> +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `LinearStep` to +/// <number>, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString, + VFParamKind &PKind, int &StepOrPos) { + // "l" {"n"} <CompileTimeStep> + if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "l") == + ParseRet::OK) + return ParseRet::OK; + + // "R" {"n"} <CompileTimeStep> + if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "R") == + ParseRet::OK) + return ParseRet::OK; + + // "L" {"n"} <CompileTimeStep> + if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "L") == + ParseRet::OK) + return ParseRet::OK; + + // "U" {"n"} <CompileTimeStep> + if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "U") == + ParseRet::OK) + return ParseRet::OK; + + return ParseRet::None; +} + +/// The function looks for the following strings at the beginning of +/// the input string `ParseString`: +/// +/// "u" <number> +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `Pos` to +/// <number>, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +ParseRet tryParseUniform(StringRef &ParseString, VFParamKind &PKind, int &Pos) { + // "u" <Pos> + const char *UniformToken = "u"; + if (ParseString.consume_front(UniformToken)) { + PKind = VFABI::getVFParamKindFromString(UniformToken); + if (ParseString.consumeInteger(10, Pos)) + return ParseRet::Error; + + return ParseRet::OK; + } + return ParseRet::None; +} + +/// Looks into the <parameters> part of the mangled name in search +/// for valid paramaters at the beginning of the string +/// `ParseString`. +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `StepOrPos` +/// accordingly, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind, + int &StepOrPos) { + if (ParseString.consume_front("v")) { + PKind = VFParamKind::Vector; + StepOrPos = 0; + return ParseRet::OK; + } + + const ParseRet HasLinearRuntime = + tryParseLinearWithRuntimeStep(ParseString, PKind, StepOrPos); + if (HasLinearRuntime != ParseRet::None) + return HasLinearRuntime; + + const ParseRet HasLinearCompileTime = + tryParseLinearWithCompileTimeStep(ParseString, PKind, StepOrPos); + if (HasLinearCompileTime != ParseRet::None) + return HasLinearCompileTime; + + const ParseRet HasUniform = tryParseUniform(ParseString, PKind, StepOrPos); + if (HasUniform != ParseRet::None) + return HasUniform; + + return ParseRet::None; +} + +/// Looks into the <parameters> part of the mangled name in search +/// of a valid 'aligned' clause. The function should be invoked +/// after parsing a parameter via `tryParseParameter`. +/// +/// On success, it removes the parsed parameter from `ParseString`, +/// sets `PKind` to the correspondent enum value, sets `StepOrPos` +/// accordingly, and return success. On a syntax error, it return a +/// parsing error. If nothing is parsed, it returns None. +ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) { + uint64_t Val; + // "a" <number> + if (ParseString.consume_front("a")) { + if (ParseString.consumeInteger(10, Val)) + return ParseRet::Error; + + if (!isPowerOf2_64(Val)) + return ParseRet::Error; + + Alignment = Align(Val); + + return ParseRet::OK; + } + + return ParseRet::None; +} +} // namespace + +// Format of the ABI name: +// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] +Optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName) { + // Assume there is no custom name <redirection>, and therefore the + // vector name consists of + // _ZGV<isa><mask><vlen><parameters>_<scalarname>. + StringRef VectorName = MangledName; + + // Parse the fixed size part of the manled name + if (!MangledName.consume_front("_ZGV")) + return None; + + // Extract ISA. An unknow ISA is also supported, so we accept all + // values. + VFISAKind ISA; + if (tryParseISA(MangledName, ISA) != ParseRet::OK) + return None; + + // Extract <mask>. + bool IsMasked; + if (tryParseMask(MangledName, IsMasked) != ParseRet::OK) + return None; + + // Parse the variable size, starting from <vlen>. + unsigned VF; + bool IsScalable; + if (tryParseVLEN(MangledName, VF, IsScalable) != ParseRet::OK) + return None; + + // Parse the <parameters>. + ParseRet ParamFound; + SmallVector<VFParameter, 8> Parameters; + do { + const unsigned ParameterPos = Parameters.size(); + VFParamKind PKind; + int StepOrPos; + ParamFound = tryParseParameter(MangledName, PKind, StepOrPos); + + // Bail off if there is a parsing error in the parsing of the parameter. + if (ParamFound == ParseRet::Error) + return None; + + if (ParamFound == ParseRet::OK) { + Align Alignment; + // Look for the alignment token "a <number>". + const ParseRet AlignFound = tryParseAlign(MangledName, Alignment); + // Bail off if there is a syntax error in the align token. + if (AlignFound == ParseRet::Error) + return None; + + // Add the parameter. + Parameters.push_back({ParameterPos, PKind, StepOrPos, Alignment}); + } + } while (ParamFound == ParseRet::OK); + + // A valid MangledName mus have at least one valid entry in the + // <parameters>. + if (Parameters.empty()) + return None; + + // Check for the <scalarname> and the optional <redirection>, which + // are separated from the prefix with "_" + if (!MangledName.consume_front("_")) + return None; + + // The rest of the string must be in the format: + // <scalarname>[(<redirection>)] + const StringRef ScalarName = + MangledName.take_while([](char In) { return In != '('; }); + + if (ScalarName.empty()) + return None; + + // Reduce MangledName to [(<redirection>)]. + MangledName = MangledName.ltrim(ScalarName); + // Find the optional custom name redirection. + if (MangledName.consume_front("(")) { + if (!MangledName.consume_back(")")) + return None; + // Update the vector variant with the one specified by the user. + VectorName = MangledName; + // If the vector name is missing, bail out. + if (VectorName.empty()) + return None; + } + + // When <mask> is "M", we need to add a parameter that is used as + // global predicate for the function. + if (IsMasked) { + const unsigned Pos = Parameters.size(); + Parameters.push_back({Pos, VFParamKind::GlobalPredicate}); + } + + // Asserts for parameters of type `VFParamKind::GlobalPredicate`, as + // prescribed by the Vector Function ABI specifications supported by + // this parser: + // 1. Uniqueness. + // 2. Must be the last in the parameter list. + const auto NGlobalPreds = std::count_if( + Parameters.begin(), Parameters.end(), [](const VFParameter PK) { + return PK.ParamKind == VFParamKind::GlobalPredicate; + }); + assert(NGlobalPreds < 2 && "Cannot have more than one global predicate."); + if (NGlobalPreds) + assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate && + "The global predicate must be the last parameter"); + + const VFShape Shape({VF, IsScalable, ISA, Parameters}); + return VFInfo({Shape, ScalarName, VectorName}); +} + +VFParamKind VFABI::getVFParamKindFromString(const StringRef Token) { + const VFParamKind ParamKind = StringSwitch<VFParamKind>(Token) + .Case("v", VFParamKind::Vector) + .Case("l", VFParamKind::OMP_Linear) + .Case("R", VFParamKind::OMP_LinearRef) + .Case("L", VFParamKind::OMP_LinearVal) + .Case("U", VFParamKind::OMP_LinearUVal) + .Case("ls", VFParamKind::OMP_LinearPos) + .Case("Ls", VFParamKind::OMP_LinearValPos) + .Case("Rs", VFParamKind::OMP_LinearRefPos) + .Case("Us", VFParamKind::OMP_LinearUValPos) + .Case("u", VFParamKind::OMP_Uniform) + .Default(VFParamKind::Unknown); + + if (ParamKind != VFParamKind::Unknown) + return ParamKind; + + // This function should never be invoked with an invalid input. + llvm_unreachable("This fuction should be invoken only on parameters" + " that have a textual representation in the mangled name" + " of the Vector Function ABI"); +} diff --git a/llvm/lib/Analysis/ValueLattice.cpp b/llvm/lib/Analysis/ValueLattice.cpp new file mode 100644 index 000000000000..a0115a0eec36 --- /dev/null +++ b/llvm/lib/Analysis/ValueLattice.cpp @@ -0,0 +1,25 @@ +//===- ValueLattice.cpp - Value constraint analysis -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueLattice.h" + +namespace llvm { +raw_ostream &operator<<(raw_ostream &OS, const ValueLatticeElement &Val) { + if (Val.isUndefined()) + return OS << "undefined"; + if (Val.isOverdefined()) + return OS << "overdefined"; + + if (Val.isNotConstant()) + return OS << "notconstant<" << *Val.getNotConstant() << ">"; + if (Val.isConstantRange()) + return OS << "constantrange<" << Val.getConstantRange().getLower() << ", " + << Val.getConstantRange().getUpper() << ">"; + return OS << "constant<" << *Val.getConstant() << ">"; +} +} // end namespace llvm diff --git a/llvm/lib/Analysis/ValueLatticeUtils.cpp b/llvm/lib/Analysis/ValueLatticeUtils.cpp new file mode 100644 index 000000000000..3f9287e26ce7 --- /dev/null +++ b/llvm/lib/Analysis/ValueLatticeUtils.cpp @@ -0,0 +1,43 @@ +//===-- ValueLatticeUtils.cpp - Utils for solving lattices ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements common functions useful for performing data-flow +// analyses that propagate values across function boundaries. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instructions.h" +using namespace llvm; + +bool llvm::canTrackArgumentsInterprocedurally(Function *F) { + return F->hasLocalLinkage() && !F->hasAddressTaken(); +} + +bool llvm::canTrackReturnsInterprocedurally(Function *F) { + return F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked); +} + +bool llvm::canTrackGlobalVariableInterprocedurally(GlobalVariable *GV) { + if (GV->isConstant() || !GV->hasLocalLinkage() || + !GV->hasDefinitiveInitializer()) + return false; + return !any_of(GV->users(), [&](User *U) { + if (auto *Store = dyn_cast<StoreInst>(U)) { + if (Store->getValueOperand() == GV || Store->isVolatile()) + return true; + } else if (auto *Load = dyn_cast<LoadInst>(U)) { + if (Load->isVolatile()) + return true; + } else { + return true; + } + return false; + }); +} diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp new file mode 100644 index 000000000000..bbf389991836 --- /dev/null +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -0,0 +1,5820 @@ +//===- ValueTracking.cpp - Walk computations to compute properties --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help analyze properties that chains of +// computations have. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include <algorithm> +#include <array> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +const unsigned MaxDepth = 6; + +// Controls the number of uses of the value searched for possible +// dominating comparisons. +static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses", + cl::Hidden, cl::init(20)); + +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. +static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { + if (unsigned BitWidth = Ty->getScalarSizeInBits()) + return BitWidth; + + return DL.getIndexTypeSizeInBits(Ty); +} + +namespace { + +// Simplifying using an assume can only be done in a particular control-flow +// context (the context instruction provides that context). If an assume and +// the context instruction are not in the same block then the DT helps in +// figuring out if we can use it. +struct Query { + const DataLayout &DL; + AssumptionCache *AC; + const Instruction *CxtI; + const DominatorTree *DT; + + // Unlike the other analyses, this may be a nullptr because not all clients + // provide it currently. + OptimizationRemarkEmitter *ORE; + + /// Set of assumptions that should be excluded from further queries. + /// This is because of the potential for mutual recursion to cause + /// computeKnownBits to repeatedly visit the same assume intrinsic. The + /// classic case of this is assume(x = y), which will attempt to determine + /// bits in x from bits in y, which will attempt to determine bits in y from + /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call + /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo + /// (all of which can call computeKnownBits), and so on. + std::array<const Value *, MaxDepth> Excluded; + + /// If true, it is safe to use metadata during simplification. + InstrInfoQuery IIQ; + + unsigned NumExcluded = 0; + + Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo, + OptimizationRemarkEmitter *ORE = nullptr) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} + + Query(const Query &Q, const Value *NewExcl) + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ), + NumExcluded(Q.NumExcluded) { + Excluded = Q.Excluded; + Excluded[NumExcluded++] = NewExcl; + assert(NumExcluded <= Excluded.size()); + } + + bool isExcluded(const Value *Value) const { + if (NumExcluded == 0) + return false; + auto End = Excluded.begin() + NumExcluded; + return std::find(Excluded.begin(), End, Value) != End; + } +}; + +} // end anonymous namespace + +// Given the provided Value and, potentially, a context instruction, return +// the preferred context instruction (if any). +static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { + // If we've been provided with a context instruction, then use that (provided + // it has been inserted). + if (CxtI && CxtI->getParent()) + return CxtI; + + // If the value is really an already-inserted instruction, then use that. + CxtI = dyn_cast<Instruction>(V); + if (CxtI && CxtI->getParent()) + return CxtI; + + return nullptr; +} + +static void computeKnownBits(const Value *V, KnownBits &Known, + unsigned Depth, const Query &Q); + +void llvm::computeKnownBits(const Value *V, KnownBits &Known, + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { + ::computeKnownBits(V, Known, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); +} + +static KnownBits computeKnownBits(const Value *V, unsigned Depth, + const Query &Q); + +KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE, + bool UseInstrInfo) { + return ::computeKnownBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); +} + +bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + assert(LHS->getType() == RHS->getType() && + "LHS and RHS should have the same type"); + assert(LHS->getType()->isIntOrIntVectorTy() && + "LHS and RHS should be integers"); + // Look for an inverted mask: (X & ~M) op (Y & M). + Value *M; + if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(RHS, m_c_And(m_Specific(M), m_Value()))) + return true; + if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(LHS, m_c_And(m_Specific(M), m_Value()))) + return true; + IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); + KnownBits LHSKnown(IT->getBitWidth()); + KnownBits RHSKnown(IT->getBitWidth()); + computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); + return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); +} + +bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI) { + for (const User *U : CxtI->users()) { + if (const ICmpInst *IC = dyn_cast<ICmpInst>(U)) + if (IC->isEquality()) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + return false; + } + return true; +} + +static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, + const Query &Q); + +bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, + bool OrZero, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownToBeAPowerOfTwo( + V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); +} + +static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); + +bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownNonZero(V, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); +} + +bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); + return Known.isNonNegative(); +} + +bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + if (auto *CI = dyn_cast<ConstantInt>(V)) + return CI->getValue().isStrictlyPositive(); + + // TODO: We'd doing two recursive queries here. We should factor this such + // that only a single query is needed. + return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT, UseInstrInfo) && + isKnownNonZero(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); +} + +bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); + return Known.isNegative(); +} + +static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q); + +bool llvm::isKnownNonEqual(const Value *V1, const Value *V2, + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + return ::isKnownNonEqual(V1, V2, + Query(DL, AC, safeCxtI(V1, safeCxtI(V2, CxtI)), DT, + UseInstrInfo, /*ORE=*/nullptr)); +} + +static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, + const Query &Q); + +bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::MaskedValueIsZero( + V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); +} + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, + const Query &Q); + +unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::ComputeNumSignBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); +} + +static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, + bool NSW, + KnownBits &KnownOut, KnownBits &Known2, + unsigned Depth, const Query &Q) { + unsigned BitWidth = KnownOut.getBitWidth(); + + // If an initial sequence of bits in the result is not needed, the + // corresponding bits in the operands are not needed. + KnownBits LHSKnown(BitWidth); + computeKnownBits(Op0, LHSKnown, Depth + 1, Q); + computeKnownBits(Op1, Known2, Depth + 1, Q); + + KnownOut = KnownBits::computeForAddSub(Add, NSW, LHSKnown, Known2); +} + +static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, + KnownBits &Known, KnownBits &Known2, + unsigned Depth, const Query &Q) { + unsigned BitWidth = Known.getBitWidth(); + computeKnownBits(Op1, Known, Depth + 1, Q); + computeKnownBits(Op0, Known2, Depth + 1, Q); + + bool isKnownNegative = false; + bool isKnownNonNegative = false; + // If the multiplication is known not to overflow, compute the sign bit. + if (NSW) { + if (Op0 == Op1) { + // The product of a number with itself is non-negative. + isKnownNonNegative = true; + } else { + bool isKnownNonNegativeOp1 = Known.isNonNegative(); + bool isKnownNonNegativeOp0 = Known2.isNonNegative(); + bool isKnownNegativeOp1 = Known.isNegative(); + bool isKnownNegativeOp0 = Known2.isNegative(); + // The product of two numbers with the same sign is non-negative. + isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || + (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); + // The product of a negative number and a non-negative number is either + // negative or zero. + if (!isKnownNonNegative) + isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && + isKnownNonZero(Op0, Depth, Q)) || + (isKnownNegativeOp0 && isKnownNonNegativeOp1 && + isKnownNonZero(Op1, Depth, Q)); + } + } + + assert(!Known.hasConflict() && !Known2.hasConflict()); + // Compute a conservative estimate for high known-0 bits. + unsigned LeadZ = std::max(Known.countMinLeadingZeros() + + Known2.countMinLeadingZeros(), + BitWidth) - BitWidth; + LeadZ = std::min(LeadZ, BitWidth); + + // The result of the bottom bits of an integer multiply can be + // inferred by looking at the bottom bits of both operands and + // multiplying them together. + // We can infer at least the minimum number of known trailing bits + // of both operands. Depending on number of trailing zeros, we can + // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming + // a and b are divisible by m and n respectively. + // We then calculate how many of those bits are inferrable and set + // the output. For example, the i8 mul: + // a = XXXX1100 (12) + // b = XXXX1110 (14) + // We know the bottom 3 bits are zero since the first can be divided by + // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). + // Applying the multiplication to the trimmed arguments gets: + // XX11 (3) + // X111 (7) + // ------- + // XX11 + // XX11 + // XX11 + // XX11 + // ------- + // XXXXX01 + // Which allows us to infer the 2 LSBs. Since we're multiplying the result + // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. + // The proof for this can be described as: + // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && + // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + + // umin(countTrailingZeros(C2), C6) + + // umin(C5 - umin(countTrailingZeros(C1), C5), + // C6 - umin(countTrailingZeros(C2), C6)))) - 1) + // %aa = shl i8 %a, C5 + // %bb = shl i8 %b, C6 + // %aaa = or i8 %aa, C1 + // %bbb = or i8 %bb, C2 + // %mul = mul i8 %aaa, %bbb + // %mask = and i8 %mul, C7 + // => + // %mask = i8 ((C1*C2)&C7) + // Where C5, C6 describe the known bits of %a, %b + // C1, C2 describe the known bottom bits of %a, %b. + // C7 describes the mask of the known bits of the result. + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + + // How many times we'd be able to divide each argument by 2 (shr by 1). + // This gives us the number of trailing zeros on the multiplication result. + unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); + unsigned TrailZero0 = Known.countMinTrailingZeros(); + unsigned TrailZero1 = Known2.countMinTrailingZeros(); + unsigned TrailZ = TrailZero0 + TrailZero1; + + // Figure out the fewest known-bits operand. + unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, + TrailBitsKnown1 - TrailZero1); + unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * + Bottom1.getLoBits(TrailBitsKnown1); + + Known.resetAll(); + Known.Zero.setHighBits(LeadZ); + Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Known.One |= BottomKnown.getLoBits(ResultBitsKnown); + + // Only make use of no-wrap flags if we failed to compute the sign bit + // directly. This matters if the multiplication always overflows, in + // which case we prefer to follow the result of the direct computation, + // though as the program is invoking undefined behaviour we can choose + // whatever we like here. + if (isKnownNonNegative && !Known.isNegative()) + Known.makeNonNegative(); + else if (isKnownNegative && !Known.isNonNegative()) + Known.makeNegative(); +} + +void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, + KnownBits &Known) { + unsigned BitWidth = Known.getBitWidth(); + unsigned NumRanges = Ranges.getNumOperands() / 2; + assert(NumRanges >= 1); + + Known.Zero.setAllBits(); + Known.One.setAllBits(); + + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + + // The first CommonPrefixBits of all values in Range are equal. + unsigned CommonPrefixBits = + (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros(); + + APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits); + Known.One &= Range.getUnsignedMax() & Mask; + Known.Zero &= ~Range.getUnsignedMax() & Mask; + } +} + +static bool isEphemeralValueOf(const Instruction *I, const Value *E) { + SmallVector<const Value *, 16> WorkSet(1, I); + SmallPtrSet<const Value *, 32> Visited; + SmallPtrSet<const Value *, 16> EphValues; + + // The instruction defining an assumption's condition itself is always + // considered ephemeral to that assumption (even if it has other + // non-ephemeral users). See r246696's test case for an example. + if (is_contained(I->operands(), E)) + return true; + + while (!WorkSet.empty()) { + const Value *V = WorkSet.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + // If all uses of this value are ephemeral, then so is this value. + if (llvm::all_of(V->users(), [&](const User *U) { + return EphValues.count(U); + })) { + if (V == E) + return true; + + if (V == I || isSafeToSpeculativelyExecute(V)) { + EphValues.insert(V); + if (const User *U = dyn_cast<User>(V)) + for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); + J != JE; ++J) + WorkSet.push_back(*J); + } + } + } + + return false; +} + +// Is this an intrinsic that cannot be speculated but also cannot trap? +bool llvm::isAssumeLikeIntrinsic(const Instruction *I) { + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (Function *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + default: break; + // FIXME: This list is repeated from NoTTI::getIntrinsicCost. + case Intrinsic::assume: + case Intrinsic::sideeffect: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::dbg_label: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::objectsize: + case Intrinsic::ptr_annotation: + case Intrinsic::var_annotation: + return true; + } + + return false; +} + +bool llvm::isValidAssumeForContext(const Instruction *Inv, + const Instruction *CxtI, + const DominatorTree *DT) { + // There are two restrictions on the use of an assume: + // 1. The assume must dominate the context (or the control flow must + // reach the assume whenever it reaches the context). + // 2. The context must not be in the assume's set of ephemeral values + // (otherwise we will use the assume to prove that the condition + // feeding the assume is trivially true, thus causing the removal of + // the assume). + + if (DT) { + if (DT->dominates(Inv, CxtI)) + return true; + } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor()) { + // We don't have a DT, but this trivially dominates. + return true; + } + + // With or without a DT, the only remaining case we will check is if the + // instructions are in the same BB. Give up if that is not the case. + if (Inv->getParent() != CxtI->getParent()) + return false; + + // If we have a dom tree, then we now know that the assume doesn't dominate + // the other instruction. If we don't have a dom tree then we can check if + // the assume is first in the BB. + if (!DT) { + // Search forward from the assume until we reach the context (or the end + // of the block); the common case is that the assume will come first. + for (auto I = std::next(BasicBlock::const_iterator(Inv)), + IE = Inv->getParent()->end(); I != IE; ++I) + if (&*I == CxtI) + return true; + } + + // Don't let an assume affect itself - this would cause the problems + // `isEphemeralValueOf` is trying to prevent, and it would also make + // the loop below go out of bounds. + if (Inv == CxtI) + return false; + + // The context comes first, but they're both in the same block. Make sure + // there is nothing in between that might interrupt the control flow. + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(CxtI)), IE(Inv); + I != IE; ++I) + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) + return false; + + return !isEphemeralValueOf(Inv, CxtI); +} + +static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, + unsigned Depth, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return; + + unsigned BitWidth = Known.getBitWidth(); + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast<CallInst>(AssumeVH); + assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + if (Q.isExcluded(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + + if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + Known.setAllOnes(); + return; + } + if (match(Arg, m_Not(m_Specific(V))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + Known.setAllZero(); + return; + } + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth == MaxDepth) + continue; + + ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg); + if (!Cmp) + continue; + + Value *A, *B; + auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); + + CmpInst::Predicate Pred; + uint64_t C; + switch (Cmp->getPredicate()) { + default: + break; + case ICmpInst::ICMP_EQ: + // assume(v = a) + if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + Known.Zero |= RHSKnown.Zero; + Known.One |= RHSKnown.One; + // assume(v & b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & MaskKnown.One; + Known.One |= RHSKnown.One & MaskKnown.One; + // assume(~(v & b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & MaskKnown.One; + Known.One |= RHSKnown.Zero & MaskKnown.One; + // assume(v | b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + // assume(~(v | b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + // assume(v ^ b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. For those bits in B that are known to be one, + // we can propagate inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + Known.Zero |= RHSKnown.One & BKnown.One; + Known.One |= RHSKnown.Zero & BKnown.One; + // assume(~(v ^ b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. For those bits in B that are + // known to be one, we can propagate known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + Known.Zero |= RHSKnown.Zero & BKnown.One; + Known.One |= RHSKnown.One & BKnown.One; + // assume(v << c = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + RHSKnown.Zero.lshrInPlace(C); + Known.Zero |= RHSKnown.Zero; + RHSKnown.One.lshrInPlace(C); + Known.One |= RHSKnown.One; + // assume(~(v << c) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + RHSKnown.One.lshrInPlace(C); + Known.Zero |= RHSKnown.One; + RHSKnown.Zero.lshrInPlace(C); + Known.One |= RHSKnown.Zero; + // assume(v >> c = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + Known.Zero |= RHSKnown.Zero << C; + Known.One |= RHSKnown.One << C; + // assume(~(v >> c) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), + m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + Known.Zero |= RHSKnown.One << C; + Known.One |= RHSKnown.Zero << C; + } + break; + case ICmpInst::ICMP_SGE: + // assume(v >=_s c) where c is non-negative + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + + if (RHSKnown.isNonNegative()) { + // We know that the sign bit is zero. + Known.makeNonNegative(); + } + } + break; + case ICmpInst::ICMP_SGT: + // assume(v >_s c) where c is at least -1. + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + + if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { + // We know that the sign bit is zero. + Known.makeNonNegative(); + } + } + break; + case ICmpInst::ICMP_SLE: + // assume(v <=_s c) where c is negative + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + + if (RHSKnown.isNegative()) { + // We know that the sign bit is one. + Known.makeNegative(); + } + } + break; + case ICmpInst::ICMP_SLT: + // assume(v <_s c) where c is non-positive + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + + if (RHSKnown.isZero() || RHSKnown.isNegative()) { + // We know that the sign bit is one. + Known.makeNegative(); + } + } + break; + case ICmpInst::ICMP_ULE: + // assume(v <=_u c) + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + + // Whatever high bits in c are zero are known to be zero. + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); + } + break; + case ICmpInst::ICMP_ULT: + // assume(v <_u c) + if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + + // If the RHS is known zero, then this assumption must be wrong (nothing + // is unsigned less than zero). Signal a conflict and get out of here. + if (RHSKnown.isZero()) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + break; + } + + // Whatever high bits in c are zero are known to be zero (if c is a power + // of 2, then one more). + if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); + else + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); + } + break; + } + } + + // If assumptions conflict with each other or previous known bits, then we + // have a logical fallacy. It's possible that the assumption is not reachable, + // so this isn't a real bug. On the other hand, the program may have undefined + // behavior, or we might have a bug in the compiler. We can't assert/crash, so + // clear out the known bits, try to warn the user, and hope for the best. + if (Known.Zero.intersects(Known.One)) { + Known.resetAll(); + + if (Q.ORE) + Q.ORE->emit([&]() { + auto *CxtI = const_cast<Instruction *>(Q.CxtI); + return OptimizationRemarkAnalysis("value-tracking", "BadAssumption", + CxtI) + << "Detected conflicting code assumptions. Program may " + "have undefined behavior, or compiler may have " + "internal error."; + }); + } +} + +/// Compute known bits from a shift operator, including those with a +/// non-constant shift amount. Known is the output of this function. Known2 is a +/// pre-allocated temporary with the same bit width as Known. KZF and KOF are +/// operator-specific functions that, given the known-zero or known-one bits +/// respectively, and a shift amount, compute the implied known-zero or +/// known-one bits of the shift operator's result respectively for that shift +/// amount. The results from calling KZF and KOF are conservatively combined for +/// all permitted shift amounts. +static void computeKnownBitsFromShiftOperator( + const Operator *I, KnownBits &Known, KnownBits &Known2, + unsigned Depth, const Query &Q, + function_ref<APInt(const APInt &, unsigned)> KZF, + function_ref<APInt(const APInt &, unsigned)> KOF) { + unsigned BitWidth = Known.getBitWidth(); + + if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1); + + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known.Zero = KZF(Known.Zero, ShiftAmt); + Known.One = KOF(Known.One, ShiftAmt); + // If the known bits conflict, this must be an overflowing left shift, so + // the shift result is poison. We can return anything we want. Choose 0 for + // the best folding opportunity. + if (Known.hasConflict()) + Known.setAllZero(); + + return; + } + + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + + // If the shift amount could be greater than or equal to the bit-width of the + // LHS, the value could be poison, but bail out because the check below is + // expensive. TODO: Should we just carry on? + if ((~Known.Zero).uge(BitWidth)) { + Known.resetAll(); + return; + } + + // Note: We cannot use Known.Zero.getLimitedValue() here, because if + // BitWidth > 64 and any upper bits are known, we'll end up returning the + // limit value (which implies all bits are known). + uint64_t ShiftAmtKZ = Known.Zero.zextOrTrunc(64).getZExtValue(); + uint64_t ShiftAmtKO = Known.One.zextOrTrunc(64).getZExtValue(); + + // It would be more-clearly correct to use the two temporaries for this + // calculation. Reusing the APInts here to prevent unnecessary allocations. + Known.resetAll(); + + // If we know the shifter operand is nonzero, we can sometimes infer more + // known bits. However this is expensive to compute, so be lazy about it and + // only compute it when absolutely necessary. + Optional<bool> ShifterOperandIsNonZero; + + // Early exit if we can't constrain any well-defined shift amount. + if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) && + !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) { + ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), Depth + 1, Q); + if (!*ShifterOperandIsNonZero) + return; + } + + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + + Known.Zero.setAllBits(); + Known.One.setAllBits(); + for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) { + // Combine the shifted known input bits only for those shift amounts + // compatible with its known constraints. + if ((ShiftAmt & ~ShiftAmtKZ) != ShiftAmt) + continue; + if ((ShiftAmt | ShiftAmtKO) != ShiftAmt) + continue; + // If we know the shifter is nonzero, we may be able to infer more known + // bits. This check is sunk down as far as possible to avoid the expensive + // call to isKnownNonZero if the cheaper checks above fail. + if (ShiftAmt == 0) { + if (!ShifterOperandIsNonZero.hasValue()) + ShifterOperandIsNonZero = + isKnownNonZero(I->getOperand(1), Depth + 1, Q); + if (*ShifterOperandIsNonZero) + continue; + } + + Known.Zero &= KZF(Known2.Zero, ShiftAmt); + Known.One &= KOF(Known2.One, ShiftAmt); + } + + // If the known bits conflict, the result is poison. Return a 0 and hope the + // caller can further optimize that. + if (Known.hasConflict()) + Known.setAllZero(); +} + +static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, + unsigned Depth, const Query &Q) { + unsigned BitWidth = Known.getBitWidth(); + + KnownBits Known2(Known); + switch (I->getOpcode()) { + default: break; + case Instruction::Load: + if (MDNode *MD = + Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range)) + computeKnownBitsFromRangeMetadata(*MD, Known); + break; + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + + // Output known-1 bits are only known if set in both the LHS & RHS. + Known.One &= Known2.One; + // Output known-0 are known to be clear if zero in either the LHS | RHS. + Known.Zero |= Known2.Zero; + + // and(x, add (x, -1)) is a common idiom that always clears the low bit; + // here we handle the more general case of adding any odd number by + // matching the form add(x, add(x, y)) where y is odd. + // TODO: This could be generalized to clearing any bit set in y where the + // following bit is known to be unset in y. + Value *X = nullptr, *Y = nullptr; + if (!Known.Zero[0] && !Known.One[0] && + match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { + Known2.resetAll(); + computeKnownBits(Y, Known2, Depth + 1, Q); + if (Known2.countMinTrailingOnes() > 0) + Known.Zero.setBit(0); + } + break; + } + case Instruction::Or: + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + Known.Zero &= Known2.Zero; + // Output known-1 are known to be set if set in either the LHS | RHS. + Known.One |= Known2.One; + break; + case Instruction::Xor: { + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); + Known.Zero = std::move(KnownZeroOut); + break; + } + case Instruction::Mul: { + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); + computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, + Known2, Depth, Q); + break; + } + case Instruction::UDiv: { + // For the purposes of computing leading zeros we can conservatively + // treat a udiv as a logical right shift by the power of 2 known to + // be less than the denominator. + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + unsigned LeadZ = Known2.countMinLeadingZeros(); + + Known2.resetAll(); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + unsigned RHSMaxLeadingZeros = Known2.countMaxLeadingZeros(); + if (RHSMaxLeadingZeros != BitWidth) + LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); + + Known.Zero.setHighBits(LeadZ); + break; + } + case Instruction::Select: { + const Value *LHS = nullptr, *RHS = nullptr; + SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; + if (SelectPatternResult::isMinOrMax(SPF)) { + computeKnownBits(RHS, Known, Depth + 1, Q); + computeKnownBits(LHS, Known2, Depth + 1, Q); + } else { + computeKnownBits(I->getOperand(2), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + } + + unsigned MaxHighOnes = 0; + unsigned MaxHighZeros = 0; + if (SPF == SPF_SMAX) { + // If both sides are negative, the result is negative. + if (Known.isNegative() && Known2.isNegative()) + // We can derive a lower bound on the result by taking the max of the + // leading one bits. + MaxHighOnes = + std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); + // If either side is non-negative, the result is non-negative. + else if (Known.isNonNegative() || Known2.isNonNegative()) + MaxHighZeros = 1; + } else if (SPF == SPF_SMIN) { + // If both sides are non-negative, the result is non-negative. + if (Known.isNonNegative() && Known2.isNonNegative()) + // We can derive an upper bound on the result by taking the max of the + // leading zero bits. + MaxHighZeros = std::max(Known.countMinLeadingZeros(), + Known2.countMinLeadingZeros()); + // If either side is negative, the result is negative. + else if (Known.isNegative() || Known2.isNegative()) + MaxHighOnes = 1; + } else if (SPF == SPF_UMAX) { + // We can derive a lower bound on the result by taking the max of the + // leading one bits. + MaxHighOnes = + std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); + } else if (SPF == SPF_UMIN) { + // We can derive an upper bound on the result by taking the max of the + // leading zero bits. + MaxHighZeros = + std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); + } else if (SPF == SPF_ABS) { + // RHS from matchSelectPattern returns the negation part of abs pattern. + // If the negate has an NSW flag we can assume the sign bit of the result + // will be 0 because that makes abs(INT_MIN) undefined. + if (match(RHS, m_Neg(m_Specific(LHS))) && + Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS))) + MaxHighZeros = 1; + } + + // Only known if known in both the LHS and RHS. + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + if (MaxHighOnes > 0) + Known.One.setHighBits(MaxHighOnes); + if (MaxHighZeros > 0) + Known.Zero.setHighBits(MaxHighZeros); + break; + } + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::SIToFP: + case Instruction::UIToFP: + break; // Can't work with floating point. + case Instruction::PtrToInt: + case Instruction::IntToPtr: + // Fall through and handle them the same as zext/trunc. + LLVM_FALLTHROUGH; + case Instruction::ZExt: + case Instruction::Trunc: { + Type *SrcTy = I->getOperand(0)->getType(); + + unsigned SrcBitWidth; + // Note that we handle pointer operands here because of inttoptr/ptrtoint + // which fall through here. + Type *ScalarTy = SrcTy->getScalarType(); + SrcBitWidth = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : + Q.DL.getTypeSizeInBits(ScalarTy); + + assert(SrcBitWidth && "SrcBitWidth can't be zero"); + Known = Known.zextOrTrunc(SrcBitWidth, false); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known = Known.zextOrTrunc(BitWidth, true /* ExtendedBitsAreKnownZero */); + break; + } + case Instruction::BitCast: { + Type *SrcTy = I->getOperand(0)->getType(); + if (SrcTy->isIntOrPtrTy() && + // TODO: For now, not handling conversions like: + // (bitcast i64 %x to <2 x i32>) + !I->getType()->isVectorTy()) { + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + break; + } + break; + } + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); + + Known = Known.trunc(SrcBitWidth); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + Known = Known.sext(BitWidth); + break; + } + case Instruction::Shl: { + // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); + auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { + APInt KZResult = KnownZero << ShiftAmt; + KZResult.setLowBits(ShiftAmt); // Low bits known 0. + // If this shift has "nsw" keyword, then the result is either a poison + // value or has the same sign bit as the first operand. + if (NSW && KnownZero.isSignBitSet()) + KZResult.setSignBit(); + return KZResult; + }; + + auto KOF = [NSW](const APInt &KnownOne, unsigned ShiftAmt) { + APInt KOResult = KnownOne << ShiftAmt; + if (NSW && KnownOne.isSignBitSet()) + KOResult.setSignBit(); + return KOResult; + }; + + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + break; + } + case Instruction::LShr: { + // (lshr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { + APInt KZResult = KnownZero.lshr(ShiftAmt); + // High bits known zero. + KZResult.setHighBits(ShiftAmt); + return KZResult; + }; + + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.lshr(ShiftAmt); + }; + + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + break; + } + case Instruction::AShr: { + // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { + return KnownZero.ashr(ShiftAmt); + }; + + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.ashr(ShiftAmt); + }; + + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + break; + } + case Instruction::Sub: { + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); + computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, + Known, Known2, Depth, Q); + break; + } + case Instruction::Add: { + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); + computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, + Known, Known2, Depth, Q); + break; + } + case Instruction::SRem: + if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + APInt RA = Rem->getValue().abs(); + if (RA.isPowerOf2()) { + APInt LowBits = RA - 1; + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + + // The low bits of the first operand are unchanged by the srem. + Known.Zero = Known2.Zero & LowBits; + Known.One = Known2.One & LowBits; + + // If the first operand is non-negative or has all low bits zero, then + // the upper bits are all zero. + if (Known2.isNonNegative() || LowBits.isSubsetOf(Known2.Zero)) + Known.Zero |= ~LowBits; + + // If the first operand is negative and not all low bits are zero, then + // the upper bits are all one. + if (Known2.isNegative() && LowBits.intersects(Known2.One)) + Known.One |= ~LowBits; + + assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); + break; + } + } + + // The sign bit is the LHS's sign bit, except when the result of the + // remainder is zero. + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If it's known zero, our sign bit is also zero. + if (Known2.isNonNegative()) + Known.makeNonNegative(); + + break; + case Instruction::URem: { + if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt &RA = Rem->getValue(); + if (RA.isPowerOf2()) { + APInt LowBits = (RA - 1); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known.Zero |= ~LowBits; + Known.One &= LowBits; + break; + } + } + + // Since the result is less than or equal to either operand, any leading + // zero bits in either operand must also exist in the result. + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + + unsigned Leaders = + std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); + Known.resetAll(); + Known.Zero.setHighBits(Leaders); + break; + } + + case Instruction::Alloca: { + const AllocaInst *AI = cast<AllocaInst>(I); + unsigned Align = AI->getAlignment(); + if (Align == 0) + Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); + + if (Align > 0) + Known.Zero.setLowBits(countTrailingZeros(Align)); + break; + } + case Instruction::GetElementPtr: { + // Analyze all of the subscripts of this getelementptr instruction + // to determine if we can prove known low zero bits. + KnownBits LocalKnown(BitWidth); + computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q); + unsigned TrailZ = LocalKnown.countMinTrailingZeros(); + + gep_type_iterator GTI = gep_type_begin(I); + for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { + Value *Index = I->getOperand(i); + if (StructType *STy = GTI.getStructTypeOrNull()) { + // Handle struct member offset arithmetic. + + // Handle case when index is vector zeroinitializer + Constant *CIndex = cast<Constant>(Index); + if (CIndex->isZeroValue()) + continue; + + if (CIndex->getType()->isVectorTy()) + Index = CIndex->getSplatValue(); + + unsigned Idx = cast<ConstantInt>(Index)->getZExtValue(); + const StructLayout *SL = Q.DL.getStructLayout(STy); + uint64_t Offset = SL->getElementOffset(Idx); + TrailZ = std::min<unsigned>(TrailZ, + countTrailingZeros(Offset)); + } else { + // Handle array index arithmetic. + Type *IndexedTy = GTI.getIndexedType(); + if (!IndexedTy->isSized()) { + TrailZ = 0; + break; + } + unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); + uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); + LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); + computeKnownBits(Index, LocalKnown, Depth + 1, Q); + TrailZ = std::min(TrailZ, + unsigned(countTrailingZeros(TypeSize) + + LocalKnown.countMinTrailingZeros())); + } + } + + Known.Zero.setLowBits(TrailZ); + break; + } + case Instruction::PHI: { + const PHINode *P = cast<PHINode>(I); + // Handle the case of a simple two-predecessor recurrence PHI. + // There's a lot more that could theoretically be done here, but + // this is sufficient to catch some interesting cases. + if (P->getNumIncomingValues() == 2) { + for (unsigned i = 0; i != 2; ++i) { + Value *L = P->getIncomingValue(i); + Value *R = P->getIncomingValue(!i); + Operator *LU = dyn_cast<Operator>(L); + if (!LU) + continue; + unsigned Opcode = LU->getOpcode(); + // Check for operations that have the property that if + // both their operands have low zero bits, the result + // will have low zero bits. + if (Opcode == Instruction::Add || + Opcode == Instruction::Sub || + Opcode == Instruction::And || + Opcode == Instruction::Or || + Opcode == Instruction::Mul) { + Value *LL = LU->getOperand(0); + Value *LR = LU->getOperand(1); + // Find a recurrence. + if (LL == I) + L = LR; + else if (LR == I) + L = LL; + else + continue; // Check for recurrence with L and R flipped. + // Ok, we have a PHI of the form L op= R. Check for low + // zero bits. + computeKnownBits(R, Known2, Depth + 1, Q); + + // We need to take the minimum number of known bits + KnownBits Known3(Known); + computeKnownBits(L, Known3, Depth + 1, Q); + + Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(), + Known3.countMinTrailingZeros())); + + auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU); + if (OverflowOp && Q.IIQ.hasNoSignedWrap(OverflowOp)) { + // If initial value of recurrence is nonnegative, and we are adding + // a nonnegative number with nsw, the result can only be nonnegative + // or poison value regardless of the number of times we execute the + // add in phi recurrence. If initial value is negative and we are + // adding a negative number with nsw, the result can only be + // negative or poison value. Similar arguments apply to sub and mul. + // + // (add non-negative, non-negative) --> non-negative + // (add negative, negative) --> negative + if (Opcode == Instruction::Add) { + if (Known2.isNonNegative() && Known3.isNonNegative()) + Known.makeNonNegative(); + else if (Known2.isNegative() && Known3.isNegative()) + Known.makeNegative(); + } + + // (sub nsw non-negative, negative) --> non-negative + // (sub nsw negative, non-negative) --> negative + else if (Opcode == Instruction::Sub && LL == I) { + if (Known2.isNonNegative() && Known3.isNegative()) + Known.makeNonNegative(); + else if (Known2.isNegative() && Known3.isNonNegative()) + Known.makeNegative(); + } + + // (mul nsw non-negative, non-negative) --> non-negative + else if (Opcode == Instruction::Mul && Known2.isNonNegative() && + Known3.isNonNegative()) + Known.makeNonNegative(); + } + + break; + } + } + } + + // Unreachable blocks may have zero-operand PHI nodes. + if (P->getNumIncomingValues() == 0) + break; + + // Otherwise take the unions of the known bit sets of the operands, + // taking conservative care to avoid excessive recursion. + if (Depth < MaxDepth - 1 && !Known.Zero && !Known.One) { + // Skip if every incoming value references to ourself. + if (dyn_cast_or_null<UndefValue>(P->hasConstantValue())) + break; + + Known.Zero.setAllBits(); + Known.One.setAllBits(); + for (Value *IncValue : P->incoming_values()) { + // Skip direct self references. + if (IncValue == P) continue; + + Known2 = KnownBits(BitWidth); + // Recurse, but cap the recursion to one level, because we don't + // want to waste time spinning around in loops. + computeKnownBits(IncValue, Known2, MaxDepth - 1, Q); + Known.Zero &= Known2.Zero; + Known.One &= Known2.One; + // If all bits have been ruled out, there's no need to check + // more operands. + if (!Known.Zero && !Known.One) + break; + } + } + break; + } + case Instruction::Call: + case Instruction::Invoke: + // If range metadata is attached to this call, set known bits from that, + // and then intersect with known bits based on other properties of the + // function. + if (MDNode *MD = + Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range)) + computeKnownBitsFromRangeMetadata(*MD, Known); + if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { + computeKnownBits(RV, Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero; + Known.One |= Known2.One; + } + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::bitreverse: + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero.reverseBits(); + Known.One |= Known2.One.reverseBits(); + break; + case Intrinsic::bswap: + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero.byteSwap(); + Known.One |= Known2.One.byteSwap(); + break; + case Intrinsic::ctlz: { + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If we have a known 1, its position is our upper bound. + unsigned PossibleLZ = Known2.One.countLeadingZeros(); + // If this call is undefined for 0, the result will be less than 2^n. + if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) + PossibleLZ = std::min(PossibleLZ, BitWidth - 1); + unsigned LowBits = Log2_32(PossibleLZ)+1; + Known.Zero.setBitsFrom(LowBits); + break; + } + case Intrinsic::cttz: { + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If we have a known 1, its position is our upper bound. + unsigned PossibleTZ = Known2.One.countTrailingZeros(); + // If this call is undefined for 0, the result will be less than 2^n. + if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) + PossibleTZ = std::min(PossibleTZ, BitWidth - 1); + unsigned LowBits = Log2_32(PossibleTZ)+1; + Known.Zero.setBitsFrom(LowBits); + break; + } + case Intrinsic::ctpop: { + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // We can bound the space the count needs. Also, bits known to be zero + // can't contribute to the population. + unsigned BitsPossiblySet = Known2.countMaxPopulation(); + unsigned LowBits = Log2_32(BitsPossiblySet)+1; + Known.Zero.setBitsFrom(LowBits); + // TODO: we could bound KnownOne using the lower bound on the number + // of bits which might be set provided by popcnt KnownOne2. + break; + } + case Intrinsic::fshr: + case Intrinsic::fshl: { + const APInt *SA; + if (!match(I->getOperand(2), m_APInt(SA))) + break; + + // Normalize to funnel shift left. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + KnownBits Known3(Known); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q); + + Known.Zero = + Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt); + Known.One = + Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt); + break; + } + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: { + bool IsAdd = II->getIntrinsicID() == Intrinsic::uadd_sat; + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + + // Add: Leading ones of either operand are preserved. + // Sub: Leading zeros of LHS and leading ones of RHS are preserved + // as leading zeros in the result. + unsigned LeadingKnown; + if (IsAdd) + LeadingKnown = std::max(Known.countMinLeadingOnes(), + Known2.countMinLeadingOnes()); + else + LeadingKnown = std::max(Known.countMinLeadingZeros(), + Known2.countMinLeadingOnes()); + + Known = KnownBits::computeForAddSub( + IsAdd, /* NSW */ false, Known, Known2); + + // We select between the operation result and all-ones/zero + // respectively, so we can preserve known ones/zeros. + if (IsAdd) { + Known.One.setHighBits(LeadingKnown); + Known.Zero.clearAllBits(); + } else { + Known.Zero.setHighBits(LeadingKnown); + Known.One.clearAllBits(); + } + break; + } + case Intrinsic::x86_sse42_crc32_64_64: + Known.Zero.setBitsFrom(32); + break; + } + } + break; + case Instruction::ExtractElement: + // Look through extract element. At the moment we keep this simple and skip + // tracking the specific element. But at least we might find information + // valid for all elements of the vector (for example if vector is sign + // extended, shifted, etc). + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + break; + case Instruction::ExtractValue: + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) { + const ExtractValueInst *EVI = cast<ExtractValueInst>(I); + if (EVI->getNumIndices() != 1) break; + if (EVI->getIndices()[0] == 0) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + computeKnownBitsAddSub(true, II->getArgOperand(0), + II->getArgOperand(1), false, Known, Known2, + Depth, Q); + break; + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + computeKnownBitsAddSub(false, II->getArgOperand(0), + II->getArgOperand(1), false, Known, Known2, + Depth, Q); + break; + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, + Known, Known2, Depth, Q); + break; + } + } + } + } +} + +/// Determine which bits of V are known to be either zero or one and return +/// them. +KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { + KnownBits Known(getBitWidth(V->getType(), Q.DL)); + computeKnownBits(V, Known, Depth, Q); + return Known; +} + +/// Determine which bits of V are known to be either zero or one and return +/// them in the Known bit set. +/// +/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that +/// we cannot optimize based on the assumption that it is zero without changing +/// it to be an explicit zero. If we don't change it to zero, other code could +/// optimized based on the contradictory assumption that it is non-zero. +/// Because instcombine aggressively folds operations with undef args anyway, +/// this won't lose us code quality. +/// +/// This function is defined on values with integer type, values with pointer +/// type, and vectors of integers. In the case +/// where V is a vector, known zero, and known one values are the +/// same width as the vector element, and the bit is set only if it is true +/// for all of the elements in the vector. +void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, + const Query &Q) { + assert(V && "No Value?"); + assert(Depth <= MaxDepth && "Limit Search Depth"); + unsigned BitWidth = Known.getBitWidth(); + + assert((V->getType()->isIntOrIntVectorTy(BitWidth) || + V->getType()->isPtrOrPtrVectorTy()) && + "Not integer or pointer type!"); + + Type *ScalarTy = V->getType()->getScalarType(); + unsigned ExpectedWidth = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); + assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); + (void)BitWidth; + (void)ExpectedWidth; + + const APInt *C; + if (match(V, m_APInt(C))) { + // We know all of the bits for a scalar constant or a splat vector constant! + Known.One = *C; + Known.Zero = ~Known.One; + return; + } + // Null and aggregate-zero are all-zeros. + if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) { + Known.setAllZero(); + return; + } + // Handle a constant vector by taking the intersection of the known bits of + // each element. + if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) { + // We know that CDS must be a vector of integers. Take the intersection of + // each element. + Known.Zero.setAllBits(); Known.One.setAllBits(); + for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { + APInt Elt = CDS->getElementAsAPInt(i); + Known.Zero &= ~Elt; + Known.One &= Elt; + } + return; + } + + if (const auto *CV = dyn_cast<ConstantVector>(V)) { + // We know that CV must be a vector of integers. Take the intersection of + // each element. + Known.Zero.setAllBits(); Known.One.setAllBits(); + for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + Constant *Element = CV->getAggregateElement(i); + auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); + if (!ElementCI) { + Known.resetAll(); + return; + } + const APInt &Elt = ElementCI->getValue(); + Known.Zero &= ~Elt; + Known.One &= Elt; + } + return; + } + + // Start out not knowing anything. + Known.resetAll(); + + // We can't imply anything about undefs. + if (isa<UndefValue>(V)) + return; + + // There's no point in looking through other users of ConstantData for + // assumptions. Confirm that we've handled them all. + assert(!isa<ConstantData>(V) && "Unhandled constant data!"); + + // Limit search depth. + // All recursive calls that increase depth must come after this. + if (Depth == MaxDepth) + return; + + // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has + // the bits of its aliasee. + if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (!GA->isInterposable()) + computeKnownBits(GA->getAliasee(), Known, Depth + 1, Q); + return; + } + + if (const Operator *I = dyn_cast<Operator>(V)) + computeKnownBitsFromOperator(I, Known, Depth, Q); + + // Aligned pointers have trailing zeros - refine Known.Zero set + if (V->getType()->isPointerTy()) { + const MaybeAlign Align = V->getPointerAlignment(Q.DL); + if (Align) + Known.Zero.setLowBits(countTrailingZeros(Align->value())); + } + + // computeKnownBitsFromAssume strictly refines Known. + // Therefore, we run them after computeKnownBitsFromOperator. + + // Check whether a nearby assume intrinsic can determine some known bits. + computeKnownBitsFromAssume(V, Known, Depth, Q); + + assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); +} + +/// Return true if the given value is known to have exactly one +/// bit set when defined. For vectors return true if every element is known to +/// be a power of two when defined. Supports values with integer or pointer +/// types and vectors of integers. +bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, + const Query &Q) { + assert(Depth <= MaxDepth && "Limit Search Depth"); + + // Attempt to match against constants. + if (OrZero && match(V, m_Power2OrZero())) + return true; + if (match(V, m_Power2())) + return true; + + // 1 << X is clearly a power of two if the one is not shifted off the end. If + // it is shifted off the end then the result is undefined. + if (match(V, m_Shl(m_One(), m_Value()))) + return true; + + // (signmask) >>l X is clearly a power of two if the one is not shifted off + // the bottom. If it is shifted off the bottom then the result is undefined. + if (match(V, m_LShr(m_SignMask(), m_Value()))) + return true; + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return false; + + Value *X = nullptr, *Y = nullptr; + // A shift left or a logical shift right of a power of two is a power of two + // or zero. + if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || + match(V, m_LShr(m_Value(X), m_Value())))) + return isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q); + + if (const ZExtInst *ZI = dyn_cast<ZExtInst>(V)) + return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); + + if (const SelectInst *SI = dyn_cast<SelectInst>(V)) + return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && + isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); + + if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { + // A power of two and'd with anything is a power of two or zero. + if (isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q) || + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, Depth, Q)) + return true; + // X & (-X) is always a power of two or zero. + if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) + return true; + return false; + } + + // Adding a power-of-two or zero to the same power-of-two or zero yields + // either the original power-of-two, a larger power-of-two or zero. + if (match(V, m_Add(m_Value(X), m_Value(Y)))) { + const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V); + if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) || + Q.IIQ.hasNoSignedWrap(VOBO)) { + if (match(X, m_And(m_Specific(Y), m_Value())) || + match(X, m_And(m_Value(), m_Specific(Y)))) + if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) + return true; + if (match(Y, m_And(m_Specific(X), m_Value())) || + match(Y, m_And(m_Value(), m_Specific(X)))) + if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) + return true; + + unsigned BitWidth = V->getType()->getScalarSizeInBits(); + KnownBits LHSBits(BitWidth); + computeKnownBits(X, LHSBits, Depth, Q); + + KnownBits RHSBits(BitWidth); + computeKnownBits(Y, RHSBits, Depth, Q); + // If i8 V is a power of two or zero: + // ZeroBits: 1 1 1 0 1 1 1 1 + // ~ZeroBits: 0 0 0 1 0 0 0 0 + if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2()) + // If OrZero isn't set, we cannot give back a zero result. + // Make sure either the LHS or RHS has a bit set. + if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue()) + return true; + } + } + + // An exact divide or right shift can only shift off zero bits, so the result + // is a power of two only if the first operand is a power of two and not + // copying a sign bit (sdiv int_min, 2). + if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || + match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { + return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, + Depth, Q); + } + + return false; +} + +/// Test whether a GEP's result is known to be non-null. +/// +/// Uses properties inherent in a GEP to try to determine whether it is known +/// to be non-null. +/// +/// Currently this routine does not support vector GEPs. +static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, + const Query &Q) { + const Function *F = nullptr; + if (const Instruction *I = dyn_cast<Instruction>(GEP)) + F = I->getFunction(); + + if (!GEP->isInBounds() || + NullPointerIsDefined(F, GEP->getPointerAddressSpace())) + return false; + + // FIXME: Support vector-GEPs. + assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP"); + + // If the base pointer is non-null, we cannot walk to a null address with an + // inbounds GEP in address space zero. + if (isKnownNonZero(GEP->getPointerOperand(), Depth, Q)) + return true; + + // Walk the GEP operands and see if any operand introduces a non-zero offset. + // If so, then the GEP cannot produce a null pointer, as doing so would + // inherently violate the inbounds contract within address space zero. + for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP); + GTI != GTE; ++GTI) { + // Struct types are easy -- they must always be indexed by a constant. + if (StructType *STy = GTI.getStructTypeOrNull()) { + ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand()); + unsigned ElementIdx = OpC->getZExtValue(); + const StructLayout *SL = Q.DL.getStructLayout(STy); + uint64_t ElementOffset = SL->getElementOffset(ElementIdx); + if (ElementOffset > 0) + return true; + continue; + } + + // If we have a zero-sized type, the index doesn't matter. Keep looping. + if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0) + continue; + + // Fast path the constant operand case both for efficiency and so we don't + // increment Depth when just zipping down an all-constant GEP. + if (ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand())) { + if (!OpC->isZero()) + return true; + continue; + } + + // We post-increment Depth here because while isKnownNonZero increments it + // as well, when we pop back up that increment won't persist. We don't want + // to recurse 10k times just because we have 10k GEP operands. We don't + // bail completely out because we want to handle constant GEPs regardless + // of depth. + if (Depth++ >= MaxDepth) + continue; + + if (isKnownNonZero(GTI.getOperand(), Depth, Q)) + return true; + } + + return false; +} + +static bool isKnownNonNullFromDominatingCondition(const Value *V, + const Instruction *CtxI, + const DominatorTree *DT) { + assert(V->getType()->isPointerTy() && "V must be pointer type"); + assert(!isa<ConstantData>(V) && "Did not expect ConstantPointerNull"); + + if (!CtxI || !DT) + return false; + + unsigned NumUsesExplored = 0; + for (auto *U : V->users()) { + // Avoid massive lists + if (NumUsesExplored >= DomConditionsMaxUses) + break; + NumUsesExplored++; + + // If the value is used as an argument to a call or invoke, then argument + // attributes may provide an answer about null-ness. + if (auto CS = ImmutableCallSite(U)) + if (auto *CalledFunc = CS.getCalledFunction()) + for (const Argument &Arg : CalledFunc->args()) + if (CS.getArgOperand(Arg.getArgNo()) == V && + Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) + return true; + + // Consider only compare instructions uniquely controlling a branch + CmpInst::Predicate Pred; + if (!match(const_cast<User *>(U), + m_c_ICmp(Pred, m_Specific(V), m_Zero())) || + (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) + continue; + + SmallVector<const User *, 4> WorkList; + SmallPtrSet<const User *, 4> Visited; + for (auto *CmpU : U->users()) { + assert(WorkList.empty() && "Should be!"); + if (Visited.insert(CmpU).second) + WorkList.push_back(CmpU); + + while (!WorkList.empty()) { + auto *Curr = WorkList.pop_back_val(); + + // If a user is an AND, add all its users to the work list. We only + // propagate "pred != null" condition through AND because it is only + // correct to assume that all conditions of AND are met in true branch. + // TODO: Support similar logic of OR and EQ predicate? + if (Pred == ICmpInst::ICMP_NE) + if (auto *BO = dyn_cast<BinaryOperator>(Curr)) + if (BO->getOpcode() == Instruction::And) { + for (auto *BOU : BO->users()) + if (Visited.insert(BOU).second) + WorkList.push_back(BOU); + continue; + } + + if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) { + assert(BI->isConditional() && "uses a comparison!"); + + BasicBlock *NonNullSuccessor = + BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); + BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); + if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + return true; + } else if (Pred == ICmpInst::ICMP_NE && isGuard(Curr) && + DT->dominates(cast<Instruction>(Curr), CtxI)) { + return true; + } + } + } + } + + return false; +} + +/// Does the 'Range' metadata (which must be a valid MD_range operand list) +/// ensure that the value it's attached to is never Value? 'RangeType' is +/// is the type of the value described by the range. +static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) { + const unsigned NumRanges = Ranges->getNumOperands() / 2; + assert(NumRanges >= 1); + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + if (Range.contains(Value)) + return false; + } + return true; +} + +/// Return true if the given value is known to be non-zero when defined. For +/// vectors, return true if every element is known to be non-zero when +/// defined. For pointers, if the context instruction and dominator tree are +/// specified, perform context-sensitive analysis and return true if the +/// pointer couldn't possibly be null at the specified instruction. +/// Supports values with integer or pointer type and vectors of integers. +bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { + if (auto *C = dyn_cast<Constant>(V)) { + if (C->isNullValue()) + return false; + if (isa<ConstantInt>(C)) + // Must be non-zero due to null test above. + return true; + + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + // See the comment for IntToPtr/PtrToInt instructions below. + if (CE->getOpcode() == Instruction::IntToPtr || + CE->getOpcode() == Instruction::PtrToInt) + if (Q.DL.getTypeSizeInBits(CE->getOperand(0)->getType()) <= + Q.DL.getTypeSizeInBits(CE->getType())) + return isKnownNonZero(CE->getOperand(0), Depth, Q); + } + + // For constant vectors, check that all elements are undefined or known + // non-zero to determine that the whole vector is known non-zero. + if (auto *VecTy = dyn_cast<VectorType>(C->getType())) { + for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt || Elt->isNullValue()) + return false; + if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt)) + return false; + } + return true; + } + + // A global variable in address space 0 is non null unless extern weak + // or an absolute symbol reference. Other address spaces may have null as a + // valid address for a global, so we can't assume anything. + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && + GV->getType()->getAddressSpace() == 0) + return true; + } else + return false; + } + + if (auto *I = dyn_cast<Instruction>(V)) { + if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) { + // If the possible ranges don't contain zero, then the value is + // definitely non-zero. + if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { + const APInt ZeroValue(Ty->getBitWidth(), 0); + if (rangeMetadataExcludesValue(Ranges, ZeroValue)) + return true; + } + } + } + + // Some of the tests below are recursive, so bail out if we hit the limit. + if (Depth++ >= MaxDepth) + return false; + + // Check for pointer simplifications. + if (V->getType()->isPointerTy()) { + // Alloca never returns null, malloc might. + if (isa<AllocaInst>(V) && Q.DL.getAllocaAddrSpace() == 0) + return true; + + // A byval, inalloca, or nonnull argument is never null. + if (const Argument *A = dyn_cast<Argument>(V)) + if (A->hasByValOrInAllocaAttr() || A->hasNonNullAttr()) + return true; + + // A Load tagged with nonnull metadata is never null. + if (const LoadInst *LI = dyn_cast<LoadInst>(V)) + if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull)) + return true; + + if (const auto *Call = dyn_cast<CallBase>(V)) { + if (Call->isReturnNonNull()) + return true; + if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true)) + return isKnownNonZero(RP, Depth, Q); + } + } + + + // Check for recursive pointer simplifications. + if (V->getType()->isPointerTy()) { + if (isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT)) + return true; + + // Look through bitcast operations, GEPs, and int2ptr instructions as they + // do not alter the value, or at least not the nullness property of the + // value, e.g., int2ptr is allowed to zero/sign extend the value. + // + // Note that we have to take special care to avoid looking through + // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well + // as casts that can alter the value, e.g., AddrSpaceCasts. + if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) + if (isGEPKnownNonNull(GEP, Depth, Q)) + return true; + + if (auto *BCO = dyn_cast<BitCastOperator>(V)) + return isKnownNonZero(BCO->getOperand(0), Depth, Q); + + if (auto *I2P = dyn_cast<IntToPtrInst>(V)) + if (Q.DL.getTypeSizeInBits(I2P->getSrcTy()) <= + Q.DL.getTypeSizeInBits(I2P->getDestTy())) + return isKnownNonZero(I2P->getOperand(0), Depth, Q); + } + + // Similar to int2ptr above, we can look through ptr2int here if the cast + // is a no-op or an extend and not a truncate. + if (auto *P2I = dyn_cast<PtrToIntInst>(V)) + if (Q.DL.getTypeSizeInBits(P2I->getSrcTy()) <= + Q.DL.getTypeSizeInBits(P2I->getDestTy())) + return isKnownNonZero(P2I->getOperand(0), Depth, Q); + + unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL); + + // X | Y != 0 if X != 0 or Y != 0. + Value *X = nullptr, *Y = nullptr; + if (match(V, m_Or(m_Value(X), m_Value(Y)))) + return isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q); + + // ext X != 0 if X != 0. + if (isa<SExtInst>(V) || isa<ZExtInst>(V)) + return isKnownNonZero(cast<Instruction>(V)->getOperand(0), Depth, Q); + + // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined + // if the lowest bit is shifted off the end. + if (match(V, m_Shl(m_Value(X), m_Value(Y)))) { + // shl nuw can't remove any non-zero bits. + const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); + if (Q.IIQ.hasNoUnsignedWrap(BO)) + return isKnownNonZero(X, Depth, Q); + + KnownBits Known(BitWidth); + computeKnownBits(X, Known, Depth, Q); + if (Known.One[0]) + return true; + } + // shr X, Y != 0 if X is negative. Note that the value of the shift is not + // defined if the sign bit is shifted off the end. + else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) { + // shr exact can only shift out zero bits. + const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); + if (BO->isExact()) + return isKnownNonZero(X, Depth, Q); + + KnownBits Known = computeKnownBits(X, Depth, Q); + if (Known.isNegative()) + return true; + + // If the shifter operand is a constant, and all of the bits shifted + // out are known to be zero, and X is known non-zero then at least one + // non-zero bit must remain. + if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) { + auto ShiftVal = Shift->getLimitedValue(BitWidth - 1); + // Is there a known one in the portion not shifted out? + if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal) + return true; + // Are all the bits to be shifted out known zero? + if (Known.countMinTrailingZeros() >= ShiftVal) + return isKnownNonZero(X, Depth, Q); + } + } + // div exact can only produce a zero if the dividend is zero. + else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { + return isKnownNonZero(X, Depth, Q); + } + // X + Y. + else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { + KnownBits XKnown = computeKnownBits(X, Depth, Q); + KnownBits YKnown = computeKnownBits(Y, Depth, Q); + + // If X and Y are both non-negative (as signed values) then their sum is not + // zero unless both X and Y are zero. + if (XKnown.isNonNegative() && YKnown.isNonNegative()) + if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q)) + return true; + + // If X and Y are both negative (as signed values) then their sum is not + // zero unless both X and Y equal INT_MIN. + if (XKnown.isNegative() && YKnown.isNegative()) { + APInt Mask = APInt::getSignedMaxValue(BitWidth); + // The sign bit of X is set. If some other bit is set then X is not equal + // to INT_MIN. + if (XKnown.One.intersects(Mask)) + return true; + // The sign bit of Y is set. If some other bit is set then Y is not equal + // to INT_MIN. + if (YKnown.One.intersects(Mask)) + return true; + } + + // The sum of a non-negative number and a power of two is not zero. + if (XKnown.isNonNegative() && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q)) + return true; + if (YKnown.isNonNegative() && + isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) + return true; + } + // X * Y. + else if (match(V, m_Mul(m_Value(X), m_Value(Y)))) { + const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); + // If X and Y are non-zero then so is X * Y as long as the multiplication + // does not overflow. + if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && + isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) + return true; + } + // (C ? X : Y) != 0 if X != 0 and Y != 0. + else if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { + if (isKnownNonZero(SI->getTrueValue(), Depth, Q) && + isKnownNonZero(SI->getFalseValue(), Depth, Q)) + return true; + } + // PHI + else if (const PHINode *PN = dyn_cast<PHINode>(V)) { + // Try and detect a recurrence that monotonically increases from a + // starting value, as these are common as induction variables. + if (PN->getNumIncomingValues() == 2) { + Value *Start = PN->getIncomingValue(0); + Value *Induction = PN->getIncomingValue(1); + if (isa<ConstantInt>(Induction) && !isa<ConstantInt>(Start)) + std::swap(Start, Induction); + if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) { + if (!C->isZero() && !C->isNegative()) { + ConstantInt *X; + if (Q.IIQ.UseInstrInfo && + (match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || + match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) && + !X->isNegative()) + return true; + } + } + } + // Check if all incoming values are non-zero constant. + bool AllNonZeroConstants = llvm::all_of(PN->operands(), [](Value *V) { + return isa<ConstantInt>(V) && !cast<ConstantInt>(V)->isZero(); + }); + if (AllNonZeroConstants) + return true; + } + + KnownBits Known(BitWidth); + computeKnownBits(V, Known, Depth, Q); + return Known.One != 0; +} + +/// Return true if V2 == V1 + X, where X is known non-zero. +static bool isAddOfNonZero(const Value *V1, const Value *V2, const Query &Q) { + const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1); + if (!BO || BO->getOpcode() != Instruction::Add) + return false; + Value *Op = nullptr; + if (V2 == BO->getOperand(0)) + Op = BO->getOperand(1); + else if (V2 == BO->getOperand(1)) + Op = BO->getOperand(0); + else + return false; + return isKnownNonZero(Op, 0, Q); +} + +/// Return true if it is known that V1 != V2. +static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q) { + if (V1 == V2) + return false; + if (V1->getType() != V2->getType()) + // We can't look through casts yet. + return false; + if (isAddOfNonZero(V1, V2, Q) || isAddOfNonZero(V2, V1, Q)) + return true; + + if (V1->getType()->isIntOrIntVectorTy()) { + // Are any known bits in V1 contradictory to known bits in V2? If V1 + // has a known zero where V2 has a known one, they must not be equal. + KnownBits Known1 = computeKnownBits(V1, 0, Q); + KnownBits Known2 = computeKnownBits(V2, 0, Q); + + if (Known1.Zero.intersects(Known2.One) || + Known2.Zero.intersects(Known1.One)) + return true; + } + return false; +} + +/// Return true if 'V & Mask' is known to be zero. We use this predicate to +/// simplify operations downstream. Mask is known to be zero for bits that V +/// cannot have. +/// +/// This function is defined on values with integer type, values with pointer +/// type, and vectors of integers. In the case +/// where V is a vector, the mask, known zero, and known one values are the +/// same width as the vector element, and the bit is set only if it is true +/// for all of the elements in the vector. +bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, + const Query &Q) { + KnownBits Known(Mask.getBitWidth()); + computeKnownBits(V, Known, Depth, Q); + return Mask.isSubsetOf(Known.Zero); +} + +// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow). +// Returns the input and lower/upper bounds. +static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, + const APInt *&CLow, const APInt *&CHigh) { + assert(isa<Operator>(Select) && + cast<Operator>(Select)->getOpcode() == Instruction::Select && + "Input should be a Select!"); + + const Value *LHS = nullptr, *RHS = nullptr; + SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor; + if (SPF != SPF_SMAX && SPF != SPF_SMIN) + return false; + + if (!match(RHS, m_APInt(CLow))) + return false; + + const Value *LHS2 = nullptr, *RHS2 = nullptr; + SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor; + if (getInverseMinMaxFlavor(SPF) != SPF2) + return false; + + if (!match(RHS2, m_APInt(CHigh))) + return false; + + if (SPF == SPF_SMIN) + std::swap(CLow, CHigh); + + In = LHS2; + return CLow->sle(*CHigh); +} + +/// For vector constants, loop over the elements and find the constant with the +/// minimum number of sign bits. Return 0 if the value is not a vector constant +/// or if any element was not analyzed; otherwise, return the count for the +/// element with the minimum number of sign bits. +static unsigned computeNumSignBitsVectorConstant(const Value *V, + unsigned TyBits) { + const auto *CV = dyn_cast<Constant>(V); + if (!CV || !CV->getType()->isVectorTy()) + return 0; + + unsigned MinSignBits = TyBits; + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + // If we find a non-ConstantInt, bail out. + auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i)); + if (!Elt) + return 0; + + MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits()); + } + + return MinSignBits; +} + +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q); + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, + const Query &Q) { + unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q); + assert(Result > 0 && "At least one sign bit needs to be present!"); + return Result; +} + +/// Return the number of times the sign bit of the register is replicated into +/// the other bits. We know that at least 1 bit is always equal to the sign bit +/// (itself), but other cases can give us information. For example, immediately +/// after an "ashr X, 2", we know that the top 3 bits are all equal to each +/// other, so we return 3. For vectors, return the number of sign bits for the +/// vector element with the minimum number of known sign bits. +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q) { + assert(Depth <= MaxDepth && "Limit Search Depth"); + + // We return the minimum number of sign bits that are guaranteed to be present + // in V, so for undef we have to conservatively return 1. We don't have the + // same behavior for poison though -- that's a FIXME today. + + Type *ScalarTy = V->getType()->getScalarType(); + unsigned TyBits = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : + Q.DL.getTypeSizeInBits(ScalarTy); + + unsigned Tmp, Tmp2; + unsigned FirstAnswer = 1; + + // Note that ConstantInt is handled by the general computeKnownBits case + // below. + + if (Depth == MaxDepth) + return 1; // Limit search depth. + + if (auto *U = dyn_cast<Operator>(V)) { + switch (Operator::getOpcode(V)) { + default: break; + case Instruction::SExt: + Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); + return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp; + + case Instruction::SDiv: { + const APInt *Denominator; + // sdiv X, C -> adds log(C) sign bits. + if (match(U->getOperand(1), m_APInt(Denominator))) { + + // Ignore non-positive denominator. + if (!Denominator->isStrictlyPositive()) + break; + + // Calculate the incoming numerator bits. + unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + + // Add floor(log(C)) bits to the numerator bits. + return std::min(TyBits, NumBits + Denominator->logBase2()); + } + break; + } + + case Instruction::SRem: { + const APInt *Denominator; + // srem X, C -> we know that the result is within [-C+1,C) when C is a + // positive constant. This let us put a lower bound on the number of sign + // bits. + if (match(U->getOperand(1), m_APInt(Denominator))) { + + // Ignore non-positive denominator. + if (!Denominator->isStrictlyPositive()) + break; + + // Calculate the incoming numerator bits. SRem by a positive constant + // can't lower the number of sign bits. + unsigned NumrBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + + // Calculate the leading sign bit constraints by examining the + // denominator. Given that the denominator is positive, there are two + // cases: + // + // 1. the numerator is positive. The result range is [0,C) and [0,C) u< + // (1 << ceilLogBase2(C)). + // + // 2. the numerator is negative. Then the result range is (-C,0] and + // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)). + // + // Thus a lower bound on the number of sign bits is `TyBits - + // ceilLogBase2(C)`. + + unsigned ResBits = TyBits - Denominator->ceilLogBase2(); + return std::max(NumrBits, ResBits); + } + break; + } + + case Instruction::AShr: { + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + // ashr X, C -> adds C sign bits. Vectors too. + const APInt *ShAmt; + if (match(U->getOperand(1), m_APInt(ShAmt))) { + if (ShAmt->uge(TyBits)) + break; // Bad shift. + unsigned ShAmtLimited = ShAmt->getZExtValue(); + Tmp += ShAmtLimited; + if (Tmp > TyBits) Tmp = TyBits; + } + return Tmp; + } + case Instruction::Shl: { + const APInt *ShAmt; + if (match(U->getOperand(1), m_APInt(ShAmt))) { + // shl destroys sign bits. + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (ShAmt->uge(TyBits) || // Bad shift. + ShAmt->uge(Tmp)) break; // Shifted all sign bits out. + Tmp2 = ShAmt->getZExtValue(); + return Tmp - Tmp2; + } + break; + } + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: // NOT is handled here. + // Logical binary ops preserve the number of sign bits at the worst. + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (Tmp != 1) { + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + FirstAnswer = std::min(Tmp, Tmp2); + // We computed what we know about the sign bits as our first + // answer. Now proceed to the generic code that uses + // computeKnownBits, and pick whichever answer is better. + } + break; + + case Instruction::Select: { + // If we have a clamp pattern, we know that the number of sign bits will + // be the minimum of the clamp min/max range. + const Value *X; + const APInt *CLow, *CHigh; + if (isSignedMinMaxClamp(U, X, CLow, CHigh)) + return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits()); + + Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (Tmp == 1) break; + Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); + return std::min(Tmp, Tmp2); + } + + case Instruction::Add: + // Add can have at most one carry bit. Thus we know that the output + // is, at worst, one more bit than the inputs. + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (Tmp == 1) break; + + // Special case decrementing a value (ADD X, -1): + if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) + if (CRHS->isAllOnesValue()) { + KnownBits Known(TyBits); + computeKnownBits(U->getOperand(0), Known, Depth + 1, Q); + + // If the input is known to be 0 or 1, the output is 0/-1, which is + // all sign bits set. + if ((Known.Zero | 1).isAllOnesValue()) + return TyBits; + + // If we are subtracting one from a positive number, there is no carry + // out of the result. + if (Known.isNonNegative()) + return Tmp; + } + + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (Tmp2 == 1) break; + return std::min(Tmp, Tmp2) - 1; + + case Instruction::Sub: + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (Tmp2 == 1) break; + + // Handle NEG. + if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) + if (CLHS->isNullValue()) { + KnownBits Known(TyBits); + computeKnownBits(U->getOperand(1), Known, Depth + 1, Q); + // If the input is known to be 0 or 1, the output is 0/-1, which is + // all sign bits set. + if ((Known.Zero | 1).isAllOnesValue()) + return TyBits; + + // If the input is known to be positive (the sign bit is known clear), + // the output of the NEG has the same number of sign bits as the + // input. + if (Known.isNonNegative()) + return Tmp2; + + // Otherwise, we treat this like a SUB. + } + + // Sub can have at most one carry bit. Thus we know that the output + // is, at worst, one more bit than the inputs. + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (Tmp == 1) break; + return std::min(Tmp, Tmp2) - 1; + + case Instruction::Mul: { + // The output of the Mul can be at most twice the valid bits in the + // inputs. + unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (SignBitsOp0 == 1) break; + unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (SignBitsOp1 == 1) break; + unsigned OutValidBits = + (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1); + return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; + } + + case Instruction::PHI: { + const PHINode *PN = cast<PHINode>(U); + unsigned NumIncomingValues = PN->getNumIncomingValues(); + // Don't analyze large in-degree PHIs. + if (NumIncomingValues > 4) break; + // Unreachable blocks may have zero-operand PHI nodes. + if (NumIncomingValues == 0) break; + + // Take the minimum of all incoming values. This can't infinitely loop + // because of our depth threshold. + Tmp = ComputeNumSignBits(PN->getIncomingValue(0), Depth + 1, Q); + for (unsigned i = 1, e = NumIncomingValues; i != e; ++i) { + if (Tmp == 1) return Tmp; + Tmp = std::min( + Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, Q)); + } + return Tmp; + } + + case Instruction::Trunc: + // FIXME: it's tricky to do anything useful for this, but it is an + // important case for targets like X86. + break; + + case Instruction::ExtractElement: + // Look through extract element. At the moment we keep this simple and + // skip tracking the specific element. But at least we might find + // information valid for all elements of the vector (for example if vector + // is sign extended, shifted, etc). + return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + + case Instruction::ShuffleVector: { + // TODO: This is copied almost directly from the SelectionDAG version of + // ComputeNumSignBits. It would be better if we could share common + // code. If not, make sure that changes are translated to the DAG. + + // Collect the minimum number of sign bits that are shared by every vector + // element referenced by the shuffle. + auto *Shuf = cast<ShuffleVectorInst>(U); + int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); + int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); + APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); + for (int i = 0; i != NumMaskElts; ++i) { + int M = Shuf->getMaskValue(i); + assert(M < NumElts * 2 && "Invalid shuffle mask constant"); + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M == -1) + return 1; + if (M < NumElts) + DemandedLHS.setBit(M % NumElts); + else + DemandedRHS.setBit(M % NumElts); + } + Tmp = std::numeric_limits<unsigned>::max(); + if (!!DemandedLHS) + Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q); + if (!!DemandedRHS) { + Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q); + Tmp = std::min(Tmp, Tmp2); + } + // If we don't know anything, early out and try computeKnownBits + // fall-back. + if (Tmp == 1) + break; + assert(Tmp <= V->getType()->getScalarSizeInBits() && + "Failed to determine minimum sign bits"); + return Tmp; + } + } + } + + // Finally, if we can prove that the top bits of the result are 0's or 1's, + // use this information. + + // If we can examine all elements of a vector constant successfully, we're + // done (we can't do any better than that). If not, keep trying. + if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits)) + return VecSignBits; + + KnownBits Known(TyBits); + computeKnownBits(V, Known, Depth, Q); + + // If we know that the sign bit is either zero or one, determine the number of + // identical bits in the top of the input value. + return std::max(FirstAnswer, Known.countMinSignBits()); +} + +/// This function computes the integer multiple of Base that equals V. +/// If successful, it returns true and returns the multiple in +/// Multiple. If unsuccessful, it returns false. It looks +/// through SExt instructions only if LookThroughSExt is true. +bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, + bool LookThroughSExt, unsigned Depth) { + assert(V && "No Value?"); + assert(Depth <= MaxDepth && "Limit Search Depth"); + assert(V->getType()->isIntegerTy() && "Not integer or pointer type!"); + + Type *T = V->getType(); + + ConstantInt *CI = dyn_cast<ConstantInt>(V); + + if (Base == 0) + return false; + + if (Base == 1) { + Multiple = V; + return true; + } + + ConstantExpr *CO = dyn_cast<ConstantExpr>(V); + Constant *BaseVal = ConstantInt::get(T, Base); + if (CO && CO == BaseVal) { + // Multiple is 1. + Multiple = ConstantInt::get(T, 1); + return true; + } + + if (CI && CI->getZExtValue() % Base == 0) { + Multiple = ConstantInt::get(T, CI->getZExtValue() / Base); + return true; + } + + if (Depth == MaxDepth) return false; // Limit search depth. + + Operator *I = dyn_cast<Operator>(V); + if (!I) return false; + + switch (I->getOpcode()) { + default: break; + case Instruction::SExt: + if (!LookThroughSExt) return false; + // otherwise fall through to ZExt + LLVM_FALLTHROUGH; + case Instruction::ZExt: + return ComputeMultiple(I->getOperand(0), Base, Multiple, + LookThroughSExt, Depth+1); + case Instruction::Shl: + case Instruction::Mul: { + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + + if (I->getOpcode() == Instruction::Shl) { + ConstantInt *Op1CI = dyn_cast<ConstantInt>(Op1); + if (!Op1CI) return false; + // Turn Op0 << Op1 into Op0 * 2^Op1 + APInt Op1Int = Op1CI->getValue(); + uint64_t BitToSet = Op1Int.getLimitedValue(Op1Int.getBitWidth() - 1); + APInt API(Op1Int.getBitWidth(), 0); + API.setBit(BitToSet); + Op1 = ConstantInt::get(V->getContext(), API); + } + + Value *Mul0 = nullptr; + if (ComputeMultiple(Op0, Base, Mul0, LookThroughSExt, Depth+1)) { + if (Constant *Op1C = dyn_cast<Constant>(Op1)) + if (Constant *MulC = dyn_cast<Constant>(Mul0)) { + if (Op1C->getType()->getPrimitiveSizeInBits() < + MulC->getType()->getPrimitiveSizeInBits()) + Op1C = ConstantExpr::getZExt(Op1C, MulC->getType()); + if (Op1C->getType()->getPrimitiveSizeInBits() > + MulC->getType()->getPrimitiveSizeInBits()) + MulC = ConstantExpr::getZExt(MulC, Op1C->getType()); + + // V == Base * (Mul0 * Op1), so return (Mul0 * Op1) + Multiple = ConstantExpr::getMul(MulC, Op1C); + return true; + } + + if (ConstantInt *Mul0CI = dyn_cast<ConstantInt>(Mul0)) + if (Mul0CI->getValue() == 1) { + // V == Base * Op1, so return Op1 + Multiple = Op1; + return true; + } + } + + Value *Mul1 = nullptr; + if (ComputeMultiple(Op1, Base, Mul1, LookThroughSExt, Depth+1)) { + if (Constant *Op0C = dyn_cast<Constant>(Op0)) + if (Constant *MulC = dyn_cast<Constant>(Mul1)) { + if (Op0C->getType()->getPrimitiveSizeInBits() < + MulC->getType()->getPrimitiveSizeInBits()) + Op0C = ConstantExpr::getZExt(Op0C, MulC->getType()); + if (Op0C->getType()->getPrimitiveSizeInBits() > + MulC->getType()->getPrimitiveSizeInBits()) + MulC = ConstantExpr::getZExt(MulC, Op0C->getType()); + + // V == Base * (Mul1 * Op0), so return (Mul1 * Op0) + Multiple = ConstantExpr::getMul(MulC, Op0C); + return true; + } + + if (ConstantInt *Mul1CI = dyn_cast<ConstantInt>(Mul1)) + if (Mul1CI->getValue() == 1) { + // V == Base * Op0, so return Op0 + Multiple = Op0; + return true; + } + } + } + } + + // We could not determine if V is a multiple of Base. + return false; +} + +Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, + const TargetLibraryInfo *TLI) { + const Function *F = ICS.getCalledFunction(); + if (!F) + return Intrinsic::not_intrinsic; + + if (F->isIntrinsic()) + return F->getIntrinsicID(); + + if (!TLI) + return Intrinsic::not_intrinsic; + + LibFunc Func; + // We're going to make assumptions on the semantics of the functions, check + // that the target knows that it's available in this environment and it does + // not have local linkage. + if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(*F, Func)) + return Intrinsic::not_intrinsic; + + if (!ICS.onlyReadsMemory()) + return Intrinsic::not_intrinsic; + + // Otherwise check if we have a call to a function that can be turned into a + // vector intrinsic. + switch (Func) { + default: + break; + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: + return Intrinsic::sin; + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: + return Intrinsic::cos; + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + return Intrinsic::exp; + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + return Intrinsic::exp2; + case LibFunc_log: + case LibFunc_logf: + case LibFunc_logl: + return Intrinsic::log; + case LibFunc_log10: + case LibFunc_log10f: + case LibFunc_log10l: + return Intrinsic::log10; + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log2l: + return Intrinsic::log2; + case LibFunc_fabs: + case LibFunc_fabsf: + case LibFunc_fabsl: + return Intrinsic::fabs; + case LibFunc_fmin: + case LibFunc_fminf: + case LibFunc_fminl: + return Intrinsic::minnum; + case LibFunc_fmax: + case LibFunc_fmaxf: + case LibFunc_fmaxl: + return Intrinsic::maxnum; + case LibFunc_copysign: + case LibFunc_copysignf: + case LibFunc_copysignl: + return Intrinsic::copysign; + case LibFunc_floor: + case LibFunc_floorf: + case LibFunc_floorl: + return Intrinsic::floor; + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: + return Intrinsic::ceil; + case LibFunc_trunc: + case LibFunc_truncf: + case LibFunc_truncl: + return Intrinsic::trunc; + case LibFunc_rint: + case LibFunc_rintf: + case LibFunc_rintl: + return Intrinsic::rint; + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_nearbyintl: + return Intrinsic::nearbyint; + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: + return Intrinsic::round; + case LibFunc_pow: + case LibFunc_powf: + case LibFunc_powl: + return Intrinsic::pow; + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + return Intrinsic::sqrt; + } + + return Intrinsic::not_intrinsic; +} + +/// Return true if we can prove that the specified FP value is never equal to +/// -0.0. +/// +/// NOTE: this function will need to be revisited when we support non-default +/// rounding modes! +bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, + unsigned Depth) { + if (auto *CFP = dyn_cast<ConstantFP>(V)) + return !CFP->getValueAPF().isNegZero(); + + // Limit search depth. + if (Depth == MaxDepth) + return false; + + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return false; + + // Check if the nsz fast-math flag is set. + if (auto *FPO = dyn_cast<FPMathOperator>(Op)) + if (FPO->hasNoSignedZeros()) + return true; + + // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. + if (match(Op, m_FAdd(m_Value(), m_PosZeroFP()))) + return true; + + // sitofp and uitofp turn into +0.0 for zero. + if (isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) + return true; + + if (auto *Call = dyn_cast<CallInst>(Op)) { + Intrinsic::ID IID = getIntrinsicForCallSite(Call, TLI); + switch (IID) { + default: + break; + // sqrt(-0.0) = -0.0, no other negative results are possible. + case Intrinsic::sqrt: + case Intrinsic::canonicalize: + return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); + // fabs(x) != -0.0 + case Intrinsic::fabs: + return true; + } + } + + return false; +} + +/// If \p SignBitOnly is true, test for a known 0 sign bit rather than a +/// standard ordered compare. e.g. make -0.0 olt 0.0 be true because of the sign +/// bit despite comparing equal. +static bool cannotBeOrderedLessThanZeroImpl(const Value *V, + const TargetLibraryInfo *TLI, + bool SignBitOnly, + unsigned Depth) { + // TODO: This function does not do the right thing when SignBitOnly is true + // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform + // which flips the sign bits of NaNs. See + // https://llvm.org/bugs/show_bug.cgi?id=31702. + + if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { + return !CFP->getValueAPF().isNegative() || + (!SignBitOnly && CFP->getValueAPF().isZero()); + } + + // Handle vector of constants. + if (auto *CV = dyn_cast<Constant>(V)) { + if (CV->getType()->isVectorTy()) { + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); + if (!CFP) + return false; + if (CFP->getValueAPF().isNegative() && + (SignBitOnly || !CFP->getValueAPF().isZero())) + return false; + } + + // All non-negative ConstantFPs. + return true; + } + } + + if (Depth == MaxDepth) + return false; // Limit search depth. + + const Operator *I = dyn_cast<Operator>(V); + if (!I) + return false; + + switch (I->getOpcode()) { + default: + break; + // Unsigned integers are always nonnegative. + case Instruction::UIToFP: + return true; + case Instruction::FMul: + // x*x is always non-negative or a NaN. + if (I->getOperand(0) == I->getOperand(1) && + (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs())) + return true; + + LLVM_FALLTHROUGH; + case Instruction::FAdd: + case Instruction::FDiv: + case Instruction::FRem: + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, + Depth + 1); + case Instruction::Select: + return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, + Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, + Depth + 1); + case Instruction::FPExt: + case Instruction::FPTrunc: + // Widening/narrowing never change sign. + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1); + case Instruction::ExtractElement: + // Look through extract element. At the moment we keep this simple and skip + // tracking the specific element. But at least we might find information + // valid for all elements of the vector. + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1); + case Instruction::Call: + const auto *CI = cast<CallInst>(I); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + switch (IID) { + default: + break; + case Intrinsic::maxnum: + return (isKnownNeverNaN(I->getOperand(0), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, + SignBitOnly, Depth + 1)) || + (isKnownNeverNaN(I->getOperand(1), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, + SignBitOnly, Depth + 1)); + + case Intrinsic::maximum: + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1) || + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, + Depth + 1); + case Intrinsic::minnum: + case Intrinsic::minimum: + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, + Depth + 1); + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::fabs: + return true; + + case Intrinsic::sqrt: + // sqrt(x) is always >= -0 or NaN. Moreover, sqrt(x) == -0 iff x == -0. + if (!SignBitOnly) + return true; + return CI->hasNoNaNs() && (CI->hasNoSignedZeros() || + CannotBeNegativeZero(CI->getOperand(0), TLI)); + + case Intrinsic::powi: + if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) { + // powi(x,n) is non-negative if n is even. + if (Exponent->getBitWidth() <= 64 && Exponent->getSExtValue() % 2u == 0) + return true; + } + // TODO: This is not correct. Given that exp is an integer, here are the + // ways that pow can return a negative value: + // + // pow(x, exp) --> negative if exp is odd and x is negative. + // pow(-0, exp) --> -inf if exp is negative odd. + // pow(-0, exp) --> -0 if exp is positive odd. + // pow(-inf, exp) --> -0 if exp is negative odd. + // pow(-inf, exp) --> -inf if exp is positive odd. + // + // Therefore, if !SignBitOnly, we can return true if x >= +0 or x is NaN, + // but we must return false if x == -0. Unfortunately we do not currently + // have a way of expressing this constraint. See details in + // https://llvm.org/bugs/show_bug.cgi?id=31702. + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1); + + case Intrinsic::fma: + case Intrinsic::fmuladd: + // x*x+y is non-negative if y is non-negative. + return I->getOperand(0) == I->getOperand(1) && + (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs()) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, + Depth + 1); + } + break; + } + return false; +} + +bool llvm::CannotBeOrderedLessThanZero(const Value *V, + const TargetLibraryInfo *TLI) { + return cannotBeOrderedLessThanZeroImpl(V, TLI, false, 0); +} + +bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { + return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); +} + +bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, + unsigned Depth) { + assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); + + // If we're told that NaNs won't happen, assume they won't. + if (auto *FPMathOp = dyn_cast<FPMathOperator>(V)) + if (FPMathOp->hasNoNaNs()) + return true; + + // Handle scalar constants. + if (auto *CFP = dyn_cast<ConstantFP>(V)) + return !CFP->isNaN(); + + if (Depth == MaxDepth) + return false; + + if (auto *Inst = dyn_cast<Instruction>(V)) { + switch (Inst->getOpcode()) { + case Instruction::FAdd: + case Instruction::FMul: + case Instruction::FSub: + case Instruction::FDiv: + case Instruction::FRem: { + // TODO: Need isKnownNeverInfinity + return false; + } + case Instruction::Select: { + return isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && + isKnownNeverNaN(Inst->getOperand(2), TLI, Depth + 1); + } + case Instruction::SIToFP: + case Instruction::UIToFP: + return true; + case Instruction::FPTrunc: + case Instruction::FPExt: + return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1); + default: + break; + } + } + + if (const auto *II = dyn_cast<IntrinsicInst>(V)) { + switch (II->getIntrinsicID()) { + case Intrinsic::canonicalize: + case Intrinsic::fabs: + case Intrinsic::copysign: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1); + case Intrinsic::sqrt: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(II->getArgOperand(0), TLI); + case Intrinsic::minnum: + case Intrinsic::maxnum: + // If either operand is not NaN, the result is not NaN. + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) || + isKnownNeverNaN(II->getArgOperand(1), TLI, Depth + 1); + default: + return false; + } + } + + // Bail out for constant expressions, but try to handle vector constants. + if (!V->getType()->isVectorTy() || !isa<Constant>(V)) + return false; + + // For vectors, verify that each element is not NaN. + unsigned NumElts = V->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast<Constant>(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CElt = dyn_cast<ConstantFP>(Elt); + if (!CElt || CElt->isNaN()) + return false; + } + // All elements were confirmed not-NaN or undefined. + return true; +} + +Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { + + // All byte-wide stores are splatable, even of arbitrary variables. + if (V->getType()->isIntegerTy(8)) + return V; + + LLVMContext &Ctx = V->getContext(); + + // Undef don't care. + auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx)); + if (isa<UndefValue>(V)) + return UndefInt8; + + const uint64_t Size = DL.getTypeStoreSize(V->getType()); + if (!Size) + return UndefInt8; + + Constant *C = dyn_cast<Constant>(V); + if (!C) { + // Conceptually, we could handle things like: + // %a = zext i8 %X to i16 + // %b = shl i16 %a, 8 + // %c = or i16 %a, %b + // but until there is an example that actually needs this, it doesn't seem + // worth worrying about. + return nullptr; + } + + // Handle 'null' ConstantArrayZero etc. + if (C->isNullValue()) + return Constant::getNullValue(Type::getInt8Ty(Ctx)); + + // Constant floating-point values can be handled as integer values if the + // corresponding integer value is "byteable". An important case is 0.0. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { + Type *Ty = nullptr; + if (CFP->getType()->isHalfTy()) + Ty = Type::getInt16Ty(Ctx); + else if (CFP->getType()->isFloatTy()) + Ty = Type::getInt32Ty(Ctx); + else if (CFP->getType()->isDoubleTy()) + Ty = Type::getInt64Ty(Ctx); + // Don't handle long double formats, which have strange constraints. + return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL) + : nullptr; + } + + // We can handle constant integers that are multiple of 8 bits. + if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) { + if (CI->getBitWidth() % 8 == 0) { + assert(CI->getBitWidth() > 8 && "8 bits should be handled above!"); + if (!CI->getValue().isSplat(8)) + return nullptr; + return ConstantInt::get(Ctx, CI->getValue().trunc(8)); + } + } + + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::IntToPtr) { + auto PS = DL.getPointerSizeInBits( + cast<PointerType>(CE->getType())->getAddressSpace()); + return isBytewiseValue( + ConstantExpr::getIntegerCast(CE->getOperand(0), + Type::getIntNTy(Ctx, PS), false), + DL); + } + } + + auto Merge = [&](Value *LHS, Value *RHS) -> Value * { + if (LHS == RHS) + return LHS; + if (!LHS || !RHS) + return nullptr; + if (LHS == UndefInt8) + return RHS; + if (RHS == UndefInt8) + return LHS; + return nullptr; + }; + + if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL)))) + return nullptr; + return Val; + } + + if (isa<ConstantAggregate>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I), DL)))) + return nullptr; + return Val; + } + + // Don't try to handle the handful of other constants. + return nullptr; +} + +// This is the recursive version of BuildSubAggregate. It takes a few different +// arguments. Idxs is the index within the nested struct From that we are +// looking at now (which is of type IndexedType). IdxSkip is the number of +// indices from Idxs that should be left out when inserting into the resulting +// struct. To is the result struct built so far, new insertvalue instructions +// build on that. +static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType, + SmallVectorImpl<unsigned> &Idxs, + unsigned IdxSkip, + Instruction *InsertBefore) { + StructType *STy = dyn_cast<StructType>(IndexedType); + if (STy) { + // Save the original To argument so we can modify it + Value *OrigTo = To; + // General case, the type indexed by Idxs is a struct + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + // Process each struct element recursively + Idxs.push_back(i); + Value *PrevTo = To; + To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip, + InsertBefore); + Idxs.pop_back(); + if (!To) { + // Couldn't find any inserted value for this index? Cleanup + while (PrevTo != OrigTo) { + InsertValueInst* Del = cast<InsertValueInst>(PrevTo); + PrevTo = Del->getAggregateOperand(); + Del->eraseFromParent(); + } + // Stop processing elements + break; + } + } + // If we successfully found a value for each of our subaggregates + if (To) + return To; + } + // Base case, the type indexed by SourceIdxs is not a struct, or not all of + // the struct's elements had a value that was inserted directly. In the latter + // case, perhaps we can't determine each of the subelements individually, but + // we might be able to find the complete struct somewhere. + + // Find the value that is at that particular spot + Value *V = FindInsertedValue(From, Idxs); + + if (!V) + return nullptr; + + // Insert the value in the new (sub) aggregate + return InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip), + "tmp", InsertBefore); +} + +// This helper takes a nested struct and extracts a part of it (which is again a +// struct) into a new value. For example, given the struct: +// { a, { b, { c, d }, e } } +// and the indices "1, 1" this returns +// { c, d }. +// +// It does this by inserting an insertvalue for each element in the resulting +// struct, as opposed to just inserting a single struct. This will only work if +// each of the elements of the substruct are known (ie, inserted into From by an +// insertvalue instruction somewhere). +// +// All inserted insertvalue instructions are inserted before InsertBefore +static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range, + Instruction *InsertBefore) { + assert(InsertBefore && "Must have someplace to insert!"); + Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(), + idx_range); + Value *To = UndefValue::get(IndexedType); + SmallVector<unsigned, 10> Idxs(idx_range.begin(), idx_range.end()); + unsigned IdxSkip = Idxs.size(); + + return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); +} + +/// Given an aggregate and a sequence of indices, see if the scalar value +/// indexed is already around as a register, for example if it was inserted +/// directly into the aggregate. +/// +/// If InsertBefore is not null, this function will duplicate (modified) +/// insertvalues when a part of a nested struct is extracted. +Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range, + Instruction *InsertBefore) { + // Nothing to index? Just return V then (this is useful at the end of our + // recursion). + if (idx_range.empty()) + return V; + // We have indices, so V should have an indexable type. + assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) && + "Not looking at a struct or array?"); + assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) && + "Invalid indices for type?"); + + if (Constant *C = dyn_cast<Constant>(V)) { + C = C->getAggregateElement(idx_range[0]); + if (!C) return nullptr; + return FindInsertedValue(C, idx_range.slice(1), InsertBefore); + } + + if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) { + // Loop the indices for the insertvalue instruction in parallel with the + // requested indices + const unsigned *req_idx = idx_range.begin(); + for (const unsigned *i = I->idx_begin(), *e = I->idx_end(); + i != e; ++i, ++req_idx) { + if (req_idx == idx_range.end()) { + // We can't handle this without inserting insertvalues + if (!InsertBefore) + return nullptr; + + // The requested index identifies a part of a nested aggregate. Handle + // this specially. For example, + // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0 + // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1 + // %C = extractvalue {i32, { i32, i32 } } %B, 1 + // This can be changed into + // %A = insertvalue {i32, i32 } undef, i32 10, 0 + // %C = insertvalue {i32, i32 } %A, i32 11, 1 + // which allows the unused 0,0 element from the nested struct to be + // removed. + return BuildSubAggregate(V, makeArrayRef(idx_range.begin(), req_idx), + InsertBefore); + } + + // This insert value inserts something else than what we are looking for. + // See if the (aggregate) value inserted into has the value we are + // looking for, then. + if (*req_idx != *i) + return FindInsertedValue(I->getAggregateOperand(), idx_range, + InsertBefore); + } + // If we end up here, the indices of the insertvalue match with those + // requested (though possibly only partially). Now we recursively look at + // the inserted value, passing any remaining indices. + return FindInsertedValue(I->getInsertedValueOperand(), + makeArrayRef(req_idx, idx_range.end()), + InsertBefore); + } + + if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) { + // If we're extracting a value from an aggregate that was extracted from + // something else, we can extract from that something else directly instead. + // However, we will need to chain I's indices with the requested indices. + + // Calculate the number of indices required + unsigned size = I->getNumIndices() + idx_range.size(); + // Allocate some space to put the new indices in + SmallVector<unsigned, 5> Idxs; + Idxs.reserve(size); + // Add indices from the extract value instruction + Idxs.append(I->idx_begin(), I->idx_end()); + + // Add requested indices + Idxs.append(idx_range.begin(), idx_range.end()); + + assert(Idxs.size() == size + && "Number of indices added not correct?"); + + return FindInsertedValue(I->getAggregateOperand(), Idxs, InsertBefore); + } + // Otherwise, we don't know (such as, extracting from a function return value + // or load instruction) + return nullptr; +} + +bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP, + unsigned CharSize) { + // Make sure the GEP has exactly three arguments. + if (GEP->getNumOperands() != 3) + return false; + + // Make sure the index-ee is a pointer to array of \p CharSize integers. + // CharSize. + ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType()); + if (!AT || !AT->getElementType()->isIntegerTy(CharSize)) + return false; + + // Check to make sure that the first operand of the GEP is an integer and + // has value 0 so that we are sure we're indexing into the initializer. + const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1)); + if (!FirstIdx || !FirstIdx->isZero()) + return false; + + return true; +} + +bool llvm::getConstantDataArrayInfo(const Value *V, + ConstantDataArraySlice &Slice, + unsigned ElementSize, uint64_t Offset) { + assert(V); + + // Look through bitcast instructions and geps. + V = V->stripPointerCasts(); + + // If the value is a GEP instruction or constant expression, treat it as an + // offset. + if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + // The GEP operator should be based on a pointer to string constant, and is + // indexing into the string constant. + if (!isGEPBasedOnPointerToString(GEP, ElementSize)) + return false; + + // If the second index isn't a ConstantInt, then this is a variable index + // into the array. If this occurs, we can't say anything meaningful about + // the string. + uint64_t StartIdx = 0; + if (const ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2))) + StartIdx = CI->getZExtValue(); + else + return false; + return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize, + StartIdx + Offset); + } + + // The GEP instruction, constant or instruction, must reference a global + // variable that is a constant and is initialized. The referenced constant + // initializer is the array that we'll use for optimization. + const GlobalVariable *GV = dyn_cast<GlobalVariable>(V); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return false; + + const ConstantDataArray *Array; + ArrayType *ArrayTy; + if (GV->getInitializer()->isNullValue()) { + Type *GVTy = GV->getValueType(); + if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) { + // A zeroinitializer for the array; there is no ConstantDataArray. + Array = nullptr; + } else { + const DataLayout &DL = GV->getParent()->getDataLayout(); + uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy); + uint64_t Length = SizeInBytes / (ElementSize / 8); + if (Length <= Offset) + return false; + + Slice.Array = nullptr; + Slice.Offset = 0; + Slice.Length = Length - Offset; + return true; + } + } else { + // This must be a ConstantDataArray. + Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); + if (!Array) + return false; + ArrayTy = Array->getType(); + } + if (!ArrayTy->getElementType()->isIntegerTy(ElementSize)) + return false; + + uint64_t NumElts = ArrayTy->getArrayNumElements(); + if (Offset > NumElts) + return false; + + Slice.Array = Array; + Slice.Offset = Offset; + Slice.Length = NumElts - Offset; + return true; +} + +/// This function computes the length of a null-terminated C string pointed to +/// by V. If successful, it returns true and returns the string in Str. +/// If unsuccessful, it returns false. +bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, + uint64_t Offset, bool TrimAtNul) { + ConstantDataArraySlice Slice; + if (!getConstantDataArrayInfo(V, Slice, 8, Offset)) + return false; + + if (Slice.Array == nullptr) { + if (TrimAtNul) { + Str = StringRef(); + return true; + } + if (Slice.Length == 1) { + Str = StringRef("", 1); + return true; + } + // We cannot instantiate a StringRef as we do not have an appropriate string + // of 0s at hand. + return false; + } + + // Start out with the entire array in the StringRef. + Str = Slice.Array->getAsString(); + // Skip over 'offset' bytes. + Str = Str.substr(Slice.Offset); + + if (TrimAtNul) { + // Trim off the \0 and anything after it. If the array is not nul + // terminated, we just return the whole end of string. The client may know + // some other way that the string is length-bound. + Str = Str.substr(0, Str.find('\0')); + } + return true; +} + +// These next two are very similar to the above, but also look through PHI +// nodes. +// TODO: See if we can integrate these two together. + +/// If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'. If we can't, return 0. +static uint64_t GetStringLengthH(const Value *V, + SmallPtrSetImpl<const PHINode*> &PHIs, + unsigned CharSize) { + // Look through noop bitcast instructions. + V = V->stripPointerCasts(); + + // If this is a PHI node, there are two cases: either we have already seen it + // or we haven't. + if (const PHINode *PN = dyn_cast<PHINode>(V)) { + if (!PHIs.insert(PN).second) + return ~0ULL; // already in the set. + + // If it was new, see if all the input strings are the same length. + uint64_t LenSoFar = ~0ULL; + for (Value *IncValue : PN->incoming_values()) { + uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize); + if (Len == 0) return 0; // Unknown length -> unknown. + + if (Len == ~0ULL) continue; + + if (Len != LenSoFar && LenSoFar != ~0ULL) + return 0; // Disagree -> unknown. + LenSoFar = Len; + } + + // Success, all agree. + return LenSoFar; + } + + // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y) + if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { + uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize); + if (Len1 == 0) return 0; + uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize); + if (Len2 == 0) return 0; + if (Len1 == ~0ULL) return Len2; + if (Len2 == ~0ULL) return Len1; + if (Len1 != Len2) return 0; + return Len1; + } + + // Otherwise, see if we can read the string. + ConstantDataArraySlice Slice; + if (!getConstantDataArrayInfo(V, Slice, CharSize)) + return 0; + + if (Slice.Array == nullptr) + return 1; + + // Search for nul characters + unsigned NullIndex = 0; + for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) { + if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0) + break; + } + + return NullIndex + 1; +} + +/// If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'. If we can't, return 0. +uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { + if (!V->getType()->isPointerTy()) + return 0; + + SmallPtrSet<const PHINode*, 32> PHIs; + uint64_t Len = GetStringLengthH(V, PHIs, CharSize); + // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return + // an empty string as a length. + return Len == ~0ULL ? 1 : Len; +} + +const Value * +llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call, + bool MustPreserveNullness) { + assert(Call && + "getArgumentAliasingToReturnedPointer only works on nonnull calls"); + if (const Value *RV = Call->getReturnedArgOperand()) + return RV; + // This can be used only as a aliasing property. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( + Call, MustPreserveNullness)) + return Call->getArgOperand(0); + return nullptr; +} + +bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( + const CallBase *Call, bool MustPreserveNullness) { + return Call->getIntrinsicID() == Intrinsic::launder_invariant_group || + Call->getIntrinsicID() == Intrinsic::strip_invariant_group || + Call->getIntrinsicID() == Intrinsic::aarch64_irg || + Call->getIntrinsicID() == Intrinsic::aarch64_tagp || + (!MustPreserveNullness && + Call->getIntrinsicID() == Intrinsic::ptrmask); +} + +/// \p PN defines a loop-variant pointer to an object. Check if the +/// previous iteration of the loop was referring to the same object as \p PN. +static bool isSameUnderlyingObjectInLoop(const PHINode *PN, + const LoopInfo *LI) { + // Find the loop-defined value. + Loop *L = LI->getLoopFor(PN->getParent()); + if (PN->getNumIncomingValues() != 2) + return true; + + // Find the value from previous iteration. + auto *PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(0)); + if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L) + PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(1)); + if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L) + return true; + + // If a new pointer is loaded in the loop, the pointer references a different + // object in every iteration. E.g.: + // for (i) + // int *p = a[i]; + // ... + if (auto *Load = dyn_cast<LoadInst>(PrevValue)) + if (!L->isLoopInvariant(Load->getPointerOperand())) + return false; + return true; +} + +Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, + unsigned MaxLookup) { + if (!V->getType()->isPointerTy()) + return V; + for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) { + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + V = GEP->getPointerOperand(); + } else if (Operator::getOpcode(V) == Instruction::BitCast || + Operator::getOpcode(V) == Instruction::AddrSpaceCast) { + V = cast<Operator>(V)->getOperand(0); + } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (GA->isInterposable()) + return V; + V = GA->getAliasee(); + } else if (isa<AllocaInst>(V)) { + // An alloca can't be further simplified. + return V; + } else { + if (auto *Call = dyn_cast<CallBase>(V)) { + // CaptureTracking can know about special capturing properties of some + // intrinsics like launder.invariant.group, that can't be expressed with + // the attributes, but have properties like returning aliasing pointer. + // Because some analysis may assume that nocaptured pointer is not + // returned from some special intrinsic (because function would have to + // be marked with returns attribute), it is crucial to use this function + // because it should be in sync with CaptureTracking. Not using it may + // cause weird miscompilations where 2 aliasing pointers are assumed to + // noalias. + if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) { + V = RP; + continue; + } + } + + // See if InstructionSimplify knows any relevant tricks. + if (Instruction *I = dyn_cast<Instruction>(V)) + // TODO: Acquire a DominatorTree and AssumptionCache and use them. + if (Value *Simplified = SimplifyInstruction(I, {DL, I})) { + V = Simplified; + continue; + } + + return V; + } + assert(V->getType()->isPointerTy() && "Unexpected operand type!"); + } + return V; +} + +void llvm::GetUnderlyingObjects(const Value *V, + SmallVectorImpl<const Value *> &Objects, + const DataLayout &DL, LoopInfo *LI, + unsigned MaxLookup) { + SmallPtrSet<const Value *, 4> Visited; + SmallVector<const Value *, 4> Worklist; + Worklist.push_back(V); + do { + const Value *P = Worklist.pop_back_val(); + P = GetUnderlyingObject(P, DL, MaxLookup); + + if (!Visited.insert(P).second) + continue; + + if (auto *SI = dyn_cast<SelectInst>(P)) { + Worklist.push_back(SI->getTrueValue()); + Worklist.push_back(SI->getFalseValue()); + continue; + } + + if (auto *PN = dyn_cast<PHINode>(P)) { + // If this PHI changes the underlying object in every iteration of the + // loop, don't look through it. Consider: + // int **A; + // for (i) { + // Prev = Curr; // Prev = PHI (Prev_0, Curr) + // Curr = A[i]; + // *Prev, *Curr; + // + // Prev is tracking Curr one iteration behind so they refer to different + // underlying objects. + if (!LI || !LI->isLoopHeader(PN->getParent()) || + isSameUnderlyingObjectInLoop(PN, LI)) + for (Value *IncValue : PN->incoming_values()) + Worklist.push_back(IncValue); + continue; + } + + Objects.push_back(P); + } while (!Worklist.empty()); +} + +/// This is the function that does the work of looking through basic +/// ptrtoint+arithmetic+inttoptr sequences. +static const Value *getUnderlyingObjectFromInt(const Value *V) { + do { + if (const Operator *U = dyn_cast<Operator>(V)) { + // If we find a ptrtoint, we can transfer control back to the + // regular getUnderlyingObjectFromInt. + if (U->getOpcode() == Instruction::PtrToInt) + return U->getOperand(0); + // If we find an add of a constant, a multiplied value, or a phi, it's + // likely that the other operand will lead us to the base + // object. We don't have to worry about the case where the + // object address is somehow being computed by the multiply, + // because our callers only care when the result is an + // identifiable object. + if (U->getOpcode() != Instruction::Add || + (!isa<ConstantInt>(U->getOperand(1)) && + Operator::getOpcode(U->getOperand(1)) != Instruction::Mul && + !isa<PHINode>(U->getOperand(1)))) + return V; + V = U->getOperand(0); + } else { + return V; + } + assert(V->getType()->isIntegerTy() && "Unexpected operand type!"); + } while (true); +} + +/// This is a wrapper around GetUnderlyingObjects and adds support for basic +/// ptrtoint+arithmetic+inttoptr sequences. +/// It returns false if unidentified object is found in GetUnderlyingObjects. +bool llvm::getUnderlyingObjectsForCodeGen(const Value *V, + SmallVectorImpl<Value *> &Objects, + const DataLayout &DL) { + SmallPtrSet<const Value *, 16> Visited; + SmallVector<const Value *, 4> Working(1, V); + do { + V = Working.pop_back_val(); + + SmallVector<const Value *, 4> Objs; + GetUnderlyingObjects(V, Objs, DL); + + for (const Value *V : Objs) { + if (!Visited.insert(V).second) + continue; + if (Operator::getOpcode(V) == Instruction::IntToPtr) { + const Value *O = + getUnderlyingObjectFromInt(cast<User>(V)->getOperand(0)); + if (O->getType()->isPointerTy()) { + Working.push_back(O); + continue; + } + } + // If GetUnderlyingObjects fails to find an identifiable object, + // getUnderlyingObjectsForCodeGen also fails for safety. + if (!isIdentifiedObject(V)) { + Objects.clear(); + return false; + } + Objects.push_back(const_cast<Value *>(V)); + } + } while (!Working.empty()); + return true; +} + +/// Return true if the only users of this pointer are lifetime markers. +bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { + for (const User *U : V->users()) { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); + if (!II) return false; + + if (!II->isLifetimeStartOrEnd()) + return false; + } + return true; +} + +bool llvm::mustSuppressSpeculation(const LoadInst &LI) { + if (!LI.isUnordered()) + return true; + const Function &F = *LI.getFunction(); + // Speculative load may create a race that did not exist in the source. + return F.hasFnAttribute(Attribute::SanitizeThread) || + // Speculative load may load data from dirty regions. + F.hasFnAttribute(Attribute::SanitizeAddress) || + F.hasFnAttribute(Attribute::SanitizeHWAddress); +} + + +bool llvm::isSafeToSpeculativelyExecute(const Value *V, + const Instruction *CtxI, + const DominatorTree *DT) { + const Operator *Inst = dyn_cast<Operator>(V); + if (!Inst) + return false; + + for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) + if (Constant *C = dyn_cast<Constant>(Inst->getOperand(i))) + if (C->canTrap()) + return false; + + switch (Inst->getOpcode()) { + default: + return true; + case Instruction::UDiv: + case Instruction::URem: { + // x / y is undefined if y == 0. + const APInt *V; + if (match(Inst->getOperand(1), m_APInt(V))) + return *V != 0; + return false; + } + case Instruction::SDiv: + case Instruction::SRem: { + // x / y is undefined if y == 0 or x == INT_MIN and y == -1 + const APInt *Numerator, *Denominator; + if (!match(Inst->getOperand(1), m_APInt(Denominator))) + return false; + // We cannot hoist this division if the denominator is 0. + if (*Denominator == 0) + return false; + // It's safe to hoist if the denominator is not 0 or -1. + if (*Denominator != -1) + return true; + // At this point we know that the denominator is -1. It is safe to hoist as + // long we know that the numerator is not INT_MIN. + if (match(Inst->getOperand(0), m_APInt(Numerator))) + return !Numerator->isMinSignedValue(); + // The numerator *might* be MinSignedValue. + return false; + } + case Instruction::Load: { + const LoadInst *LI = cast<LoadInst>(Inst); + if (mustSuppressSpeculation(*LI)) + return false; + const DataLayout &DL = LI->getModule()->getDataLayout(); + return isDereferenceableAndAlignedPointer( + LI->getPointerOperand(), LI->getType(), MaybeAlign(LI->getAlignment()), + DL, CtxI, DT); + } + case Instruction::Call: { + auto *CI = cast<const CallInst>(Inst); + const Function *Callee = CI->getCalledFunction(); + + // The called function could have undefined behavior or side-effects, even + // if marked readnone nounwind. + return Callee && Callee->isSpeculatable(); + } + case Instruction::VAArg: + case Instruction::Alloca: + case Instruction::Invoke: + case Instruction::CallBr: + case Instruction::PHI: + case Instruction::Store: + case Instruction::Ret: + case Instruction::Br: + case Instruction::IndirectBr: + case Instruction::Switch: + case Instruction::Unreachable: + case Instruction::Fence: + case Instruction::AtomicRMW: + case Instruction::AtomicCmpXchg: + case Instruction::LandingPad: + case Instruction::Resume: + case Instruction::CatchSwitch: + case Instruction::CatchPad: + case Instruction::CatchRet: + case Instruction::CleanupPad: + case Instruction::CleanupRet: + return false; // Misc instructions which have effects + } +} + +bool llvm::mayBeMemoryDependent(const Instruction &I) { + return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); +} + +/// Convert ConstantRange OverflowResult into ValueTracking OverflowResult. +static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) { + switch (OR) { + case ConstantRange::OverflowResult::MayOverflow: + return OverflowResult::MayOverflow; + case ConstantRange::OverflowResult::AlwaysOverflowsLow: + return OverflowResult::AlwaysOverflowsLow; + case ConstantRange::OverflowResult::AlwaysOverflowsHigh: + return OverflowResult::AlwaysOverflowsHigh; + case ConstantRange::OverflowResult::NeverOverflows: + return OverflowResult::NeverOverflows; + } + llvm_unreachable("Unknown OverflowResult"); +} + +/// Combine constant ranges from computeConstantRange() and computeKnownBits(). +static ConstantRange computeConstantRangeIncludingKnownBits( + const Value *V, bool ForSigned, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + OptimizationRemarkEmitter *ORE = nullptr, bool UseInstrInfo = true) { + KnownBits Known = computeKnownBits( + V, DL, Depth, AC, CxtI, DT, ORE, UseInstrInfo); + ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned); + ConstantRange CR2 = computeConstantRange(V, UseInstrInfo); + ConstantRange::PreferredRangeType RangeType = + ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned; + return CR1.intersectWith(CR2, RangeType); +} + +OverflowResult llvm::computeOverflowForUnsignedMul( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false); + ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false); + return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange)); +} + +OverflowResult +llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS, + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) + + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return OverflowResult::NeverOverflows; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) + return OverflowResult::NeverOverflows; + } + return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForUnsignedAdd( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( + LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( + RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange)); +} + +static OverflowResult computeOverflowForSignedAdd(const Value *LHS, + const Value *RHS, + const AddOperator *Add, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + if (Add && Add->hasNoSignedWrap()) { + return OverflowResult::NeverOverflows; + } + + // If LHS and RHS each have at least two sign bits, the addition will look + // like + // + // XX..... + + // YY..... + // + // If the carry into the most significant position is 0, X and Y can't both + // be 1 and therefore the carry out of the addition is also 0. + // + // If the carry into the most significant position is 1, X and Y can't both + // be 0 and therefore the carry out of the addition is also 1. + // + // Since the carry into the most significant position is always equal to + // the carry out of the addition, there is no signed overflow. + if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) + return OverflowResult::NeverOverflows; + + ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( + LHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT); + ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( + RHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT); + OverflowResult OR = + mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange)); + if (OR != OverflowResult::MayOverflow) + return OR; + + // The remaining code needs Add to be available. Early returns if not so. + if (!Add) + return OverflowResult::MayOverflow; + + // If the sign of Add is the same as at least one of the operands, this add + // CANNOT overflow. If this can be determined from the known bits of the + // operands the above signedAddMayOverflow() check will have already done so. + // The only other way to improve on the known bits is from an assumption, so + // call computeKnownBitsFromAssume() directly. + bool LHSOrRHSKnownNonNegative = + (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative()); + bool LHSOrRHSKnownNegative = + (LHSRange.isAllNegative() || RHSRange.isAllNegative()); + if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) { + KnownBits AddKnown(LHSRange.getBitWidth()); + computeKnownBitsFromAssume( + Add, AddKnown, /*Depth=*/0, Query(DL, AC, CxtI, DT, true)); + if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) || + (AddKnown.isNegative() && LHSOrRHSKnownNegative)) + return OverflowResult::NeverOverflows; + } + + return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( + LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT); + ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( + RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT); + return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange)); +} + +OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // If LHS and RHS each have at least two sign bits, the subtraction + // cannot overflow. + if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) + return OverflowResult::NeverOverflows; + + ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( + LHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT); + ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( + RHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT); + return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange)); +} + +bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, + const DominatorTree &DT) { + SmallVector<const BranchInst *, 2> GuardingBranches; + SmallVector<const ExtractValueInst *, 2> Results; + + for (const User *U : WO->users()) { + if (const auto *EVI = dyn_cast<ExtractValueInst>(U)) { + assert(EVI->getNumIndices() == 1 && "Obvious from CI's type"); + + if (EVI->getIndices()[0] == 0) + Results.push_back(EVI); + else { + assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type"); + + for (const auto *U : EVI->users()) + if (const auto *B = dyn_cast<BranchInst>(U)) { + assert(B->isConditional() && "How else is it using an i1?"); + GuardingBranches.push_back(B); + } + } + } else { + // We are using the aggregate directly in a way we don't want to analyze + // here (storing it to a global, say). + return false; + } + } + + auto AllUsesGuardedByBranch = [&](const BranchInst *BI) { + BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1)); + if (!NoWrapEdge.isSingleEdge()) + return false; + + // Check if all users of the add are provably no-wrap. + for (const auto *Result : Results) { + // If the extractvalue itself is not executed on overflow, the we don't + // need to check each use separately, since domination is transitive. + if (DT.dominates(NoWrapEdge, Result->getParent())) + continue; + + for (auto &RU : Result->uses()) + if (!DT.dominates(NoWrapEdge, RU)) + return false; + } + + return true; + }; + + return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch); +} + + +OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1), + Add, DL, AC, CxtI, DT); +} + +OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT); +} + +bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { + // Note: An atomic operation isn't guaranteed to return in a reasonable amount + // of time because it's possible for another thread to interfere with it for an + // arbitrary length of time, but programs aren't allowed to rely on that. + + // If there is no successor, then execution can't transfer to it. + if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) + return !CRI->unwindsToCaller(); + if (const auto *CatchSwitch = dyn_cast<CatchSwitchInst>(I)) + return !CatchSwitch->unwindsToCaller(); + if (isa<ResumeInst>(I)) + return false; + if (isa<ReturnInst>(I)) + return false; + if (isa<UnreachableInst>(I)) + return false; + + // Calls can throw, or contain an infinite loop, or kill the process. + if (auto CS = ImmutableCallSite(I)) { + // Call sites that throw have implicit non-local control flow. + if (!CS.doesNotThrow()) + return false; + + // A function which doens't throw and has "willreturn" attribute will + // always return. + if (CS.hasFnAttr(Attribute::WillReturn)) + return true; + + // Non-throwing call sites can loop infinitely, call exit/pthread_exit + // etc. and thus not return. However, LLVM already assumes that + // + // - Thread exiting actions are modeled as writes to memory invisible to + // the program. + // + // - Loops that don't have side effects (side effects are volatile/atomic + // stores and IO) always terminate (see http://llvm.org/PR965). + // Furthermore IO itself is also modeled as writes to memory invisible to + // the program. + // + // We rely on those assumptions here, and use the memory effects of the call + // target as a proxy for checking that it always returns. + + // FIXME: This isn't aggressive enough; a call which only writes to a global + // is guaranteed to return. + return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory(); + } + + // Other instructions return normally. + return true; +} + +bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) { + // TODO: This is slightly conservative for invoke instruction since exiting + // via an exception *is* normal control for them. + for (auto I = BB->begin(), E = BB->end(); I != E; ++I) + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) + return false; + return true; +} + +bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, + const Loop *L) { + // The loop header is guaranteed to be executed for every iteration. + // + // FIXME: Relax this constraint to cover all basic blocks that are + // guaranteed to be executed at every iteration. + if (I->getParent() != L->getHeader()) return false; + + for (const Instruction &LI : *L->getHeader()) { + if (&LI == I) return true; + if (!isGuaranteedToTransferExecutionToSuccessor(&LI)) return false; + } + llvm_unreachable("Instruction not contained in its own parent basic block."); +} + +bool llvm::propagatesFullPoison(const Instruction *I) { + // TODO: This should include all instructions apart from phis, selects and + // call-like instructions. + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + case Instruction::Trunc: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::Mul: + case Instruction::Shl: + case Instruction::GetElementPtr: + // These operations all propagate poison unconditionally. Note that poison + // is not any particular value, so xor or subtraction of poison with + // itself still yields poison, not zero. + return true; + + case Instruction::AShr: + case Instruction::SExt: + // For these operations, one bit of the input is replicated across + // multiple output bits. A replicated poison bit is still poison. + return true; + + case Instruction::ICmp: + // Comparing poison with any value yields poison. This is why, for + // instance, x s< (x +nsw 1) can be folded to true. + return true; + + default: + return false; + } +} + +const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Store: + return cast<StoreInst>(I)->getPointerOperand(); + + case Instruction::Load: + return cast<LoadInst>(I)->getPointerOperand(); + + case Instruction::AtomicCmpXchg: + return cast<AtomicCmpXchgInst>(I)->getPointerOperand(); + + case Instruction::AtomicRMW: + return cast<AtomicRMWInst>(I)->getPointerOperand(); + + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + return I->getOperand(1); + + default: + // Note: It's really tempting to think that a conditional branch or + // switch should be listed here, but that's incorrect. It's not + // branching off of poison which is UB, it is executing a side effecting + // instruction which follows the branch. + return nullptr; + } +} + +bool llvm::mustTriggerUB(const Instruction *I, + const SmallSet<const Value *, 16>& KnownPoison) { + auto *NotPoison = getGuaranteedNonFullPoisonOp(I); + return (NotPoison && KnownPoison.count(NotPoison)); +} + + +bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { + // We currently only look for uses of poison values within the same basic + // block, as that makes it easier to guarantee that the uses will be + // executed given that PoisonI is executed. + // + // FIXME: Expand this to consider uses beyond the same basic block. To do + // this, look out for the distinction between post-dominance and strong + // post-dominance. + const BasicBlock *BB = PoisonI->getParent(); + + // Set of instructions that we have proved will yield poison if PoisonI + // does. + SmallSet<const Value *, 16> YieldsPoison; + SmallSet<const BasicBlock *, 4> Visited; + YieldsPoison.insert(PoisonI); + Visited.insert(PoisonI->getParent()); + + BasicBlock::const_iterator Begin = PoisonI->getIterator(), End = BB->end(); + + unsigned Iter = 0; + while (Iter++ < MaxDepth) { + for (auto &I : make_range(Begin, End)) { + if (&I != PoisonI) { + if (mustTriggerUB(&I, YieldsPoison)) + return true; + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + return false; + } + + // Mark poison that propagates from I through uses of I. + if (YieldsPoison.count(&I)) { + for (const User *User : I.users()) { + const Instruction *UserI = cast<Instruction>(User); + if (propagatesFullPoison(UserI)) + YieldsPoison.insert(User); + } + } + } + + if (auto *NextBB = BB->getSingleSuccessor()) { + if (Visited.insert(NextBB).second) { + BB = NextBB; + Begin = BB->getFirstNonPHI()->getIterator(); + End = BB->end(); + continue; + } + } + + break; + } + return false; +} + +static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) { + if (FMF.noNaNs()) + return true; + + if (auto *C = dyn_cast<ConstantFP>(V)) + return !C->isNaN(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isNaN()) + return false; + } + return true; + } + + return false; +} + +static bool isKnownNonZero(const Value *V) { + if (auto *C = dyn_cast<ConstantFP>(V)) + return !C->isZero(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isZero()) + return false; + } + return true; + } + + return false; +} + +/// Match clamp pattern for float types without care about NaNs or signed zeros. +/// Given non-min/max outer cmp/select from the clamp pattern this +/// function recognizes if it can be substitued by a "canonical" min/max +/// pattern. +static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS) { + // Try to match + // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2)) + // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2)) + // and return description of the outer Max/Min. + + // First, check if select has inverse order: + if (CmpRHS == FalseVal) { + std::swap(TrueVal, FalseVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + // Assume success now. If there's no match, callers should not use these anyway. + LHS = TrueVal; + RHS = FalseVal; + + const APFloat *FC1; + if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite()) + return {SPF_UNKNOWN, SPNB_NA, false}; + + const APFloat *FC2; + switch (Pred) { + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULT: + case CmpInst::FCMP_ULE: + if (match(FalseVal, + m_CombineOr(m_OrdFMin(m_Specific(CmpLHS), m_APFloat(FC2)), + m_UnordFMin(m_Specific(CmpLHS), m_APFloat(FC2)))) && + FC1->compare(*FC2) == APFloat::cmpResult::cmpLessThan) + return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false}; + break; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGT: + case CmpInst::FCMP_UGE: + if (match(FalseVal, + m_CombineOr(m_OrdFMax(m_Specific(CmpLHS), m_APFloat(FC2)), + m_UnordFMax(m_Specific(CmpLHS), m_APFloat(FC2)))) && + FC1->compare(*FC2) == APFloat::cmpResult::cmpGreaterThan) + return {SPF_FMINNUM, SPNB_RETURNS_ANY, false}; + break; + default: + break; + } + + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Recognize variations of: +/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) +static SelectPatternResult matchClamp(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal) { + // Swap the select operands and predicate to match the patterns below. + if (CmpRHS != TrueVal) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(TrueVal, FalseVal); + } + const APInt *C1; + if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) { + const APInt *C2; + // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1) + if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->slt(*C2) && Pred == CmpInst::ICMP_SLT) + return {SPF_SMAX, SPNB_NA, false}; + + // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1) + if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT) + return {SPF_SMIN, SPNB_NA, false}; + + // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1) + if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ult(*C2) && Pred == CmpInst::ICMP_ULT) + return {SPF_UMAX, SPNB_NA, false}; + + // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1) + if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT) + return {SPF_UMIN, SPNB_NA, false}; + } + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Recognize variations of: +/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c)) +static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TVal, Value *FVal, + unsigned Depth) { + // TODO: Allow FP min/max with nnan/nsz. + assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison"); + + Value *A = nullptr, *B = nullptr; + SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1); + if (!SelectPatternResult::isMinOrMax(L.Flavor)) + return {SPF_UNKNOWN, SPNB_NA, false}; + + Value *C = nullptr, *D = nullptr; + SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1); + if (L.Flavor != R.Flavor) + return {SPF_UNKNOWN, SPNB_NA, false}; + + // We have something like: x Pred y ? min(a, b) : min(c, d). + // Try to match the compare to the min/max operations of the select operands. + // First, make sure we have the right compare predicate. + switch (L.Flavor) { + case SPF_SMIN: + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_SMAX: + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_UMIN: + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_UMAX: + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + default: + return {SPF_UNKNOWN, SPNB_NA, false}; + } + + // If there is a common operand in the already matched min/max and the other + // min/max operands match the compare operands (either directly or inverted), + // then this is min/max of the same flavor. + + // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) + // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) + if (D == B) { + if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && + match(A, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) + // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) + if (C == B) { + if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && + match(A, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) + // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) + if (D == A) { + if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && + match(B, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) + // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) + if (C == A) { + if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && + match(B, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Match non-obvious integer minimum and maximum sequences. +static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS, + unsigned Depth) { + // Assume success. If there's no match, callers should not use these anyway. + LHS = TrueVal; + RHS = FalseVal; + + SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal); + if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) + return SPR; + + SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth); + if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) + return SPR; + + if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) + return {SPF_UNKNOWN, SPNB_NA, false}; + + // Z = X -nsw Y + // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0) + // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0) + if (match(TrueVal, m_Zero()) && + match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) + return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; + + // Z = X -nsw Y + // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0) + // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0) + if (match(FalseVal, m_Zero()) && + match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) + return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; + + const APInt *C1; + if (!match(CmpRHS, m_APInt(C1))) + return {SPF_UNKNOWN, SPNB_NA, false}; + + // An unsigned min/max can be written with a signed compare. + const APInt *C2; + if ((CmpLHS == TrueVal && match(FalseVal, m_APInt(C2))) || + (CmpLHS == FalseVal && match(TrueVal, m_APInt(C2)))) { + // Is the sign bit set? + // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX + // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN + if (Pred == CmpInst::ICMP_SLT && C1->isNullValue() && + C2->isMaxSignedValue()) + return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; + + // Is the sign bit clear? + // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX + // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN + if (Pred == CmpInst::ICMP_SGT && C1->isAllOnesValue() && + C2->isMinSignedValue()) + return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; + } + + // Look through 'not' ops to find disguised signed min/max. + // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C) + // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C) + if (match(TrueVal, m_Not(m_Specific(CmpLHS))) && + match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) + return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; + + // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X) + // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X) + if (match(FalseVal, m_Not(m_Specific(CmpLHS))) && + match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) + return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; + + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW) { + assert(X && Y && "Invalid operand"); + + // X = sub (0, Y) || X = sub nsw (0, Y) + if ((!NeedNSW && match(X, m_Sub(m_ZeroInt(), m_Specific(Y)))) || + (NeedNSW && match(X, m_NSWSub(m_ZeroInt(), m_Specific(Y))))) + return true; + + // Y = sub (0, X) || Y = sub nsw (0, X) + if ((!NeedNSW && match(Y, m_Sub(m_ZeroInt(), m_Specific(X)))) || + (NeedNSW && match(Y, m_NSWSub(m_ZeroInt(), m_Specific(X))))) + return true; + + // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A) + Value *A, *B; + return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) && + match(Y, m_Sub(m_Specific(B), m_Specific(A))))) || + (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) && + match(Y, m_NSWSub(m_Specific(B), m_Specific(A))))); +} + +static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, + FastMathFlags FMF, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS, + unsigned Depth) { + if (CmpInst::isFPPredicate(Pred)) { + // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one + // 0.0 operand, set the compare's 0.0 operands to that same value for the + // purpose of identifying min/max. Disregard vector constants with undefined + // elements because those can not be back-propagated for analysis. + Value *OutputZeroVal = nullptr; + if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && + !cast<Constant>(TrueVal)->containsUndefElement()) + OutputZeroVal = TrueVal; + else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && + !cast<Constant>(FalseVal)->containsUndefElement()) + OutputZeroVal = FalseVal; + + if (OutputZeroVal) { + if (match(CmpLHS, m_AnyZeroFP())) + CmpLHS = OutputZeroVal; + if (match(CmpRHS, m_AnyZeroFP())) + CmpRHS = OutputZeroVal; + } + } + + LHS = CmpLHS; + RHS = CmpRHS; + + // Signed zero may return inconsistent results between implementations. + // (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0 + // minNum(0.0, -0.0) // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1) + // Therefore, we behave conservatively and only proceed if at least one of the + // operands is known to not be zero or if we don't care about signed zero. + switch (Pred) { + default: break; + // FIXME: Include OGT/OLT/UGT/ULT. + case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE: + case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE: + if (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) && + !isKnownNonZero(CmpRHS)) + return {SPF_UNKNOWN, SPNB_NA, false}; + } + + SelectPatternNaNBehavior NaNBehavior = SPNB_NA; + bool Ordered = false; + + // When given one NaN and one non-NaN input: + // - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input. + // - A simple C99 (a < b ? a : b) construction will return 'b' (as the + // ordered comparison fails), which could be NaN or non-NaN. + // so here we discover exactly what NaN behavior is required/accepted. + if (CmpInst::isFPPredicate(Pred)) { + bool LHSSafe = isKnownNonNaN(CmpLHS, FMF); + bool RHSSafe = isKnownNonNaN(CmpRHS, FMF); + + if (LHSSafe && RHSSafe) { + // Both operands are known non-NaN. + NaNBehavior = SPNB_RETURNS_ANY; + } else if (CmpInst::isOrdered(Pred)) { + // An ordered comparison will return false when given a NaN, so it + // returns the RHS. + Ordered = true; + if (LHSSafe) + // LHS is non-NaN, so if RHS is NaN then NaN will be returned. + NaNBehavior = SPNB_RETURNS_NAN; + else if (RHSSafe) + NaNBehavior = SPNB_RETURNS_OTHER; + else + // Completely unsafe. + return {SPF_UNKNOWN, SPNB_NA, false}; + } else { + Ordered = false; + // An unordered comparison will return true when given a NaN, so it + // returns the LHS. + if (LHSSafe) + // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned. + NaNBehavior = SPNB_RETURNS_OTHER; + else if (RHSSafe) + NaNBehavior = SPNB_RETURNS_NAN; + else + // Completely unsafe. + return {SPF_UNKNOWN, SPNB_NA, false}; + } + } + + if (TrueVal == CmpRHS && FalseVal == CmpLHS) { + std::swap(CmpLHS, CmpRHS); + Pred = CmpInst::getSwappedPredicate(Pred); + if (NaNBehavior == SPNB_RETURNS_NAN) + NaNBehavior = SPNB_RETURNS_OTHER; + else if (NaNBehavior == SPNB_RETURNS_OTHER) + NaNBehavior = SPNB_RETURNS_NAN; + Ordered = !Ordered; + } + + // ([if]cmp X, Y) ? X : Y + if (TrueVal == CmpLHS && FalseVal == CmpRHS) { + switch (Pred) { + default: return {SPF_UNKNOWN, SPNB_NA, false}; // Equality. + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: return {SPF_UMAX, SPNB_NA, false}; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: return {SPF_SMAX, SPNB_NA, false}; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: return {SPF_UMIN, SPNB_NA, false}; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: return {SPF_SMIN, SPNB_NA, false}; + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_OGE: return {SPF_FMAXNUM, NaNBehavior, Ordered}; + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLE: return {SPF_FMINNUM, NaNBehavior, Ordered}; + } + } + + if (isKnownNegation(TrueVal, FalseVal)) { + // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can + // match against either LHS or sext(LHS). + auto MaybeSExtCmpLHS = + m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS))); + auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes()); + auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One()); + if (match(TrueVal, MaybeSExtCmpLHS)) { + // Set the return values. If the compare uses the negated value (-X >s 0), + // swap the return values because the negated value is always 'RHS'. + LHS = TrueVal; + RHS = FalseVal; + if (match(CmpLHS, m_Neg(m_Specific(FalseVal)))) + std::swap(LHS, RHS); + + // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X) + // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X) + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) + return {SPF_ABS, SPNB_NA, false}; + + // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X) + if (Pred == ICmpInst::ICMP_SGE && match(CmpRHS, ZeroOrOne)) + return {SPF_ABS, SPNB_NA, false}; + + // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X) + // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X) + if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) + return {SPF_NABS, SPNB_NA, false}; + } + else if (match(FalseVal, MaybeSExtCmpLHS)) { + // Set the return values. If the compare uses the negated value (-X >s 0), + // swap the return values because the negated value is always 'RHS'. + LHS = FalseVal; + RHS = TrueVal; + if (match(CmpLHS, m_Neg(m_Specific(TrueVal)))) + std::swap(LHS, RHS); + + // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X) + // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X) + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) + return {SPF_NABS, SPNB_NA, false}; + + // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X) + // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X) + if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) + return {SPF_ABS, SPNB_NA, false}; + } + } + + if (CmpInst::isIntPredicate(Pred)) + return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth); + + // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar + // may return either -0.0 or 0.0, so fcmp/select pair has stricter + // semantics than minNum. Be conservative in such case. + if (NaNBehavior != SPNB_RETURNS_ANY || + (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) && + !isKnownNonZero(CmpRHS))) + return {SPF_UNKNOWN, SPNB_NA, false}; + + return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); +} + +/// Helps to match a select pattern in case of a type mismatch. +/// +/// The function processes the case when type of true and false values of a +/// select instruction differs from type of the cmp instruction operands because +/// of a cast instruction. The function checks if it is legal to move the cast +/// operation after "select". If yes, it returns the new second value of +/// "select" (with the assumption that cast is moved): +/// 1. As operand of cast instruction when both values of "select" are same cast +/// instructions. +/// 2. As restored constant (by applying reverse cast operation) when the first +/// value of the "select" is a cast operation and the second value is a +/// constant. +/// NOTE: We return only the new second value because the first value could be +/// accessed as operand of cast instruction. +static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, + Instruction::CastOps *CastOp) { + auto *Cast1 = dyn_cast<CastInst>(V1); + if (!Cast1) + return nullptr; + + *CastOp = Cast1->getOpcode(); + Type *SrcTy = Cast1->getSrcTy(); + if (auto *Cast2 = dyn_cast<CastInst>(V2)) { + // If V1 and V2 are both the same cast from the same type, look through V1. + if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy()) + return Cast2->getOperand(0); + return nullptr; + } + + auto *C = dyn_cast<Constant>(V2); + if (!C) + return nullptr; + + Constant *CastedTo = nullptr; + switch (*CastOp) { + case Instruction::ZExt: + if (CmpI->isUnsigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy); + break; + case Instruction::SExt: + if (CmpI->isSigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy, true); + break; + case Instruction::Trunc: + Constant *CmpConst; + if (match(CmpI->getOperand(1), m_Constant(CmpConst)) && + CmpConst->getType() == SrcTy) { + // Here we have the following case: + // + // %cond = cmp iN %x, CmpConst + // %tr = trunc iN %x to iK + // %narrowsel = select i1 %cond, iK %t, iK C + // + // We can always move trunc after select operation: + // + // %cond = cmp iN %x, CmpConst + // %widesel = select i1 %cond, iN %x, iN CmpConst + // %tr = trunc iN %widesel to iK + // + // Note that C could be extended in any way because we don't care about + // upper bits after truncation. It can't be abs pattern, because it would + // look like: + // + // select i1 %cond, x, -x. + // + // So only min/max pattern could be matched. Such match requires widened C + // == CmpConst. That is why set widened C = CmpConst, condition trunc + // CmpConst == C is checked below. + CastedTo = CmpConst; + } else { + CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); + } + break; + case Instruction::FPTrunc: + CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true); + break; + case Instruction::FPExt: + CastedTo = ConstantExpr::getFPTrunc(C, SrcTy, true); + break; + case Instruction::FPToUI: + CastedTo = ConstantExpr::getUIToFP(C, SrcTy, true); + break; + case Instruction::FPToSI: + CastedTo = ConstantExpr::getSIToFP(C, SrcTy, true); + break; + case Instruction::UIToFP: + CastedTo = ConstantExpr::getFPToUI(C, SrcTy, true); + break; + case Instruction::SIToFP: + CastedTo = ConstantExpr::getFPToSI(C, SrcTy, true); + break; + default: + break; + } + + if (!CastedTo) + return nullptr; + + // Make sure the cast doesn't lose any information. + Constant *CastedBack = + ConstantExpr::getCast(*CastOp, CastedTo, C->getType(), true); + if (CastedBack != C) + return nullptr; + + return CastedTo; +} + +SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp, + unsigned Depth) { + if (Depth >= MaxDepth) + return {SPF_UNKNOWN, SPNB_NA, false}; + + SelectInst *SI = dyn_cast<SelectInst>(V); + if (!SI) return {SPF_UNKNOWN, SPNB_NA, false}; + + CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition()); + if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false}; + + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + + return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS, + CastOp, Depth); +} + +SelectPatternResult llvm::matchDecomposedSelectPattern( + CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp, unsigned Depth) { + CmpInst::Predicate Pred = CmpI->getPredicate(); + Value *CmpLHS = CmpI->getOperand(0); + Value *CmpRHS = CmpI->getOperand(1); + FastMathFlags FMF; + if (isa<FPMathOperator>(CmpI)) + FMF = CmpI->getFastMathFlags(); + + // Bail out early. + if (CmpI->isEquality()) + return {SPF_UNKNOWN, SPNB_NA, false}; + + // Deal with type mismatches. + if (CastOp && CmpLHS->getType() != TrueVal->getType()) { + if (Value *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) { + // If this is a potential fmin/fmax with a cast to integer, then ignore + // -0.0 because there is no corresponding integer value. + if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI) + FMF.setNoSignedZeros(); + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, + cast<CastInst>(TrueVal)->getOperand(0), C, + LHS, RHS, Depth); + } + if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) { + // If this is a potential fmin/fmax with a cast to integer, then ignore + // -0.0 because there is no corresponding integer value. + if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI) + FMF.setNoSignedZeros(); + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, + C, cast<CastInst>(FalseVal)->getOperand(0), + LHS, RHS, Depth); + } + } + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal, + LHS, RHS, Depth); +} + +CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) { + if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT; + if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT; + if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT; + if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT; + if (SPF == SPF_FMINNUM) + return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; + if (SPF == SPF_FMAXNUM) + return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; + llvm_unreachable("unhandled!"); +} + +SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) { + if (SPF == SPF_SMIN) return SPF_SMAX; + if (SPF == SPF_UMIN) return SPF_UMAX; + if (SPF == SPF_SMAX) return SPF_SMIN; + if (SPF == SPF_UMAX) return SPF_UMIN; + llvm_unreachable("unhandled!"); +} + +CmpInst::Predicate llvm::getInverseMinMaxPred(SelectPatternFlavor SPF) { + return getMinMaxPred(getInverseMinMaxFlavor(SPF)); +} + +/// Return true if "icmp Pred LHS RHS" is always true. +static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, + const Value *RHS, const DataLayout &DL, + unsigned Depth) { + assert(!LHS->getType()->isVectorTy() && "TODO: extend to handle vectors!"); + if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS) + return true; + + switch (Pred) { + default: + return false; + + case CmpInst::ICMP_SLE: { + const APInt *C; + + // LHS s<= LHS +_{nsw} C if C >= 0 + if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C)))) + return !C->isNegative(); + return false; + } + + case CmpInst::ICMP_ULE: { + const APInt *C; + + // LHS u<= LHS +_{nuw} C for any C + if (match(RHS, m_NUWAdd(m_Specific(LHS), m_APInt(C)))) + return true; + + // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB) + auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B, + const Value *&X, + const APInt *&CA, const APInt *&CB) { + if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) && + match(B, m_NUWAdd(m_Specific(X), m_APInt(CB)))) + return true; + + // If X & C == 0 then (X | C) == X +_{nuw} C + if (match(A, m_Or(m_Value(X), m_APInt(CA))) && + match(B, m_Or(m_Specific(X), m_APInt(CB)))) { + KnownBits Known(CA->getBitWidth()); + computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr, + /*CxtI*/ nullptr, /*DT*/ nullptr); + if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero)) + return true; + } + + return false; + }; + + const Value *X; + const APInt *CLHS, *CRHS; + if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS)) + return CLHS->ule(*CRHS); + + return false; + } + } +} + +/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred +/// ALHS ARHS" is true. Otherwise, return None. +static Optional<bool> +isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, + const Value *ARHS, const Value *BLHS, const Value *BRHS, + const DataLayout &DL, unsigned Depth) { + switch (Pred) { + default: + return None; + + case CmpInst::ICMP_SLT: + case CmpInst::ICMP_SLE: + if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth)) + return true; + return None; + + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: + if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth)) + return true; + return None; + } +} + +/// Return true if the operands of the two compares match. IsSwappedOps is true +/// when the operands match, but are swapped. +static bool isMatchingOps(const Value *ALHS, const Value *ARHS, + const Value *BLHS, const Value *BRHS, + bool &IsSwappedOps) { + + bool IsMatchingOps = (ALHS == BLHS && ARHS == BRHS); + IsSwappedOps = (ALHS == BRHS && ARHS == BLHS); + return IsMatchingOps || IsSwappedOps; +} + +/// Return true if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is true. +/// Return false if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is false. +/// Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, + CmpInst::Predicate BPred, + bool AreSwappedOps) { + // Canonicalize the predicate as if the operands were not commuted. + if (AreSwappedOps) + BPred = ICmpInst::getSwappedPredicate(BPred); + + if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred)) + return true; + if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred)) + return false; + + return None; +} + +/// Return true if "icmp APred X, C1" implies "icmp BPred X, C2" is true. +/// Return false if "icmp APred X, C1" implies "icmp BPred X, C2" is false. +/// Otherwise, return None if we can't infer anything. +static Optional<bool> +isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, + const ConstantInt *C1, + CmpInst::Predicate BPred, + const ConstantInt *C2) { + ConstantRange DomCR = + ConstantRange::makeExactICmpRegion(APred, C1->getValue()); + ConstantRange CR = + ConstantRange::makeAllowedICmpRegion(BPred, C2->getValue()); + ConstantRange Intersection = DomCR.intersectWith(CR); + ConstantRange Difference = DomCR.difference(CR); + if (Intersection.isEmptySet()) + return false; + if (Difference.isEmptySet()) + return true; + return None; +} + +/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is +/// false. Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, + const ICmpInst *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + Value *ALHS = LHS->getOperand(0); + Value *ARHS = LHS->getOperand(1); + // The rest of the logic assumes the LHS condition is true. If that's not the + // case, invert the predicate to make it so. + ICmpInst::Predicate APred = + LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate(); + + Value *BLHS = RHS->getOperand(0); + Value *BRHS = RHS->getOperand(1); + ICmpInst::Predicate BPred = RHS->getPredicate(); + + // Can we infer anything when the two compares have matching operands? + bool AreSwappedOps; + if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, AreSwappedOps)) { + if (Optional<bool> Implication = isImpliedCondMatchingOperands( + APred, BPred, AreSwappedOps)) + return Implication; + // No amount of additional analysis will infer the second condition, so + // early exit. + return None; + } + + // Can we infer anything when the LHS operands match and the RHS operands are + // constants (not necessarily matching)? + if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) { + if (Optional<bool> Implication = isImpliedCondMatchingImmOperands( + APred, cast<ConstantInt>(ARHS), BPred, cast<ConstantInt>(BRHS))) + return Implication; + // No amount of additional analysis will infer the second condition, so + // early exit. + return None; + } + + if (APred == BPred) + return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth); + return None; +} + +/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is +/// false. Otherwise, return None if we can't infer anything. We expect the +/// RHS to be an icmp and the LHS to be an 'and' or an 'or' instruction. +static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS, + const ICmpInst *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + // The LHS must be an 'or' or an 'and' instruction. + assert((LHS->getOpcode() == Instruction::And || + LHS->getOpcode() == Instruction::Or) && + "Expected LHS to be 'and' or 'or'."); + + assert(Depth <= MaxDepth && "Hit recursion limit"); + + // If the result of an 'or' is false, then we know both legs of the 'or' are + // false. Similarly, if the result of an 'and' is true, then we know both + // legs of the 'and' are true. + Value *ALHS, *ARHS; + if ((!LHSIsTrue && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) || + (LHSIsTrue && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) { + // FIXME: Make this non-recursion. + if (Optional<bool> Implication = + isImpliedCondition(ALHS, RHS, DL, LHSIsTrue, Depth + 1)) + return Implication; + if (Optional<bool> Implication = + isImpliedCondition(ARHS, RHS, DL, LHSIsTrue, Depth + 1)) + return Implication; + return None; + } + return None; +} + +Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + // Bail out when we hit the limit. + if (Depth == MaxDepth) + return None; + + // A mismatch occurs when we compare a scalar cmp to a vector cmp, for + // example. + if (LHS->getType() != RHS->getType()) + return None; + + Type *OpTy = LHS->getType(); + assert(OpTy->isIntOrIntVectorTy(1) && "Expected integer type only!"); + + // LHS ==> RHS by definition + if (LHS == RHS) + return LHSIsTrue; + + // FIXME: Extending the code below to handle vectors. + if (OpTy->isVectorTy()) + return None; + + assert(OpTy->isIntegerTy(1) && "implied by above"); + + // Both LHS and RHS are icmps. + const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS); + const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); + if (LHSCmp && RHSCmp) + return isImpliedCondICmps(LHSCmp, RHSCmp, DL, LHSIsTrue, Depth); + + // The LHS should be an 'or' or an 'and' instruction. We expect the RHS to be + // an icmp. FIXME: Add support for and/or on the RHS. + const BinaryOperator *LHSBO = dyn_cast<BinaryOperator>(LHS); + if (LHSBO && RHSCmp) { + if ((LHSBO->getOpcode() == Instruction::And || + LHSBO->getOpcode() == Instruction::Or)) + return isImpliedCondAndOr(LHSBO, RHSCmp, DL, LHSIsTrue, Depth); + } + return None; +} + +Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond, + const Instruction *ContextI, + const DataLayout &DL) { + assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool"); + if (!ContextI || !ContextI->getParent()) + return None; + + // TODO: This is a poor/cheap way to determine dominance. Should we use a + // dominator tree (eg, from a SimplifyQuery) instead? + const BasicBlock *ContextBB = ContextI->getParent(); + const BasicBlock *PredBB = ContextBB->getSinglePredecessor(); + if (!PredBB) + return None; + + // We need a conditional branch in the predecessor. + Value *PredCond; + BasicBlock *TrueBB, *FalseBB; + if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB))) + return None; + + // The branch should get simplified. Don't bother simplifying this condition. + if (TrueBB == FalseBB) + return None; + + assert((TrueBB == ContextBB || FalseBB == ContextBB) && + "Predecessor block does not point to successor?"); + + // Is this condition implied by the predecessor condition? + bool CondIsTrue = TrueBB == ContextBB; + return isImpliedCondition(PredCond, Cond, DL, CondIsTrue); +} + +static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, + APInt &Upper, const InstrInfoQuery &IIQ) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (BO.getOpcode()) { + case Instruction::Add: + if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { + // FIXME: If we have both nuw and nsw, we should reduce the range further. + if (IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(&BO))) { + // 'add nuw x, C' produces [C, UINT_MAX]. + Lower = *C; + } else if (IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(&BO))) { + if (C->isNegative()) { + // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + } + break; + + case Instruction::And: + if (match(BO.getOperand(1), m_APInt(C))) + // 'and x, C' produces [0, C]. + Upper = *C + 1; + break; + + case Instruction::Or: + if (match(BO.getOperand(1), m_APInt(C))) + // 'or x, C' produces [C, UINT_MAX]. + Lower = *C; + break; + + case Instruction::AShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. + Lower = APInt::getSignedMinValue(Width).ashr(*C); + Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + unsigned ShiftAmount = Width - 1; + if (!C->isNullValue() && IIQ.isExact(&BO)) + ShiftAmount = C->countTrailingZeros(); + if (C->isNegative()) { + // 'ashr C, x' produces [C, C >> (Width-1)] + Lower = *C; + Upper = C->ashr(ShiftAmount) + 1; + } else { + // 'ashr C, x' produces [C >> (Width-1), C] + Lower = C->ashr(ShiftAmount); + Upper = *C + 1; + } + } + break; + + case Instruction::LShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'lshr x, C' produces [0, UINT_MAX >> C]. + Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'lshr C, x' produces [C >> (Width-1), C]. + unsigned ShiftAmount = Width - 1; + if (!C->isNullValue() && IIQ.isExact(&BO)) + ShiftAmount = C->countTrailingZeros(); + Lower = C->lshr(ShiftAmount); + Upper = *C + 1; + } + break; + + case Instruction::Shl: + if (match(BO.getOperand(0), m_APInt(C))) { + if (IIQ.hasNoUnsignedWrap(&BO)) { + // 'shl nuw C, x' produces [C, C << CLZ(C)] + Lower = *C; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? + if (C->isNegative()) { + // 'shl nsw C, x' produces [C << CLO(C)-1, C] + unsigned ShiftAmount = C->countLeadingOnes() - 1; + Lower = C->shl(ShiftAmount); + Upper = *C + 1; + } else { + // 'shl nsw C, x' produces [C, C << CLZ(C)-1] + unsigned ShiftAmount = C->countLeadingZeros() - 1; + Lower = *C; + Upper = C->shl(ShiftAmount) + 1; + } + } + } + break; + + case Instruction::SDiv: + if (match(BO.getOperand(1), m_APInt(C))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C->countLeadingZeros() < Width - 1) { + // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin.sdiv(*C); + Upper = IntMax.sdiv(*C); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(BO.getOperand(0), m_APInt(C))) { + if (C->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C, x' produces [-|C|, |C|]. + Upper = C->abs() + 1; + Lower = (-Upper) + 1; + } + } + break; + + case Instruction::UDiv: + if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { + // 'udiv x, C' produces [0, UINT_MAX / C]. + Upper = APInt::getMaxValue(Width).udiv(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'udiv C, x' produces [0, C]. + Upper = *C + 1; + } + break; + + case Instruction::SRem: + if (match(BO.getOperand(1), m_APInt(C))) { + // 'srem x, C' produces (-|C|, |C|). + Upper = C->abs(); + Lower = (-Upper) + 1; + } + break; + + case Instruction::URem: + if (match(BO.getOperand(1), m_APInt(C))) + // 'urem x, C' produces [0, C). + Upper = *C; + break; + + default: + break; + } +} + +static void setLimitsForIntrinsic(const IntrinsicInst &II, APInt &Lower, + APInt &Upper) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (II.getIntrinsicID()) { + case Intrinsic::uadd_sat: + // uadd.sat(x, C) produces [C, UINT_MAX]. + if (match(II.getOperand(0), m_APInt(C)) || + match(II.getOperand(1), m_APInt(C))) + Lower = *C; + break; + case Intrinsic::sadd_sat: + if (match(II.getOperand(0), m_APInt(C)) || + match(II.getOperand(1), m_APInt(C))) { + if (C->isNegative()) { + // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + break; + case Intrinsic::usub_sat: + // usub.sat(C, x) produces [0, C]. + if (match(II.getOperand(0), m_APInt(C))) + Upper = *C + 1; + // usub.sat(x, C) produces [0, UINT_MAX - C]. + else if (match(II.getOperand(1), m_APInt(C))) + Upper = APInt::getMaxValue(Width) - *C + 1; + break; + case Intrinsic::ssub_sat: + if (match(II.getOperand(0), m_APInt(C))) { + if (C->isNegative()) { + // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)]. + Lower = APInt::getSignedMinValue(Width); + Upper = *C - APInt::getSignedMinValue(Width) + 1; + } else { + // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX]. + Lower = *C - APInt::getSignedMaxValue(Width); + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } else if (match(II.getOperand(1), m_APInt(C))) { + if (C->isNegative()) { + // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]: + Lower = APInt::getSignedMinValue(Width) - *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } else { + // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) - *C + 1; + } + } + break; + default: + break; + } +} + +static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower, + APInt &Upper, const InstrInfoQuery &IIQ) { + const Value *LHS = nullptr, *RHS = nullptr; + SelectPatternResult R = matchSelectPattern(&SI, LHS, RHS); + if (R.Flavor == SPF_UNKNOWN) + return; + + unsigned BitWidth = SI.getType()->getScalarSizeInBits(); + + if (R.Flavor == SelectPatternFlavor::SPF_ABS) { + // If the negation part of the abs (in RHS) has the NSW flag, + // then the result of abs(X) is [0..SIGNED_MAX], + // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN. + Lower = APInt::getNullValue(BitWidth); + if (match(RHS, m_Neg(m_Specific(LHS))) && + IIQ.hasNoSignedWrap(cast<Instruction>(RHS))) + Upper = APInt::getSignedMaxValue(BitWidth) + 1; + else + Upper = APInt::getSignedMinValue(BitWidth) + 1; + return; + } + + if (R.Flavor == SelectPatternFlavor::SPF_NABS) { + // The result of -abs(X) is <= 0. + Lower = APInt::getSignedMinValue(BitWidth); + Upper = APInt(BitWidth, 1); + return; + } + + const APInt *C; + if (!match(LHS, m_APInt(C)) && !match(RHS, m_APInt(C))) + return; + + switch (R.Flavor) { + case SPF_UMIN: + Upper = *C + 1; + break; + case SPF_UMAX: + Lower = *C; + break; + case SPF_SMIN: + Lower = APInt::getSignedMinValue(BitWidth); + Upper = *C + 1; + break; + case SPF_SMAX: + Lower = *C; + Upper = APInt::getSignedMaxValue(BitWidth) + 1; + break; + default: + break; + } +} + +ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) { + assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction"); + + const APInt *C; + if (match(V, m_APInt(C))) + return ConstantRange(*C); + + InstrInfoQuery IIQ(UseInstrInfo); + unsigned BitWidth = V->getType()->getScalarSizeInBits(); + APInt Lower = APInt(BitWidth, 0); + APInt Upper = APInt(BitWidth, 0); + if (auto *BO = dyn_cast<BinaryOperator>(V)) + setLimitsForBinOp(*BO, Lower, Upper, IIQ); + else if (auto *II = dyn_cast<IntrinsicInst>(V)) + setLimitsForIntrinsic(*II, Lower, Upper); + else if (auto *SI = dyn_cast<SelectInst>(V)) + setLimitsForSelectPattern(*SI, Lower, Upper, IIQ); + + ConstantRange CR = ConstantRange::getNonEmpty(Lower, Upper); + + if (auto *I = dyn_cast<Instruction>(V)) + if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range)) + CR = CR.intersectWith(getConstantRangeFromMetadata(*Range)); + + return CR; +} + +static Optional<int64_t> +getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) { + // Skip over the first indices. + gep_type_iterator GTI = gep_type_begin(GEP); + for (unsigned i = 1; i != Idx; ++i, ++GTI) + /*skip along*/; + + // Compute the offset implied by the rest of the indices. + int64_t Offset = 0; + for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { + ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (!OpC) + return None; + if (OpC->isZero()) + continue; // No offset. + + // Handle struct indices, which add their field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); + continue; + } + + // Otherwise, we have a sequential type like an array or vector. Multiply + // the index by the ElementSize. + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size * OpC->getSExtValue(); + } + + return Offset; +} + +Optional<int64_t> llvm::isPointerOffset(const Value *Ptr1, const Value *Ptr2, + const DataLayout &DL) { + Ptr1 = Ptr1->stripPointerCasts(); + Ptr2 = Ptr2->stripPointerCasts(); + + // Handle the trivial case first. + if (Ptr1 == Ptr2) { + return 0; + } + + const GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1); + const GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2); + + // If one pointer is a GEP see if the GEP is a constant offset from the base, + // as in "P" and "gep P, 1". + // Also do this iteratively to handle the the following case: + // Ptr_t1 = GEP Ptr1, c1 + // Ptr_t2 = GEP Ptr_t1, c2 + // Ptr2 = GEP Ptr_t2, c3 + // where we will return c1+c2+c3. + // TODO: Handle the case when both Ptr1 and Ptr2 are GEPs of some common base + // -- replace getOffsetFromBase with getOffsetAndBase, check that the bases + // are the same, and return the difference between offsets. + auto getOffsetFromBase = [&DL](const GEPOperator *GEP, + const Value *Ptr) -> Optional<int64_t> { + const GEPOperator *GEP_T = GEP; + int64_t OffsetVal = 0; + bool HasSameBase = false; + while (GEP_T) { + auto Offset = getOffsetFromIndex(GEP_T, 1, DL); + if (!Offset) + return None; + OffsetVal += *Offset; + auto Op0 = GEP_T->getOperand(0)->stripPointerCasts(); + if (Op0 == Ptr) { + HasSameBase = true; + break; + } + GEP_T = dyn_cast<GEPOperator>(Op0); + } + if (!HasSameBase) + return None; + return OffsetVal; + }; + + if (GEP1) { + auto Offset = getOffsetFromBase(GEP1, Ptr2); + if (Offset) + return -*Offset; + } + if (GEP2) { + auto Offset = getOffsetFromBase(GEP2, Ptr1); + if (Offset) + return Offset; + } + + // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical + // base. After that base, they may have some number of common (and + // potentially variable) indices. After that they handle some constant + // offset, which determines their offset from each other. At this point, we + // handle no other case. + if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0)) + return None; + + // Skip any common indices and track the GEP types. + unsigned Idx = 1; + for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx) + if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) + break; + + auto Offset1 = getOffsetFromIndex(GEP1, Idx, DL); + auto Offset2 = getOffsetFromIndex(GEP2, Idx, DL); + if (!Offset1 || !Offset2) + return None; + return *Offset2 - *Offset1; +} diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp new file mode 100644 index 000000000000..600f57ab9d71 --- /dev/null +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -0,0 +1,1161 @@ +//===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines vectorizer utilities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" + +#define DEBUG_TYPE "vectorutils" + +using namespace llvm; +using namespace llvm::PatternMatch; + +/// Maximum factor for an interleaved memory access. +static cl::opt<unsigned> MaxInterleaveGroupFactor( + "max-interleave-group-factor", cl::Hidden, + cl::desc("Maximum factor for an interleaved access group (default = 8)"), + cl::init(8)); + +/// Return true if all of the intrinsic's arguments and return type are scalars +/// for the scalar form of the intrinsic, and vectors for the vector form of the +/// intrinsic (except operands that are marked as always being scalar by +/// hasVectorInstrinsicScalarOpd). +bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::bswap: // Begin integer bit-manipulation. + case Intrinsic::bitreverse: + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::fshl: + case Intrinsic::fshr: + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: + case Intrinsic::umul_fix: + case Intrinsic::umul_fix_sat: + case Intrinsic::sqrt: // Begin floating-point. + case Intrinsic::sin: + case Intrinsic::cos: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::log: + case Intrinsic::log10: + case Intrinsic::log2: + case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: + case Intrinsic::copysign: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + case Intrinsic::pow: + case Intrinsic::fma: + case Intrinsic::fmuladd: + case Intrinsic::powi: + case Intrinsic::canonicalize: + return true; + default: + return false; + } +} + +/// Identifies if the vector form of the intrinsic has a scalar operand. +bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, + unsigned ScalarOpdIdx) { + switch (ID) { + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::powi: + return (ScalarOpdIdx == 1); + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: + case Intrinsic::umul_fix: + case Intrinsic::umul_fix_sat: + return (ScalarOpdIdx == 2); + default: + return false; + } +} + +/// Returns intrinsic ID for call. +/// For the input call instruction it finds mapping intrinsic and returns +/// its ID, in case it does not found it return not_intrinsic. +Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, + const TargetLibraryInfo *TLI) { + Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); + if (ID == Intrinsic::not_intrinsic) + return Intrinsic::not_intrinsic; + + if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || + ID == Intrinsic::lifetime_end || ID == Intrinsic::assume || + ID == Intrinsic::sideeffect) + return ID; + return Intrinsic::not_intrinsic; +} + +/// Find the operand of the GEP that should be checked for consecutive +/// stores. This ignores trailing indices that have no effect on the final +/// pointer. +unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { + const DataLayout &DL = Gep->getModule()->getDataLayout(); + unsigned LastOperand = Gep->getNumOperands() - 1; + unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); + + // Walk backwards and try to peel off zeros. + while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { + // Find the type we're currently indexing into. + gep_type_iterator GEPTI = gep_type_begin(Gep); + std::advance(GEPTI, LastOperand - 2); + + // If it's a type with the same allocation size as the result of the GEP we + // can peel off the zero index. + if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) + break; + --LastOperand; + } + + return LastOperand; +} + +/// If the argument is a GEP, then returns the operand identified by +/// getGEPInductionOperand. However, if there is some other non-loop-invariant +/// operand, it returns that instead. +Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); + if (!GEP) + return Ptr; + + unsigned InductionOperand = getGEPInductionOperand(GEP); + + // Check that all of the gep indices are uniform except for our induction + // operand. + for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) + if (i != InductionOperand && + !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) + return Ptr; + return GEP->getOperand(InductionOperand); +} + +/// If a value has only one user that is a CastInst, return it. +Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { + Value *UniqueCast = nullptr; + for (User *U : Ptr->users()) { + CastInst *CI = dyn_cast<CastInst>(U); + if (CI && CI->getType() == Ty) { + if (!UniqueCast) + UniqueCast = CI; + else + return nullptr; + } + } + return UniqueCast; +} + +/// Get the stride of a pointer access in a loop. Looks for symbolic +/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. +Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { + auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); + if (!PtrTy || PtrTy->isAggregateType()) + return nullptr; + + // Try to remove a gep instruction to make the pointer (actually index at this + // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the + // pointer, otherwise, we are analyzing the index. + Value *OrigPtr = Ptr; + + // The size of the pointer access. + int64_t PtrAccessSize = 1; + + Ptr = stripGetElementPtr(Ptr, SE, Lp); + const SCEV *V = SE->getSCEV(Ptr); + + if (Ptr != OrigPtr) + // Strip off casts. + while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) + V = C->getOperand(); + + const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); + if (!S) + return nullptr; + + V = S->getStepRecurrence(*SE); + if (!V) + return nullptr; + + // Strip off the size of access multiplication if we are still analyzing the + // pointer. + if (OrigPtr == Ptr) { + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { + if (M->getOperand(0)->getSCEVType() != scConstant) + return nullptr; + + const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return nullptr; + + int64_t StepVal = APStepVal.getSExtValue(); + if (PtrAccessSize != StepVal) + return nullptr; + V = M->getOperand(1); + } + } + + // Strip off casts. + Type *StripedOffRecurrenceCast = nullptr; + if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { + StripedOffRecurrenceCast = C->getType(); + V = C->getOperand(); + } + + // Look for the loop invariant symbolic value. + const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); + if (!U) + return nullptr; + + Value *Stride = U->getValue(); + if (!Lp->isLoopInvariant(Stride)) + return nullptr; + + // If we have stripped off the recurrence cast we have to make sure that we + // return the value that is used in this loop so that we can replace it later. + if (StripedOffRecurrenceCast) + Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); + + return Stride; +} + +/// Given a vector and an element number, see if the scalar value is +/// already around as a register, for example if it were inserted then extracted +/// from the vector. +Value *llvm::findScalarElement(Value *V, unsigned EltNo) { + assert(V->getType()->isVectorTy() && "Not looking at a vector?"); + VectorType *VTy = cast<VectorType>(V->getType()); + unsigned Width = VTy->getNumElements(); + if (EltNo >= Width) // Out of range access. + return UndefValue::get(VTy->getElementType()); + + if (Constant *C = dyn_cast<Constant>(V)) + return C->getAggregateElement(EltNo); + + if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { + // If this is an insert to a variable element, we don't know what it is. + if (!isa<ConstantInt>(III->getOperand(2))) + return nullptr; + unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); + + // If this is an insert to the element we are looking for, return the + // inserted value. + if (EltNo == IIElt) + return III->getOperand(1); + + // Otherwise, the insertelement doesn't modify the value, recurse on its + // vector input. + return findScalarElement(III->getOperand(0), EltNo); + } + + if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { + unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); + int InEl = SVI->getMaskValue(EltNo); + if (InEl < 0) + return UndefValue::get(VTy->getElementType()); + if (InEl < (int)LHSWidth) + return findScalarElement(SVI->getOperand(0), InEl); + return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); + } + + // Extract a value from a vector add operation with a constant zero. + // TODO: Use getBinOpIdentity() to generalize this. + Value *Val; Constant *C; + if (match(V, m_Add(m_Value(Val), m_Constant(C)))) + if (Constant *Elt = C->getAggregateElement(EltNo)) + if (Elt->isNullValue()) + return findScalarElement(Val, EltNo); + + // Otherwise, we don't know. + return nullptr; +} + +/// Get splat value if the input is a splat vector or return nullptr. +/// This function is not fully general. It checks only 2 cases: +/// the input value is (1) a splat constant vector or (2) a sequence +/// of instructions that broadcasts a scalar at element 0. +const llvm::Value *llvm::getSplatValue(const Value *V) { + if (isa<VectorType>(V->getType())) + if (auto *C = dyn_cast<Constant>(V)) + return C->getSplatValue(); + + // shuf (inselt ?, Splat, 0), ?, <0, undef, 0, ...> + Value *Splat; + if (match(V, m_ShuffleVector(m_InsertElement(m_Value(), m_Value(Splat), + m_ZeroInt()), + m_Value(), m_ZeroInt()))) + return Splat; + + return nullptr; +} + +// This setting is based on its counterpart in value tracking, but it could be +// adjusted if needed. +const unsigned MaxDepth = 6; + +bool llvm::isSplatValue(const Value *V, unsigned Depth) { + assert(Depth <= MaxDepth && "Limit Search Depth"); + + if (isa<VectorType>(V->getType())) { + if (isa<UndefValue>(V)) + return true; + // FIXME: Constant splat analysis does not allow undef elements. + if (auto *C = dyn_cast<Constant>(V)) + return C->getSplatValue() != nullptr; + } + + // FIXME: Constant splat analysis does not allow undef elements. + Constant *Mask; + if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask)))) + return Mask->getSplatValue() != nullptr; + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return false; + + // If both operands of a binop are splats, the result is a splat. + Value *X, *Y, *Z; + if (match(V, m_BinOp(m_Value(X), m_Value(Y)))) + return isSplatValue(X, Depth) && isSplatValue(Y, Depth); + + // If all operands of a select are splats, the result is a splat. + if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z)))) + return isSplatValue(X, Depth) && isSplatValue(Y, Depth) && + isSplatValue(Z, Depth); + + // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops). + + return false; +} + +MapVector<Instruction *, uint64_t> +llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, + const TargetTransformInfo *TTI) { + + // DemandedBits will give us every value's live-out bits. But we want + // to ensure no extra casts would need to be inserted, so every DAG + // of connected values must have the same minimum bitwidth. + EquivalenceClasses<Value *> ECs; + SmallVector<Value *, 16> Worklist; + SmallPtrSet<Value *, 4> Roots; + SmallPtrSet<Value *, 16> Visited; + DenseMap<Value *, uint64_t> DBits; + SmallPtrSet<Instruction *, 4> InstructionSet; + MapVector<Instruction *, uint64_t> MinBWs; + + // Determine the roots. We work bottom-up, from truncs or icmps. + bool SeenExtFromIllegalType = false; + for (auto *BB : Blocks) + for (auto &I : *BB) { + InstructionSet.insert(&I); + + if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && + !TTI->isTypeLegal(I.getOperand(0)->getType())) + SeenExtFromIllegalType = true; + + // Only deal with non-vector integers up to 64-bits wide. + if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && + !I.getType()->isVectorTy() && + I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { + // Don't make work for ourselves. If we know the loaded type is legal, + // don't add it to the worklist. + if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) + continue; + + Worklist.push_back(&I); + Roots.insert(&I); + } + } + // Early exit. + if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) + return MinBWs; + + // Now proceed breadth-first, unioning values together. + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + Value *Leader = ECs.getOrInsertLeaderValue(Val); + + if (Visited.count(Val)) + continue; + Visited.insert(Val); + + // Non-instructions terminate a chain successfully. + if (!isa<Instruction>(Val)) + continue; + Instruction *I = cast<Instruction>(Val); + + // If we encounter a type that is larger than 64 bits, we can't represent + // it so bail out. + if (DB.getDemandedBits(I).getBitWidth() > 64) + return MapVector<Instruction *, uint64_t>(); + + uint64_t V = DB.getDemandedBits(I).getZExtValue(); + DBits[Leader] |= V; + DBits[I] = V; + + // Casts, loads and instructions outside of our range terminate a chain + // successfully. + if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || + !InstructionSet.count(I)) + continue; + + // Unsafe casts terminate a chain unsuccessfully. We can't do anything + // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to + // transform anything that relies on them. + if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || + !I->getType()->isIntegerTy()) { + DBits[Leader] |= ~0ULL; + continue; + } + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + if (isa<PHINode>(I)) + continue; + + if (DBits[Leader] == ~0ULL) + // All bits demanded, no point continuing. + continue; + + for (Value *O : cast<User>(I)->operands()) { + ECs.unionSets(Leader, O); + Worklist.push_back(O); + } + } + + // Now we've discovered all values, walk them to see if there are + // any users we didn't see. If there are, we can't optimize that + // chain. + for (auto &I : DBits) + for (auto *U : I.first->users()) + if (U->getType()->isIntegerTy() && DBits.count(U) == 0) + DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; + + for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { + uint64_t LeaderDemandedBits = 0; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + LeaderDemandedBits |= DBits[*MI]; + + uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - + llvm::countLeadingZeros(LeaderDemandedBits); + // Round up to a power of 2 + if (!isPowerOf2_64((uint64_t)MinBW)) + MinBW = NextPowerOf2(MinBW); + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + // If we are required to shrink a PHI, abandon this entire equivalence class. + bool Abort = false; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { + Abort = true; + break; + } + if (Abort) + continue; + + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { + if (!isa<Instruction>(*MI)) + continue; + Type *Ty = (*MI)->getType(); + if (Roots.count(*MI)) + Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); + if (MinBW < Ty->getScalarSizeInBits()) + MinBWs[cast<Instruction>(*MI)] = MinBW; + } + } + + return MinBWs; +} + +/// Add all access groups in @p AccGroups to @p List. +template <typename ListT> +static void addToAccessGroupList(ListT &List, MDNode *AccGroups) { + // Interpret an access group as a list containing itself. + if (AccGroups->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group"); + List.insert(AccGroups); + return; + } + + for (auto &AccGroupListOp : AccGroups->operands()) { + auto *Item = cast<MDNode>(AccGroupListOp.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + List.insert(Item); + } +} + +MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) { + if (!AccGroups1) + return AccGroups2; + if (!AccGroups2) + return AccGroups1; + if (AccGroups1 == AccGroups2) + return AccGroups1; + + SmallSetVector<Metadata *, 4> Union; + addToAccessGroupList(Union, AccGroups1); + addToAccessGroupList(Union, AccGroups2); + + if (Union.size() == 0) + return nullptr; + if (Union.size() == 1) + return cast<MDNode>(Union.front()); + + LLVMContext &Ctx = AccGroups1->getContext(); + return MDNode::get(Ctx, Union.getArrayRef()); +} + +MDNode *llvm::intersectAccessGroups(const Instruction *Inst1, + const Instruction *Inst2) { + bool MayAccessMem1 = Inst1->mayReadOrWriteMemory(); + bool MayAccessMem2 = Inst2->mayReadOrWriteMemory(); + + if (!MayAccessMem1 && !MayAccessMem2) + return nullptr; + if (!MayAccessMem1) + return Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MayAccessMem2) + return Inst1->getMetadata(LLVMContext::MD_access_group); + + MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group); + MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MD1 || !MD2) + return nullptr; + if (MD1 == MD2) + return MD1; + + // Use set for scalable 'contains' check. + SmallPtrSet<Metadata *, 4> AccGroupSet2; + addToAccessGroupList(AccGroupSet2, MD2); + + SmallVector<Metadata *, 4> Intersection; + if (MD1->getNumOperands() == 0) { + assert(isValidAsAccessGroup(MD1) && "Node must be an access group"); + if (AccGroupSet2.count(MD1)) + Intersection.push_back(MD1); + } else { + for (const MDOperand &Node : MD1->operands()) { + auto *Item = cast<MDNode>(Node.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + if (AccGroupSet2.count(Item)) + Intersection.push_back(Item); + } + } + + if (Intersection.size() == 0) + return nullptr; + if (Intersection.size() == 1) + return cast<MDNode>(Intersection.front()); + + LLVMContext &Ctx = Inst1->getContext(); + return MDNode::get(Ctx, Intersection); +} + +/// \returns \p I after propagating metadata from \p VL. +Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { + Instruction *I0 = cast<Instruction>(VL[0]); + SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; + I0->getAllMetadataOtherThanDebugLoc(Metadata); + + for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_fpmath, + LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load, + LLVMContext::MD_access_group}) { + MDNode *MD = I0->getMetadata(Kind); + + for (int J = 1, E = VL.size(); MD && J != E; ++J) { + const Instruction *IJ = cast<Instruction>(VL[J]); + MDNode *IMD = IJ->getMetadata(Kind); + switch (Kind) { + case LLVMContext::MD_tbaa: + MD = MDNode::getMostGenericTBAA(MD, IMD); + break; + case LLVMContext::MD_alias_scope: + MD = MDNode::getMostGenericAliasScope(MD, IMD); + break; + case LLVMContext::MD_fpmath: + MD = MDNode::getMostGenericFPMath(MD, IMD); + break; + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_invariant_load: + MD = MDNode::intersect(MD, IMD); + break; + case LLVMContext::MD_access_group: + MD = intersectAccessGroups(Inst, IJ); + break; + default: + llvm_unreachable("unhandled metadata"); + } + } + + Inst->setMetadata(Kind, MD); + } + + return Inst; +} + +Constant * +llvm::createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF, + const InterleaveGroup<Instruction> &Group) { + // All 1's means mask is not needed. + if (Group.getNumMembers() == Group.getFactor()) + return nullptr; + + // TODO: support reversed access. + assert(!Group.isReverse() && "Reversed group not supported."); + + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < Group.getFactor(); ++j) { + unsigned HasMember = Group.getMember(j) ? 1 : 0; + Mask.push_back(Builder.getInt1(HasMember)); + } + + return ConstantVector::get(Mask); +} + +Constant *llvm::createReplicatedMask(IRBuilder<> &Builder, + unsigned ReplicationFactor, unsigned VF) { + SmallVector<Constant *, 16> MaskVec; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < ReplicationFactor; j++) + MaskVec.push_back(Builder.getInt32(i)); + + return ConstantVector::get(MaskVec); +} + +Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, + unsigned NumVecs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < NumVecs; j++) + Mask.push_back(Builder.getInt32(j * VF + i)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, + unsigned Stride, unsigned VF) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + Mask.push_back(Builder.getInt32(Start + i * Stride)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, + unsigned NumInts, unsigned NumUndefs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < NumInts; i++) + Mask.push_back(Builder.getInt32(Start + i)); + + Constant *Undef = UndefValue::get(Builder.getInt32Ty()); + for (unsigned i = 0; i < NumUndefs; i++) + Mask.push_back(Undef); + + return ConstantVector::get(Mask); +} + +/// A helper function for concatenating vectors. This function concatenates two +/// vectors having the same element type. If the second vector has fewer +/// elements than the first, it is padded with undefs. +static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, + Value *V2) { + VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); + VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); + assert(VecTy1 && VecTy2 && + VecTy1->getScalarType() == VecTy2->getScalarType() && + "Expect two vectors with the same element type"); + + unsigned NumElts1 = VecTy1->getNumElements(); + unsigned NumElts2 = VecTy2->getNumElements(); + assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); + + if (NumElts1 > NumElts2) { + // Extend with UNDEFs. + Constant *ExtMask = + createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); + V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); + } + + Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); + return Builder.CreateShuffleVector(V1, V2, Mask); +} + +Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { + unsigned NumVecs = Vecs.size(); + assert(NumVecs > 1 && "Should be at least two vectors"); + + SmallVector<Value *, 8> ResList; + ResList.append(Vecs.begin(), Vecs.end()); + do { + SmallVector<Value *, 8> TmpList; + for (unsigned i = 0; i < NumVecs - 1; i += 2) { + Value *V0 = ResList[i], *V1 = ResList[i + 1]; + assert((V0->getType() == V1->getType() || i == NumVecs - 2) && + "Only the last vector may have a different type"); + + TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); + } + + // Push the last vector if the total number of vectors is odd. + if (NumVecs % 2 != 0) + TmpList.push_back(ResList[NumVecs - 1]); + + ResList = TmpList; + NumVecs = ResList.size(); + } while (NumVecs > 1); + + return ResList[0]; +} + +bool llvm::maskIsAllZeroOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt)) + continue; + return false; + } + return true; +} + + +bool llvm::maskIsAllOneOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) + continue; + return false; + } + return true; +} + +/// TODO: This is a lot like known bits, but for +/// vectors. Is there something we can common this with? +APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { + + const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements(); + APInt DemandedElts = APInt::getAllOnesValue(VWidth); + if (auto *CV = dyn_cast<ConstantVector>(Mask)) + for (unsigned i = 0; i < VWidth; i++) + if (CV->getAggregateElement(i)->isNullValue()) + DemandedElts.clearBit(i); + return DemandedElts; +} + +bool InterleavedAccessInfo::isStrided(int Stride) { + unsigned Factor = std::abs(Stride); + return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; +} + +void InterleavedAccessInfo::collectConstStrideAccesses( + MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, + const ValueToValueMap &Strides) { + auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + + // Since it's desired that the load/store instructions be maintained in + // "program order" for the interleaved access analysis, we have to visit the + // blocks in the loop in reverse postorder (i.e., in a topological order). + // Such an ordering will ensure that any load/store that may be executed + // before a second load/store will precede the second load/store in + // AccessStrideInfo. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + for (auto &I : *BB) { + auto *LI = dyn_cast<LoadInst>(&I); + auto *SI = dyn_cast<StoreInst>(&I); + if (!LI && !SI) + continue; + + Value *Ptr = getLoadStorePointerOperand(&I); + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all + // pointers (even those that end up in groups with no gaps) will be overly + // conservative. For full groups, wrapping should be ok since if we would + // wrap around the address space we would do a memory access at nullptr + // even without the transformation. The wrapping checks are therefore + // deferred until after we've formed the interleaved groups. + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false); + + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PointerType *PtrTy = cast<PointerType>(Ptr->getType()); + uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + + // An alignment of 0 means target ABI alignment. + MaybeAlign Alignment = MaybeAlign(getLoadStoreAlignment(&I)); + if (!Alignment) + Alignment = Align(DL.getABITypeAlignment(PtrTy->getElementType())); + + AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, *Alignment); + } +} + +// Analyze interleaved accesses and collect them into interleaved load and +// store groups. +// +// When generating code for an interleaved load group, we effectively hoist all +// loads in the group to the location of the first load in program order. When +// generating code for an interleaved store group, we sink all stores to the +// location of the last store. This code motion can change the order of load +// and store instructions and may break dependences. +// +// The code generation strategy mentioned above ensures that we won't violate +// any write-after-read (WAR) dependences. +// +// E.g., for the WAR dependence: a = A[i]; // (1) +// A[i] = b; // (2) +// +// The store group of (2) is always inserted at or below (2), and the load +// group of (1) is always inserted at or above (1). Thus, the instructions will +// never be reordered. All other dependences are checked to ensure the +// correctness of the instruction reordering. +// +// The algorithm visits all memory accesses in the loop in bottom-up program +// order. Program order is established by traversing the blocks in the loop in +// reverse postorder when collecting the accesses. +// +// We visit the memory accesses in bottom-up order because it can simplify the +// construction of store groups in the presence of write-after-write (WAW) +// dependences. +// +// E.g., for the WAW dependence: A[i] = a; // (1) +// A[i] = b; // (2) +// A[i + 1] = c; // (3) +// +// We will first create a store group with (3) and (2). (1) can't be added to +// this group because it and (2) are dependent. However, (1) can be grouped +// with other accesses that may precede it in program order. Note that a +// bottom-up order does not imply that WAW dependences should not be checked. +void InterleavedAccessInfo::analyzeInterleaving( + bool EnablePredicatedInterleavedMemAccesses) { + LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); + const ValueToValueMap &Strides = LAI->getSymbolicStrides(); + + // Holds all accesses with a constant stride. + MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; + collectConstStrideAccesses(AccessStrideInfo, Strides); + + if (AccessStrideInfo.empty()) + return; + + // Collect the dependences in the loop. + collectDependences(); + + // Holds all interleaved store groups temporarily. + SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups; + // Holds all interleaved load groups temporarily. + SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups; + + // Search in bottom-up program order for pairs of accesses (A and B) that can + // form interleaved load or store groups. In the algorithm below, access A + // precedes access B in program order. We initialize a group for B in the + // outer loop of the algorithm, and then in the inner loop, we attempt to + // insert each A into B's group if: + // + // 1. A and B have the same stride, + // 2. A and B have the same memory object size, and + // 3. A belongs in B's group according to its distance from B. + // + // Special care is taken to ensure group formation will not break any + // dependences. + for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); + BI != E; ++BI) { + Instruction *B = BI->first; + StrideDescriptor DesB = BI->second; + + // Initialize a group for B if it has an allowable stride. Even if we don't + // create a group for B, we continue with the bottom-up algorithm to ensure + // we don't break any of B's dependences. + InterleaveGroup<Instruction> *Group = nullptr; + if (isStrided(DesB.Stride) && + (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) { + Group = getInterleaveGroup(B); + if (!Group) { + LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B + << '\n'); + Group = createInterleaveGroup(B, DesB.Stride, DesB.Alignment); + } + if (B->mayWriteToMemory()) + StoreGroups.insert(Group); + else + LoadGroups.insert(Group); + } + + for (auto AI = std::next(BI); AI != E; ++AI) { + Instruction *A = AI->first; + StrideDescriptor DesA = AI->second; + + // Our code motion strategy implies that we can't have dependences + // between accesses in an interleaved group and other accesses located + // between the first and last member of the group. Note that this also + // means that a group can't have more than one member at a given offset. + // The accesses in a group can have dependences with other accesses, but + // we must ensure we don't extend the boundaries of the group such that + // we encompass those dependent accesses. + // + // For example, assume we have the sequence of accesses shown below in a + // stride-2 loop: + // + // (1, 2) is a group | A[i] = a; // (1) + // | A[i-1] = b; // (2) | + // A[i-3] = c; // (3) + // A[i] = d; // (4) | (2, 4) is not a group + // + // Because accesses (2) and (3) are dependent, we can group (2) with (1) + // but not with (4). If we did, the dependent access (3) would be within + // the boundaries of the (2, 4) group. + if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { + // If a dependence exists and A is already in a group, we know that A + // must be a store since A precedes B and WAR dependences are allowed. + // Thus, A would be sunk below B. We release A's group to prevent this + // illegal code motion. A will then be free to form another group with + // instructions that precede it. + if (isInterleaved(A)) { + InterleaveGroup<Instruction> *StoreGroup = getInterleaveGroup(A); + + LLVM_DEBUG(dbgs() << "LV: Invalidated store group due to " + "dependence between " << *A << " and "<< *B << '\n'); + + StoreGroups.remove(StoreGroup); + releaseGroup(StoreGroup); + } + + // If a dependence exists and A is not already in a group (or it was + // and we just released it), B might be hoisted above A (if B is a + // load) or another store might be sunk below A (if B is a store). In + // either case, we can't add additional instructions to B's group. B + // will only form a group with instructions that it precedes. + break; + } + + // At this point, we've checked for illegal code motion. If either A or B + // isn't strided, there's nothing left to do. + if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) + continue; + + // Ignore A if it's already in a group or isn't the same kind of memory + // operation as B. + // Note that mayReadFromMemory() isn't mutually exclusive to + // mayWriteToMemory in the case of atomic loads. We shouldn't see those + // here, canVectorizeMemory() should have returned false - except for the + // case we asked for optimization remarks. + if (isInterleaved(A) || + (A->mayReadFromMemory() != B->mayReadFromMemory()) || + (A->mayWriteToMemory() != B->mayWriteToMemory())) + continue; + + // Check rules 1 and 2. Ignore A if its stride or size is different from + // that of B. + if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) + continue; + + // Ignore A if the memory object of A and B don't belong to the same + // address space + if (getLoadStoreAddressSpace(A) != getLoadStoreAddressSpace(B)) + continue; + + // Calculate the distance from A to B. + const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( + PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); + if (!DistToB) + continue; + int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); + + // Check rule 3. Ignore A if its distance to B is not a multiple of the + // size. + if (DistanceToB % static_cast<int64_t>(DesB.Size)) + continue; + + // All members of a predicated interleave-group must have the same predicate, + // and currently must reside in the same BB. + BasicBlock *BlockA = A->getParent(); + BasicBlock *BlockB = B->getParent(); + if ((isPredicated(BlockA) || isPredicated(BlockB)) && + (!EnablePredicatedInterleavedMemAccesses || BlockA != BlockB)) + continue; + + // The index of A is the index of B plus A's distance to B in multiples + // of the size. + int IndexA = + Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); + + // Try to insert A into B's group. + if (Group->insertMember(A, IndexA, DesA.Alignment)) { + LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B + << '\n'); + InterleaveGroupMap[A] = Group; + + // Set the first load in program order as the insert position. + if (A->mayReadFromMemory()) + Group->setInsertPos(A); + } + } // Iteration over A accesses. + } // Iteration over B accesses. + + // Remove interleaved store groups with gaps. + for (auto *Group : StoreGroups) + if (Group->getNumMembers() != Group->getFactor()) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved store group due " + "to gaps.\n"); + releaseGroup(Group); + } + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // not check wrapping (see documentation there). + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold + // of runtime SCEV assumptions checks (thereby potentially failing to + // vectorize altogether). + // Additional optional optimizations: + // TODO: If we are peeling the loop and we know that the first pointer doesn't + // wrap then we can deduce that all pointers in the group don't wrap. + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true + // we'll only need at most one runtime check per interleaved group. + for (auto *Group : LoadGroups) { + // Case 1: A full group. Can Skip the checks; For full groups, if the wide + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) + continue; + + // Case 2: If first and last members of the group don't wrap this implies + // that all the pointers in the group don't wrap. + // So we check only group member 0 (which is always guaranteed to exist), + // and group member Factor - 1; If the latter doesn't exist we rely on + // peeling (if it is a non-reversed accsess -- see Case 3). + Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0)); + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + continue; + } + Instruction *LastMember = Group->getMember(Group->getFactor() - 1); + if (LastMember) { + Value *LastMemberPtr = getLoadStorePointerOperand(LastMember); + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + } + } else { + // Case 3: A non-reversed interleaved load group with gaps: We need + // to execute at least one scalar epilogue iteration. This will ensure + // we don't speculatively access memory out-of-bounds. We only need + // to look for a member at index factor - 1, since every group must have + // a member at index zero. + if (Group->isReverse()) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "a reverse access with gaps.\n"); + releaseGroup(Group); + continue; + } + LLVM_DEBUG( + dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; + } + } +} + +void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() { + // If no group had triggered the requirement to create an epilogue loop, + // there is nothing to do. + if (!requiresScalarEpilogue()) + return; + + // Avoid releasing a Group twice. + SmallPtrSet<InterleaveGroup<Instruction> *, 4> DelSet; + for (auto &I : InterleaveGroupMap) { + InterleaveGroup<Instruction> *Group = I.second; + if (Group->requiresScalarEpilogue()) + DelSet.insert(Group); + } + for (auto *Ptr : DelSet) { + LLVM_DEBUG( + dbgs() + << "LV: Invalidate candidate interleaved group due to gaps that " + "require a scalar epilogue (not allowed under optsize) and cannot " + "be masked (not enabled). \n"); + releaseGroup(Ptr); + } + + RequiresScalarEpilogue = false; +} + +template <typename InstT> +void InterleaveGroup<InstT>::addMetadata(InstT *NewInst) const { + llvm_unreachable("addMetadata can only be used for Instruction"); +} + +namespace llvm { +template <> +void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const { + SmallVector<Value *, 4> VL; + std::transform(Members.begin(), Members.end(), std::back_inserter(VL), + [](std::pair<int, Instruction *> p) { return p.second; }); + propagateMetadata(NewInst, VL); +} +} |