summaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r--llvm/lib/Analysis/AliasAnalysis.cpp907
-rw-r--r--llvm/lib/Analysis/AliasAnalysisEvaluator.cpp434
-rw-r--r--llvm/lib/Analysis/AliasAnalysisSummary.cpp103
-rw-r--r--llvm/lib/Analysis/AliasAnalysisSummary.h265
-rw-r--r--llvm/lib/Analysis/AliasSetTracker.cpp776
-rw-r--r--llvm/lib/Analysis/Analysis.cpp138
-rw-r--r--llvm/lib/Analysis/AssumptionCache.cpp302
-rw-r--r--llvm/lib/Analysis/BasicAliasAnalysis.cpp2100
-rw-r--r--llvm/lib/Analysis/BlockFrequencyInfo.cpp342
-rw-r--r--llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp851
-rw-r--r--llvm/lib/Analysis/BranchProbabilityInfo.cpp1057
-rw-r--r--llvm/lib/Analysis/CFG.cpp279
-rw-r--r--llvm/lib/Analysis/CFGPrinter.cpp200
-rw-r--r--llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp931
-rw-r--r--llvm/lib/Analysis/CFLGraph.h660
-rw-r--r--llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp363
-rw-r--r--llvm/lib/Analysis/CGSCCPassManager.cpp709
-rw-r--r--llvm/lib/Analysis/CallGraph.cpp326
-rw-r--r--llvm/lib/Analysis/CallGraphSCCPass.cpp711
-rw-r--r--llvm/lib/Analysis/CallPrinter.cpp91
-rw-r--r--llvm/lib/Analysis/CaptureTracking.cpp389
-rw-r--r--llvm/lib/Analysis/CmpInstAnalysis.cpp143
-rw-r--r--llvm/lib/Analysis/CodeMetrics.cpp195
-rw-r--r--llvm/lib/Analysis/ConstantFolding.cpp2637
-rw-r--r--llvm/lib/Analysis/CostModel.cpp111
-rw-r--r--llvm/lib/Analysis/DDG.cpp203
-rw-r--r--llvm/lib/Analysis/Delinearization.cpp129
-rw-r--r--llvm/lib/Analysis/DemandedBits.cpp488
-rw-r--r--llvm/lib/Analysis/DependenceAnalysis.cpp4009
-rw-r--r--llvm/lib/Analysis/DependenceGraphBuilder.cpp228
-rw-r--r--llvm/lib/Analysis/DivergenceAnalysis.cpp466
-rw-r--r--llvm/lib/Analysis/DomPrinter.cpp297
-rw-r--r--llvm/lib/Analysis/DomTreeUpdater.cpp533
-rw-r--r--llvm/lib/Analysis/DominanceFrontier.cpp96
-rw-r--r--llvm/lib/Analysis/EHPersonalities.cpp135
-rw-r--r--llvm/lib/Analysis/GlobalsModRef.cpp1026
-rw-r--r--llvm/lib/Analysis/GuardUtils.cpp49
-rw-r--r--llvm/lib/Analysis/IVDescriptors.cpp1101
-rw-r--r--llvm/lib/Analysis/IVUsers.cpp426
-rw-r--r--llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp106
-rw-r--r--llvm/lib/Analysis/InlineCost.cpp2223
-rw-r--r--llvm/lib/Analysis/InstCount.cpp78
-rw-r--r--llvm/lib/Analysis/InstructionPrecedenceTracking.cpp160
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp5518
-rw-r--r--llvm/lib/Analysis/Interval.cpp51
-rw-r--r--llvm/lib/Analysis/IntervalPartition.cpp113
-rw-r--r--llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp71
-rw-r--r--llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp74
-rw-r--r--llvm/lib/Analysis/LazyCallGraph.cpp1816
-rw-r--r--llvm/lib/Analysis/LazyValueInfo.cpp2042
-rw-r--r--llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp404
-rw-r--r--llvm/lib/Analysis/Lint.cpp756
-rw-r--r--llvm/lib/Analysis/Loads.cpp481
-rw-r--r--llvm/lib/Analysis/LoopAccessAnalysis.cpp2464
-rw-r--r--llvm/lib/Analysis/LoopAnalysisManager.cpp151
-rw-r--r--llvm/lib/Analysis/LoopCacheAnalysis.cpp625
-rw-r--r--llvm/lib/Analysis/LoopInfo.cpp1109
-rw-r--r--llvm/lib/Analysis/LoopPass.cpp414
-rw-r--r--llvm/lib/Analysis/LoopUnrollAnalyzer.cpp214
-rw-r--r--llvm/lib/Analysis/MemDepPrinter.cpp164
-rw-r--r--llvm/lib/Analysis/MemDerefPrinter.cpp76
-rw-r--r--llvm/lib/Analysis/MemoryBuiltins.cpp1051
-rw-r--r--llvm/lib/Analysis/MemoryDependenceAnalysis.cpp1824
-rw-r--r--llvm/lib/Analysis/MemoryLocation.cpp211
-rw-r--r--llvm/lib/Analysis/MemorySSA.cpp2475
-rw-r--r--llvm/lib/Analysis/MemorySSAUpdater.cpp1440
-rw-r--r--llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp127
-rw-r--r--llvm/lib/Analysis/ModuleSummaryAnalysis.cpp881
-rw-r--r--llvm/lib/Analysis/MustExecute.cpp516
-rw-r--r--llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp164
-rw-r--r--llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp25
-rw-r--r--llvm/lib/Analysis/ObjCARCInstKind.cpp705
-rw-r--r--llvm/lib/Analysis/OptimizationRemarkEmitter.cpp133
-rw-r--r--llvm/lib/Analysis/OrderedBasicBlock.cpp111
-rw-r--r--llvm/lib/Analysis/OrderedInstructions.cpp50
-rw-r--r--llvm/lib/Analysis/PHITransAddr.cpp439
-rw-r--r--llvm/lib/Analysis/PhiValues.cpp212
-rw-r--r--llvm/lib/Analysis/PostDominators.cpp84
-rw-r--r--llvm/lib/Analysis/ProfileSummaryInfo.cpp392
-rw-r--r--llvm/lib/Analysis/PtrUseVisitor.cpp44
-rw-r--r--llvm/lib/Analysis/RegionInfo.cpp215
-rw-r--r--llvm/lib/Analysis/RegionPass.cpp299
-rw-r--r--llvm/lib/Analysis/RegionPrinter.cpp266
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp12530
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp147
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionExpander.cpp2452
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionNormalization.cpp117
-rw-r--r--llvm/lib/Analysis/ScopedNoAliasAA.cpp210
-rw-r--r--llvm/lib/Analysis/StackSafetyAnalysis.cpp676
-rw-r--r--llvm/lib/Analysis/StratifiedSets.h596
-rw-r--r--llvm/lib/Analysis/SyncDependenceAnalysis.cpp380
-rw-r--r--llvm/lib/Analysis/SyntheticCountsUtils.cpp104
-rw-r--r--llvm/lib/Analysis/TargetLibraryInfo.cpp1638
-rw-r--r--llvm/lib/Analysis/TargetTransformInfo.cpp1386
-rw-r--r--llvm/lib/Analysis/Trace.cpp53
-rw-r--r--llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp740
-rw-r--r--llvm/lib/Analysis/TypeMetadataUtils.cpp161
-rw-r--r--llvm/lib/Analysis/VFABIDemangling.cpp418
-rw-r--r--llvm/lib/Analysis/ValueLattice.cpp25
-rw-r--r--llvm/lib/Analysis/ValueLatticeUtils.cpp43
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp5820
-rw-r--r--llvm/lib/Analysis/VectorUtils.cpp1161
102 files changed, 83637 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/AliasAnalysis.cpp b/llvm/lib/Analysis/AliasAnalysis.cpp
new file mode 100644
index 000000000000..55dd9a4cda08
--- /dev/null
+++ b/llvm/lib/Analysis/AliasAnalysis.cpp
@@ -0,0 +1,907 @@
+//==- AliasAnalysis.cpp - Generic Alias Analysis Interface Implementation --==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the generic AliasAnalysis interface which is used as the
+// common interface used by all clients and implementations of alias analysis.
+//
+// This file also implements the default version of the AliasAnalysis interface
+// that is to be used when no other implementation is specified. This does some
+// simple tests that detect obvious cases: two different global pointers cannot
+// alias, a global cannot alias a malloc, two different mallocs cannot alias,
+// etc.
+//
+// This alias analysis implementation really isn't very good for anything, but
+// it is very fast, and makes a nice clean default implementation. Because it
+// handles lots of little corner cases, other, more complex, alias analysis
+// implementations may choose to rely on this pass to resolve these simple and
+// easy cases.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/CFLAndersAliasAnalysis.h"
+#include "llvm/Analysis/CFLSteensAliasAnalysis.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/ObjCARCAliasAnalysis.h"
+#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
+#include "llvm/Analysis/ScopedNoAliasAA.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <iterator>
+
+using namespace llvm;
+
+/// Allow disabling BasicAA from the AA results. This is particularly useful
+/// when testing to isolate a single AA implementation.
+static cl::opt<bool> DisableBasicAA("disable-basicaa", cl::Hidden,
+ cl::init(false));
+
+AAResults::AAResults(AAResults &&Arg)
+ : TLI(Arg.TLI), AAs(std::move(Arg.AAs)), AADeps(std::move(Arg.AADeps)) {
+ for (auto &AA : AAs)
+ AA->setAAResults(this);
+}
+
+AAResults::~AAResults() {
+// FIXME; It would be nice to at least clear out the pointers back to this
+// aggregation here, but we end up with non-nesting lifetimes in the legacy
+// pass manager that prevent this from working. In the legacy pass manager
+// we'll end up with dangling references here in some cases.
+#if 0
+ for (auto &AA : AAs)
+ AA->setAAResults(nullptr);
+#endif
+}
+
+bool AAResults::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // AAResults preserves the AAManager by default, due to the stateless nature
+ // of AliasAnalysis. There is no need to check whether it has been preserved
+ // explicitly. Check if any module dependency was invalidated and caused the
+ // AAManager to be invalidated. Invalidate ourselves in that case.
+ auto PAC = PA.getChecker<AAManager>();
+ if (!PAC.preservedWhenStateless())
+ return true;
+
+ // Check if any of the function dependencies were invalidated, and invalidate
+ // ourselves in that case.
+ for (AnalysisKey *ID : AADeps)
+ if (Inv.invalidate(ID, F, PA))
+ return true;
+
+ // Everything we depend on is still fine, so are we. Nothing to invalidate.
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Default chaining methods
+//===----------------------------------------------------------------------===//
+
+AliasResult AAResults::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB) {
+ AAQueryInfo AAQIP;
+ return alias(LocA, LocB, AAQIP);
+}
+
+AliasResult AAResults::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB, AAQueryInfo &AAQI) {
+ for (const auto &AA : AAs) {
+ auto Result = AA->alias(LocA, LocB, AAQI);
+ if (Result != MayAlias)
+ return Result;
+ }
+ return MayAlias;
+}
+
+bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc,
+ bool OrLocal) {
+ AAQueryInfo AAQIP;
+ return pointsToConstantMemory(Loc, AAQIP, OrLocal);
+}
+
+bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc,
+ AAQueryInfo &AAQI, bool OrLocal) {
+ for (const auto &AA : AAs)
+ if (AA->pointsToConstantMemory(Loc, AAQI, OrLocal))
+ return true;
+
+ return false;
+}
+
+ModRefInfo AAResults::getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) {
+ ModRefInfo Result = ModRefInfo::ModRef;
+
+ for (const auto &AA : AAs) {
+ Result = intersectModRef(Result, AA->getArgModRefInfo(Call, ArgIdx));
+
+ // Early-exit the moment we reach the bottom of the lattice.
+ if (isNoModRef(Result))
+ return ModRefInfo::NoModRef;
+ }
+
+ return Result;
+}
+
+ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(I, Call2, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2,
+ AAQueryInfo &AAQI) {
+ // We may have two calls.
+ if (const auto *Call1 = dyn_cast<CallBase>(I)) {
+ // Check if the two calls modify the same memory.
+ return getModRefInfo(Call1, Call2, AAQI);
+ } else if (I->isFenceLike()) {
+ // If this is a fence, just return ModRef.
+ return ModRefInfo::ModRef;
+ } else {
+ // Otherwise, check if the call modifies or references the
+ // location this memory access defines. The best we can say
+ // is that if the call references what this instruction
+ // defines, it must be clobbered by this location.
+ const MemoryLocation DefLoc = MemoryLocation::get(I);
+ ModRefInfo MR = getModRefInfo(Call2, DefLoc, AAQI);
+ if (isModOrRefSet(MR))
+ return setModAndRef(MR);
+ }
+ return ModRefInfo::NoModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(Call, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ ModRefInfo Result = ModRefInfo::ModRef;
+
+ for (const auto &AA : AAs) {
+ Result = intersectModRef(Result, AA->getModRefInfo(Call, Loc, AAQI));
+
+ // Early-exit the moment we reach the bottom of the lattice.
+ if (isNoModRef(Result))
+ return ModRefInfo::NoModRef;
+ }
+
+ // Try to refine the mod-ref info further using other API entry points to the
+ // aggregate set of AA results.
+ auto MRB = getModRefBehavior(Call);
+ if (MRB == FMRB_DoesNotAccessMemory ||
+ MRB == FMRB_OnlyAccessesInaccessibleMem)
+ return ModRefInfo::NoModRef;
+
+ if (onlyReadsMemory(MRB))
+ Result = clearMod(Result);
+ else if (doesNotReadMemory(MRB))
+ Result = clearRef(Result);
+
+ if (onlyAccessesArgPointees(MRB) || onlyAccessesInaccessibleOrArgMem(MRB)) {
+ bool IsMustAlias = true;
+ ModRefInfo AllArgsMask = ModRefInfo::NoModRef;
+ if (doesAccessArgPointees(MRB)) {
+ for (auto AI = Call->arg_begin(), AE = Call->arg_end(); AI != AE; ++AI) {
+ const Value *Arg = *AI;
+ if (!Arg->getType()->isPointerTy())
+ continue;
+ unsigned ArgIdx = std::distance(Call->arg_begin(), AI);
+ MemoryLocation ArgLoc =
+ MemoryLocation::getForArgument(Call, ArgIdx, TLI);
+ AliasResult ArgAlias = alias(ArgLoc, Loc);
+ if (ArgAlias != NoAlias) {
+ ModRefInfo ArgMask = getArgModRefInfo(Call, ArgIdx);
+ AllArgsMask = unionModRef(AllArgsMask, ArgMask);
+ }
+ // Conservatively clear IsMustAlias unless only MustAlias is found.
+ IsMustAlias &= (ArgAlias == MustAlias);
+ }
+ }
+ // Return NoModRef if no alias found with any argument.
+ if (isNoModRef(AllArgsMask))
+ return ModRefInfo::NoModRef;
+ // Logical & between other AA analyses and argument analysis.
+ Result = intersectModRef(Result, AllArgsMask);
+ // If only MustAlias found above, set Must bit.
+ Result = IsMustAlias ? setMust(Result) : clearMust(Result);
+ }
+
+ // If Loc is a constant memory location, the call definitely could not
+ // modify the memory location.
+ if (isModSet(Result) && pointsToConstantMemory(Loc, /*OrLocal*/ false))
+ Result = clearMod(Result);
+
+ return Result;
+}
+
+ModRefInfo AAResults::getModRefInfo(const CallBase *Call1,
+ const CallBase *Call2) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(Call1, Call2, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const CallBase *Call1,
+ const CallBase *Call2, AAQueryInfo &AAQI) {
+ ModRefInfo Result = ModRefInfo::ModRef;
+
+ for (const auto &AA : AAs) {
+ Result = intersectModRef(Result, AA->getModRefInfo(Call1, Call2, AAQI));
+
+ // Early-exit the moment we reach the bottom of the lattice.
+ if (isNoModRef(Result))
+ return ModRefInfo::NoModRef;
+ }
+
+ // Try to refine the mod-ref info further using other API entry points to the
+ // aggregate set of AA results.
+
+ // If Call1 or Call2 are readnone, they don't interact.
+ auto Call1B = getModRefBehavior(Call1);
+ if (Call1B == FMRB_DoesNotAccessMemory)
+ return ModRefInfo::NoModRef;
+
+ auto Call2B = getModRefBehavior(Call2);
+ if (Call2B == FMRB_DoesNotAccessMemory)
+ return ModRefInfo::NoModRef;
+
+ // If they both only read from memory, there is no dependence.
+ if (onlyReadsMemory(Call1B) && onlyReadsMemory(Call2B))
+ return ModRefInfo::NoModRef;
+
+ // If Call1 only reads memory, the only dependence on Call2 can be
+ // from Call1 reading memory written by Call2.
+ if (onlyReadsMemory(Call1B))
+ Result = clearMod(Result);
+ else if (doesNotReadMemory(Call1B))
+ Result = clearRef(Result);
+
+ // If Call2 only access memory through arguments, accumulate the mod/ref
+ // information from Call1's references to the memory referenced by
+ // Call2's arguments.
+ if (onlyAccessesArgPointees(Call2B)) {
+ if (!doesAccessArgPointees(Call2B))
+ return ModRefInfo::NoModRef;
+ ModRefInfo R = ModRefInfo::NoModRef;
+ bool IsMustAlias = true;
+ for (auto I = Call2->arg_begin(), E = Call2->arg_end(); I != E; ++I) {
+ const Value *Arg = *I;
+ if (!Arg->getType()->isPointerTy())
+ continue;
+ unsigned Call2ArgIdx = std::distance(Call2->arg_begin(), I);
+ auto Call2ArgLoc =
+ MemoryLocation::getForArgument(Call2, Call2ArgIdx, TLI);
+
+ // ArgModRefC2 indicates what Call2 might do to Call2ArgLoc, and the
+ // dependence of Call1 on that location is the inverse:
+ // - If Call2 modifies location, dependence exists if Call1 reads or
+ // writes.
+ // - If Call2 only reads location, dependence exists if Call1 writes.
+ ModRefInfo ArgModRefC2 = getArgModRefInfo(Call2, Call2ArgIdx);
+ ModRefInfo ArgMask = ModRefInfo::NoModRef;
+ if (isModSet(ArgModRefC2))
+ ArgMask = ModRefInfo::ModRef;
+ else if (isRefSet(ArgModRefC2))
+ ArgMask = ModRefInfo::Mod;
+
+ // ModRefC1 indicates what Call1 might do to Call2ArgLoc, and we use
+ // above ArgMask to update dependence info.
+ ModRefInfo ModRefC1 = getModRefInfo(Call1, Call2ArgLoc);
+ ArgMask = intersectModRef(ArgMask, ModRefC1);
+
+ // Conservatively clear IsMustAlias unless only MustAlias is found.
+ IsMustAlias &= isMustSet(ModRefC1);
+
+ R = intersectModRef(unionModRef(R, ArgMask), Result);
+ if (R == Result) {
+ // On early exit, not all args were checked, cannot set Must.
+ if (I + 1 != E)
+ IsMustAlias = false;
+ break;
+ }
+ }
+
+ if (isNoModRef(R))
+ return ModRefInfo::NoModRef;
+
+ // If MustAlias found above, set Must bit.
+ return IsMustAlias ? setMust(R) : clearMust(R);
+ }
+
+ // If Call1 only accesses memory through arguments, check if Call2 references
+ // any of the memory referenced by Call1's arguments. If not, return NoModRef.
+ if (onlyAccessesArgPointees(Call1B)) {
+ if (!doesAccessArgPointees(Call1B))
+ return ModRefInfo::NoModRef;
+ ModRefInfo R = ModRefInfo::NoModRef;
+ bool IsMustAlias = true;
+ for (auto I = Call1->arg_begin(), E = Call1->arg_end(); I != E; ++I) {
+ const Value *Arg = *I;
+ if (!Arg->getType()->isPointerTy())
+ continue;
+ unsigned Call1ArgIdx = std::distance(Call1->arg_begin(), I);
+ auto Call1ArgLoc =
+ MemoryLocation::getForArgument(Call1, Call1ArgIdx, TLI);
+
+ // ArgModRefC1 indicates what Call1 might do to Call1ArgLoc; if Call1
+ // might Mod Call1ArgLoc, then we care about either a Mod or a Ref by
+ // Call2. If Call1 might Ref, then we care only about a Mod by Call2.
+ ModRefInfo ArgModRefC1 = getArgModRefInfo(Call1, Call1ArgIdx);
+ ModRefInfo ModRefC2 = getModRefInfo(Call2, Call1ArgLoc);
+ if ((isModSet(ArgModRefC1) && isModOrRefSet(ModRefC2)) ||
+ (isRefSet(ArgModRefC1) && isModSet(ModRefC2)))
+ R = intersectModRef(unionModRef(R, ArgModRefC1), Result);
+
+ // Conservatively clear IsMustAlias unless only MustAlias is found.
+ IsMustAlias &= isMustSet(ModRefC2);
+
+ if (R == Result) {
+ // On early exit, not all args were checked, cannot set Must.
+ if (I + 1 != E)
+ IsMustAlias = false;
+ break;
+ }
+ }
+
+ if (isNoModRef(R))
+ return ModRefInfo::NoModRef;
+
+ // If MustAlias found above, set Must bit.
+ return IsMustAlias ? setMust(R) : clearMust(R);
+ }
+
+ return Result;
+}
+
+FunctionModRefBehavior AAResults::getModRefBehavior(const CallBase *Call) {
+ FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior;
+
+ for (const auto &AA : AAs) {
+ Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(Call));
+
+ // Early-exit the moment we reach the bottom of the lattice.
+ if (Result == FMRB_DoesNotAccessMemory)
+ return Result;
+ }
+
+ return Result;
+}
+
+FunctionModRefBehavior AAResults::getModRefBehavior(const Function *F) {
+ FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior;
+
+ for (const auto &AA : AAs) {
+ Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(F));
+
+ // Early-exit the moment we reach the bottom of the lattice.
+ if (Result == FMRB_DoesNotAccessMemory)
+ return Result;
+ }
+
+ return Result;
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, AliasResult AR) {
+ switch (AR) {
+ case NoAlias:
+ OS << "NoAlias";
+ break;
+ case MustAlias:
+ OS << "MustAlias";
+ break;
+ case MayAlias:
+ OS << "MayAlias";
+ break;
+ case PartialAlias:
+ OS << "PartialAlias";
+ break;
+ }
+ return OS;
+}
+
+//===----------------------------------------------------------------------===//
+// Helper method implementation
+//===----------------------------------------------------------------------===//
+
+ModRefInfo AAResults::getModRefInfo(const LoadInst *L,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(L, Loc, AAQIP);
+}
+ModRefInfo AAResults::getModRefInfo(const LoadInst *L,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ // Be conservative in the face of atomic.
+ if (isStrongerThan(L->getOrdering(), AtomicOrdering::Unordered))
+ return ModRefInfo::ModRef;
+
+ // If the load address doesn't alias the given address, it doesn't read
+ // or write the specified memory.
+ if (Loc.Ptr) {
+ AliasResult AR = alias(MemoryLocation::get(L), Loc, AAQI);
+ if (AR == NoAlias)
+ return ModRefInfo::NoModRef;
+ if (AR == MustAlias)
+ return ModRefInfo::MustRef;
+ }
+ // Otherwise, a load just reads.
+ return ModRefInfo::Ref;
+}
+
+ModRefInfo AAResults::getModRefInfo(const StoreInst *S,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(S, Loc, AAQIP);
+}
+ModRefInfo AAResults::getModRefInfo(const StoreInst *S,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ // Be conservative in the face of atomic.
+ if (isStrongerThan(S->getOrdering(), AtomicOrdering::Unordered))
+ return ModRefInfo::ModRef;
+
+ if (Loc.Ptr) {
+ AliasResult AR = alias(MemoryLocation::get(S), Loc, AAQI);
+ // If the store address cannot alias the pointer in question, then the
+ // specified memory cannot be modified by the store.
+ if (AR == NoAlias)
+ return ModRefInfo::NoModRef;
+
+ // If the pointer is a pointer to constant memory, then it could not have
+ // been modified by this store.
+ if (pointsToConstantMemory(Loc, AAQI))
+ return ModRefInfo::NoModRef;
+
+ // If the store address aliases the pointer as must alias, set Must.
+ if (AR == MustAlias)
+ return ModRefInfo::MustMod;
+ }
+
+ // Otherwise, a store just writes.
+ return ModRefInfo::Mod;
+}
+
+ModRefInfo AAResults::getModRefInfo(const FenceInst *S, const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(S, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const FenceInst *S,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ // If we know that the location is a constant memory location, the fence
+ // cannot modify this location.
+ if (Loc.Ptr && pointsToConstantMemory(Loc, AAQI))
+ return ModRefInfo::Ref;
+ return ModRefInfo::ModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const VAArgInst *V,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(V, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const VAArgInst *V,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (Loc.Ptr) {
+ AliasResult AR = alias(MemoryLocation::get(V), Loc, AAQI);
+ // If the va_arg address cannot alias the pointer in question, then the
+ // specified memory cannot be accessed by the va_arg.
+ if (AR == NoAlias)
+ return ModRefInfo::NoModRef;
+
+ // If the pointer is a pointer to constant memory, then it could not have
+ // been modified by this va_arg.
+ if (pointsToConstantMemory(Loc, AAQI))
+ return ModRefInfo::NoModRef;
+
+ // If the va_arg aliases the pointer as must alias, set Must.
+ if (AR == MustAlias)
+ return ModRefInfo::MustModRef;
+ }
+
+ // Otherwise, a va_arg reads and writes.
+ return ModRefInfo::ModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(CatchPad, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (Loc.Ptr) {
+ // If the pointer is a pointer to constant memory,
+ // then it could not have been modified by this catchpad.
+ if (pointsToConstantMemory(Loc, AAQI))
+ return ModRefInfo::NoModRef;
+ }
+
+ // Otherwise, a catchpad reads and writes.
+ return ModRefInfo::ModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(CatchRet, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (Loc.Ptr) {
+ // If the pointer is a pointer to constant memory,
+ // then it could not have been modified by this catchpad.
+ if (pointsToConstantMemory(Loc, AAQI))
+ return ModRefInfo::NoModRef;
+ }
+
+ // Otherwise, a catchret reads and writes.
+ return ModRefInfo::ModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(CX, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ // Acquire/Release cmpxchg has properties that matter for arbitrary addresses.
+ if (isStrongerThanMonotonic(CX->getSuccessOrdering()))
+ return ModRefInfo::ModRef;
+
+ if (Loc.Ptr) {
+ AliasResult AR = alias(MemoryLocation::get(CX), Loc, AAQI);
+ // If the cmpxchg address does not alias the location, it does not access
+ // it.
+ if (AR == NoAlias)
+ return ModRefInfo::NoModRef;
+
+ // If the cmpxchg address aliases the pointer as must alias, set Must.
+ if (AR == MustAlias)
+ return ModRefInfo::MustModRef;
+ }
+
+ return ModRefInfo::ModRef;
+}
+
+ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW,
+ const MemoryLocation &Loc) {
+ AAQueryInfo AAQIP;
+ return getModRefInfo(RMW, Loc, AAQIP);
+}
+
+ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ // Acquire/Release atomicrmw has properties that matter for arbitrary addresses.
+ if (isStrongerThanMonotonic(RMW->getOrdering()))
+ return ModRefInfo::ModRef;
+
+ if (Loc.Ptr) {
+ AliasResult AR = alias(MemoryLocation::get(RMW), Loc, AAQI);
+ // If the atomicrmw address does not alias the location, it does not access
+ // it.
+ if (AR == NoAlias)
+ return ModRefInfo::NoModRef;
+
+ // If the atomicrmw address aliases the pointer as must alias, set Must.
+ if (AR == MustAlias)
+ return ModRefInfo::MustModRef;
+ }
+
+ return ModRefInfo::ModRef;
+}
+
+/// Return information about whether a particular call site modifies
+/// or reads the specified memory location \p MemLoc before instruction \p I
+/// in a BasicBlock. An ordered basic block \p OBB can be used to speed up
+/// instruction-ordering queries inside the BasicBlock containing \p I.
+/// FIXME: this is really just shoring-up a deficiency in alias analysis.
+/// BasicAA isn't willing to spend linear time determining whether an alloca
+/// was captured before or after this particular call, while we are. However,
+/// with a smarter AA in place, this test is just wasting compile time.
+ModRefInfo AAResults::callCapturesBefore(const Instruction *I,
+ const MemoryLocation &MemLoc,
+ DominatorTree *DT,
+ OrderedBasicBlock *OBB) {
+ if (!DT)
+ return ModRefInfo::ModRef;
+
+ const Value *Object =
+ GetUnderlyingObject(MemLoc.Ptr, I->getModule()->getDataLayout());
+ if (!isIdentifiedObject(Object) || isa<GlobalValue>(Object) ||
+ isa<Constant>(Object))
+ return ModRefInfo::ModRef;
+
+ const auto *Call = dyn_cast<CallBase>(I);
+ if (!Call || Call == Object)
+ return ModRefInfo::ModRef;
+
+ if (PointerMayBeCapturedBefore(Object, /* ReturnCaptures */ true,
+ /* StoreCaptures */ true, I, DT,
+ /* include Object */ true,
+ /* OrderedBasicBlock */ OBB))
+ return ModRefInfo::ModRef;
+
+ unsigned ArgNo = 0;
+ ModRefInfo R = ModRefInfo::NoModRef;
+ bool IsMustAlias = true;
+ // Set flag only if no May found and all operands processed.
+ for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end();
+ CI != CE; ++CI, ++ArgNo) {
+ // Only look at the no-capture or byval pointer arguments. If this
+ // pointer were passed to arguments that were neither of these, then it
+ // couldn't be no-capture.
+ if (!(*CI)->getType()->isPointerTy() ||
+ (!Call->doesNotCapture(ArgNo) && ArgNo < Call->getNumArgOperands() &&
+ !Call->isByValArgument(ArgNo)))
+ continue;
+
+ AliasResult AR = alias(MemoryLocation(*CI), MemoryLocation(Object));
+ // If this is a no-capture pointer argument, see if we can tell that it
+ // is impossible to alias the pointer we're checking. If not, we have to
+ // assume that the call could touch the pointer, even though it doesn't
+ // escape.
+ if (AR != MustAlias)
+ IsMustAlias = false;
+ if (AR == NoAlias)
+ continue;
+ if (Call->doesNotAccessMemory(ArgNo))
+ continue;
+ if (Call->onlyReadsMemory(ArgNo)) {
+ R = ModRefInfo::Ref;
+ continue;
+ }
+ // Not returning MustModRef since we have not seen all the arguments.
+ return ModRefInfo::ModRef;
+ }
+ return IsMustAlias ? setMust(R) : clearMust(R);
+}
+
+/// canBasicBlockModify - Return true if it is possible for execution of the
+/// specified basic block to modify the location Loc.
+///
+bool AAResults::canBasicBlockModify(const BasicBlock &BB,
+ const MemoryLocation &Loc) {
+ return canInstructionRangeModRef(BB.front(), BB.back(), Loc, ModRefInfo::Mod);
+}
+
+/// canInstructionRangeModRef - Return true if it is possible for the
+/// execution of the specified instructions to mod\ref (according to the
+/// mode) the location Loc. The instructions to consider are all
+/// of the instructions in the range of [I1,I2] INCLUSIVE.
+/// I1 and I2 must be in the same basic block.
+bool AAResults::canInstructionRangeModRef(const Instruction &I1,
+ const Instruction &I2,
+ const MemoryLocation &Loc,
+ const ModRefInfo Mode) {
+ assert(I1.getParent() == I2.getParent() &&
+ "Instructions not in same basic block!");
+ BasicBlock::const_iterator I = I1.getIterator();
+ BasicBlock::const_iterator E = I2.getIterator();
+ ++E; // Convert from inclusive to exclusive range.
+
+ for (; I != E; ++I) // Check every instruction in range
+ if (isModOrRefSet(intersectModRef(getModRefInfo(&*I, Loc), Mode)))
+ return true;
+ return false;
+}
+
+// Provide a definition for the root virtual destructor.
+AAResults::Concept::~Concept() = default;
+
+// Provide a definition for the static object used to identify passes.
+AnalysisKey AAManager::Key;
+
+namespace {
+
+
+} // end anonymous namespace
+
+char ExternalAAWrapperPass::ID = 0;
+
+INITIALIZE_PASS(ExternalAAWrapperPass, "external-aa", "External Alias Analysis",
+ false, true)
+
+ImmutablePass *
+llvm::createExternalAAWrapperPass(ExternalAAWrapperPass::CallbackT Callback) {
+ return new ExternalAAWrapperPass(std::move(Callback));
+}
+
+AAResultsWrapperPass::AAResultsWrapperPass() : FunctionPass(ID) {
+ initializeAAResultsWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+char AAResultsWrapperPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(AAResultsWrapperPass, "aa",
+ "Function Alias Analysis Results", false, true)
+INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(CFLAndersAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(CFLSteensAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ExternalAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScopedNoAliasAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TypeBasedAAWrapperPass)
+INITIALIZE_PASS_END(AAResultsWrapperPass, "aa",
+ "Function Alias Analysis Results", false, true)
+
+FunctionPass *llvm::createAAResultsWrapperPass() {
+ return new AAResultsWrapperPass();
+}
+
+/// Run the wrapper pass to rebuild an aggregation over known AA passes.
+///
+/// This is the legacy pass manager's interface to the new-style AA results
+/// aggregation object. Because this is somewhat shoe-horned into the legacy
+/// pass manager, we hard code all the specific alias analyses available into
+/// it. While the particular set enabled is configured via commandline flags,
+/// adding a new alias analysis to LLVM will require adding support for it to
+/// this list.
+bool AAResultsWrapperPass::runOnFunction(Function &F) {
+ // NB! This *must* be reset before adding new AA results to the new
+ // AAResults object because in the legacy pass manager, each instance
+ // of these will refer to the *same* immutable analyses, registering and
+ // unregistering themselves with them. We need to carefully tear down the
+ // previous object first, in this case replacing it with an empty one, before
+ // registering new results.
+ AAR.reset(
+ new AAResults(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F)));
+
+ // BasicAA is always available for function analyses. Also, we add it first
+ // so that it can trump TBAA results when it proves MustAlias.
+ // FIXME: TBAA should have an explicit mode to support this and then we
+ // should reconsider the ordering here.
+ if (!DisableBasicAA)
+ AAR->addAAResult(getAnalysis<BasicAAWrapperPass>().getResult());
+
+ // Populate the results with the currently available AAs.
+ if (auto *WrapperPass = getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = getAnalysisIfAvailable<TypeBasedAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass =
+ getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = getAnalysisIfAvailable<GlobalsAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = getAnalysisIfAvailable<SCEVAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = getAnalysisIfAvailable<CFLAndersAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = getAnalysisIfAvailable<CFLSteensAAWrapperPass>())
+ AAR->addAAResult(WrapperPass->getResult());
+
+ // If available, run an external AA providing callback over the results as
+ // well.
+ if (auto *WrapperPass = getAnalysisIfAvailable<ExternalAAWrapperPass>())
+ if (WrapperPass->CB)
+ WrapperPass->CB(*this, F, *AAR);
+
+ // Analyses don't mutate the IR, so return false.
+ return false;
+}
+
+void AAResultsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<BasicAAWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+
+ // We also need to mark all the alias analysis passes we will potentially
+ // probe in runOnFunction as used here to ensure the legacy pass manager
+ // preserves them. This hard coding of lists of alias analyses is specific to
+ // the legacy pass manager.
+ AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>();
+ AU.addUsedIfAvailable<TypeBasedAAWrapperPass>();
+ AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>();
+ AU.addUsedIfAvailable<GlobalsAAWrapperPass>();
+ AU.addUsedIfAvailable<SCEVAAWrapperPass>();
+ AU.addUsedIfAvailable<CFLAndersAAWrapperPass>();
+ AU.addUsedIfAvailable<CFLSteensAAWrapperPass>();
+}
+
+AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F,
+ BasicAAResult &BAR) {
+ AAResults AAR(P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F));
+
+ // Add in our explicitly constructed BasicAA results.
+ if (!DisableBasicAA)
+ AAR.addAAResult(BAR);
+
+ // Populate the results with the other currently available AAs.
+ if (auto *WrapperPass =
+ P.getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = P.getAnalysisIfAvailable<TypeBasedAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass =
+ P.getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = P.getAnalysisIfAvailable<GlobalsAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLAndersAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+ if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLSteensAAWrapperPass>())
+ AAR.addAAResult(WrapperPass->getResult());
+
+ return AAR;
+}
+
+bool llvm::isNoAliasCall(const Value *V) {
+ if (const auto *Call = dyn_cast<CallBase>(V))
+ return Call->hasRetAttr(Attribute::NoAlias);
+ return false;
+}
+
+bool llvm::isNoAliasArgument(const Value *V) {
+ if (const Argument *A = dyn_cast<Argument>(V))
+ return A->hasNoAliasAttr();
+ return false;
+}
+
+bool llvm::isIdentifiedObject(const Value *V) {
+ if (isa<AllocaInst>(V))
+ return true;
+ if (isa<GlobalValue>(V) && !isa<GlobalAlias>(V))
+ return true;
+ if (isNoAliasCall(V))
+ return true;
+ if (const Argument *A = dyn_cast<Argument>(V))
+ return A->hasNoAliasAttr() || A->hasByValAttr();
+ return false;
+}
+
+bool llvm::isIdentifiedFunctionLocal(const Value *V) {
+ return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V);
+}
+
+void llvm::getAAResultsAnalysisUsage(AnalysisUsage &AU) {
+ // This function needs to be in sync with llvm::createLegacyPMAAResults -- if
+ // more alias analyses are added to llvm::createLegacyPMAAResults, they need
+ // to be added here also.
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>();
+ AU.addUsedIfAvailable<TypeBasedAAWrapperPass>();
+ AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>();
+ AU.addUsedIfAvailable<GlobalsAAWrapperPass>();
+ AU.addUsedIfAvailable<CFLAndersAAWrapperPass>();
+ AU.addUsedIfAvailable<CFLSteensAAWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp
new file mode 100644
index 000000000000..e83703867e09
--- /dev/null
+++ b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp
@@ -0,0 +1,434 @@
+//===- AliasAnalysisEvaluator.cpp - Alias Analysis Accuracy Evaluator -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AliasAnalysisEvaluator.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+static cl::opt<bool> PrintAll("print-all-alias-modref-info", cl::ReallyHidden);
+
+static cl::opt<bool> PrintNoAlias("print-no-aliases", cl::ReallyHidden);
+static cl::opt<bool> PrintMayAlias("print-may-aliases", cl::ReallyHidden);
+static cl::opt<bool> PrintPartialAlias("print-partial-aliases", cl::ReallyHidden);
+static cl::opt<bool> PrintMustAlias("print-must-aliases", cl::ReallyHidden);
+
+static cl::opt<bool> PrintNoModRef("print-no-modref", cl::ReallyHidden);
+static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden);
+static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden);
+static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden);
+static cl::opt<bool> PrintMust("print-must", cl::ReallyHidden);
+static cl::opt<bool> PrintMustRef("print-mustref", cl::ReallyHidden);
+static cl::opt<bool> PrintMustMod("print-mustmod", cl::ReallyHidden);
+static cl::opt<bool> PrintMustModRef("print-mustmodref", cl::ReallyHidden);
+
+static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden);
+
+static void PrintResults(AliasResult AR, bool P, const Value *V1,
+ const Value *V2, const Module *M) {
+ if (PrintAll || P) {
+ std::string o1, o2;
+ {
+ raw_string_ostream os1(o1), os2(o2);
+ V1->printAsOperand(os1, true, M);
+ V2->printAsOperand(os2, true, M);
+ }
+
+ if (o2 < o1)
+ std::swap(o1, o2);
+ errs() << " " << AR << ":\t" << o1 << ", " << o2 << "\n";
+ }
+}
+
+static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I,
+ Value *Ptr, Module *M) {
+ if (PrintAll || P) {
+ errs() << " " << Msg << ": Ptr: ";
+ Ptr->printAsOperand(errs(), true, M);
+ errs() << "\t<->" << *I << '\n';
+ }
+}
+
+static inline void PrintModRefResults(const char *Msg, bool P, CallBase *CallA,
+ CallBase *CallB, Module *M) {
+ if (PrintAll || P) {
+ errs() << " " << Msg << ": " << *CallA << " <-> " << *CallB << '\n';
+ }
+}
+
+static inline void PrintLoadStoreResults(AliasResult AR, bool P,
+ const Value *V1, const Value *V2,
+ const Module *M) {
+ if (PrintAll || P) {
+ errs() << " " << AR << ": " << *V1 << " <-> " << *V2 << '\n';
+ }
+}
+
+static inline bool isInterestingPointer(Value *V) {
+ return V->getType()->isPointerTy()
+ && !isa<ConstantPointerNull>(V);
+}
+
+PreservedAnalyses AAEvaluator::run(Function &F, FunctionAnalysisManager &AM) {
+ runInternal(F, AM.getResult<AAManager>(F));
+ return PreservedAnalyses::all();
+}
+
+void AAEvaluator::runInternal(Function &F, AAResults &AA) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+
+ ++FunctionCount;
+
+ SetVector<Value *> Pointers;
+ SmallSetVector<CallBase *, 16> Calls;
+ SetVector<Value *> Loads;
+ SetVector<Value *> Stores;
+
+ for (auto &I : F.args())
+ if (I.getType()->isPointerTy()) // Add all pointer arguments.
+ Pointers.insert(&I);
+
+ for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+ if (I->getType()->isPointerTy()) // Add all pointer instructions.
+ Pointers.insert(&*I);
+ if (EvalAAMD && isa<LoadInst>(&*I))
+ Loads.insert(&*I);
+ if (EvalAAMD && isa<StoreInst>(&*I))
+ Stores.insert(&*I);
+ Instruction &Inst = *I;
+ if (auto *Call = dyn_cast<CallBase>(&Inst)) {
+ Value *Callee = Call->getCalledValue();
+ // Skip actual functions for direct function calls.
+ if (!isa<Function>(Callee) && isInterestingPointer(Callee))
+ Pointers.insert(Callee);
+ // Consider formals.
+ for (Use &DataOp : Call->data_ops())
+ if (isInterestingPointer(DataOp))
+ Pointers.insert(DataOp);
+ Calls.insert(Call);
+ } else {
+ // Consider all operands.
+ for (Instruction::op_iterator OI = Inst.op_begin(), OE = Inst.op_end();
+ OI != OE; ++OI)
+ if (isInterestingPointer(*OI))
+ Pointers.insert(*OI);
+ }
+ }
+
+ if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias ||
+ PrintMustAlias || PrintNoModRef || PrintMod || PrintRef || PrintModRef)
+ errs() << "Function: " << F.getName() << ": " << Pointers.size()
+ << " pointers, " << Calls.size() << " call sites\n";
+
+ // iterate over the worklist, and run the full (n^2)/2 disambiguations
+ for (SetVector<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end();
+ I1 != E; ++I1) {
+ auto I1Size = LocationSize::unknown();
+ Type *I1ElTy = cast<PointerType>((*I1)->getType())->getElementType();
+ if (I1ElTy->isSized())
+ I1Size = LocationSize::precise(DL.getTypeStoreSize(I1ElTy));
+
+ for (SetVector<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) {
+ auto I2Size = LocationSize::unknown();
+ Type *I2ElTy = cast<PointerType>((*I2)->getType())->getElementType();
+ if (I2ElTy->isSized())
+ I2Size = LocationSize::precise(DL.getTypeStoreSize(I2ElTy));
+
+ AliasResult AR = AA.alias(*I1, I1Size, *I2, I2Size);
+ switch (AR) {
+ case NoAlias:
+ PrintResults(AR, PrintNoAlias, *I1, *I2, F.getParent());
+ ++NoAliasCount;
+ break;
+ case MayAlias:
+ PrintResults(AR, PrintMayAlias, *I1, *I2, F.getParent());
+ ++MayAliasCount;
+ break;
+ case PartialAlias:
+ PrintResults(AR, PrintPartialAlias, *I1, *I2, F.getParent());
+ ++PartialAliasCount;
+ break;
+ case MustAlias:
+ PrintResults(AR, PrintMustAlias, *I1, *I2, F.getParent());
+ ++MustAliasCount;
+ break;
+ }
+ }
+ }
+
+ if (EvalAAMD) {
+ // iterate over all pairs of load, store
+ for (Value *Load : Loads) {
+ for (Value *Store : Stores) {
+ AliasResult AR = AA.alias(MemoryLocation::get(cast<LoadInst>(Load)),
+ MemoryLocation::get(cast<StoreInst>(Store)));
+ switch (AR) {
+ case NoAlias:
+ PrintLoadStoreResults(AR, PrintNoAlias, Load, Store, F.getParent());
+ ++NoAliasCount;
+ break;
+ case MayAlias:
+ PrintLoadStoreResults(AR, PrintMayAlias, Load, Store, F.getParent());
+ ++MayAliasCount;
+ break;
+ case PartialAlias:
+ PrintLoadStoreResults(AR, PrintPartialAlias, Load, Store, F.getParent());
+ ++PartialAliasCount;
+ break;
+ case MustAlias:
+ PrintLoadStoreResults(AR, PrintMustAlias, Load, Store, F.getParent());
+ ++MustAliasCount;
+ break;
+ }
+ }
+ }
+
+ // iterate over all pairs of store, store
+ for (SetVector<Value *>::iterator I1 = Stores.begin(), E = Stores.end();
+ I1 != E; ++I1) {
+ for (SetVector<Value *>::iterator I2 = Stores.begin(); I2 != I1; ++I2) {
+ AliasResult AR = AA.alias(MemoryLocation::get(cast<StoreInst>(*I1)),
+ MemoryLocation::get(cast<StoreInst>(*I2)));
+ switch (AR) {
+ case NoAlias:
+ PrintLoadStoreResults(AR, PrintNoAlias, *I1, *I2, F.getParent());
+ ++NoAliasCount;
+ break;
+ case MayAlias:
+ PrintLoadStoreResults(AR, PrintMayAlias, *I1, *I2, F.getParent());
+ ++MayAliasCount;
+ break;
+ case PartialAlias:
+ PrintLoadStoreResults(AR, PrintPartialAlias, *I1, *I2, F.getParent());
+ ++PartialAliasCount;
+ break;
+ case MustAlias:
+ PrintLoadStoreResults(AR, PrintMustAlias, *I1, *I2, F.getParent());
+ ++MustAliasCount;
+ break;
+ }
+ }
+ }
+ }
+
+ // Mod/ref alias analysis: compare all pairs of calls and values
+ for (CallBase *Call : Calls) {
+ for (auto Pointer : Pointers) {
+ auto Size = LocationSize::unknown();
+ Type *ElTy = cast<PointerType>(Pointer->getType())->getElementType();
+ if (ElTy->isSized())
+ Size = LocationSize::precise(DL.getTypeStoreSize(ElTy));
+
+ switch (AA.getModRefInfo(Call, Pointer, Size)) {
+ case ModRefInfo::NoModRef:
+ PrintModRefResults("NoModRef", PrintNoModRef, Call, Pointer,
+ F.getParent());
+ ++NoModRefCount;
+ break;
+ case ModRefInfo::Mod:
+ PrintModRefResults("Just Mod", PrintMod, Call, Pointer, F.getParent());
+ ++ModCount;
+ break;
+ case ModRefInfo::Ref:
+ PrintModRefResults("Just Ref", PrintRef, Call, Pointer, F.getParent());
+ ++RefCount;
+ break;
+ case ModRefInfo::ModRef:
+ PrintModRefResults("Both ModRef", PrintModRef, Call, Pointer,
+ F.getParent());
+ ++ModRefCount;
+ break;
+ case ModRefInfo::Must:
+ PrintModRefResults("Must", PrintMust, Call, Pointer, F.getParent());
+ ++MustCount;
+ break;
+ case ModRefInfo::MustMod:
+ PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, Call, Pointer,
+ F.getParent());
+ ++MustModCount;
+ break;
+ case ModRefInfo::MustRef:
+ PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, Call, Pointer,
+ F.getParent());
+ ++MustRefCount;
+ break;
+ case ModRefInfo::MustModRef:
+ PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, Call,
+ Pointer, F.getParent());
+ ++MustModRefCount;
+ break;
+ }
+ }
+ }
+
+ // Mod/ref alias analysis: compare all pairs of calls
+ for (CallBase *CallA : Calls) {
+ for (CallBase *CallB : Calls) {
+ if (CallA == CallB)
+ continue;
+ switch (AA.getModRefInfo(CallA, CallB)) {
+ case ModRefInfo::NoModRef:
+ PrintModRefResults("NoModRef", PrintNoModRef, CallA, CallB,
+ F.getParent());
+ ++NoModRefCount;
+ break;
+ case ModRefInfo::Mod:
+ PrintModRefResults("Just Mod", PrintMod, CallA, CallB, F.getParent());
+ ++ModCount;
+ break;
+ case ModRefInfo::Ref:
+ PrintModRefResults("Just Ref", PrintRef, CallA, CallB, F.getParent());
+ ++RefCount;
+ break;
+ case ModRefInfo::ModRef:
+ PrintModRefResults("Both ModRef", PrintModRef, CallA, CallB,
+ F.getParent());
+ ++ModRefCount;
+ break;
+ case ModRefInfo::Must:
+ PrintModRefResults("Must", PrintMust, CallA, CallB, F.getParent());
+ ++MustCount;
+ break;
+ case ModRefInfo::MustMod:
+ PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, CallA, CallB,
+ F.getParent());
+ ++MustModCount;
+ break;
+ case ModRefInfo::MustRef:
+ PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, CallA, CallB,
+ F.getParent());
+ ++MustRefCount;
+ break;
+ case ModRefInfo::MustModRef:
+ PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, CallA,
+ CallB, F.getParent());
+ ++MustModRefCount;
+ break;
+ }
+ }
+ }
+}
+
+static void PrintPercent(int64_t Num, int64_t Sum) {
+ errs() << "(" << Num * 100LL / Sum << "." << ((Num * 1000LL / Sum) % 10)
+ << "%)\n";
+}
+
+AAEvaluator::~AAEvaluator() {
+ if (FunctionCount == 0)
+ return;
+
+ int64_t AliasSum =
+ NoAliasCount + MayAliasCount + PartialAliasCount + MustAliasCount;
+ errs() << "===== Alias Analysis Evaluator Report =====\n";
+ if (AliasSum == 0) {
+ errs() << " Alias Analysis Evaluator Summary: No pointers!\n";
+ } else {
+ errs() << " " << AliasSum << " Total Alias Queries Performed\n";
+ errs() << " " << NoAliasCount << " no alias responses ";
+ PrintPercent(NoAliasCount, AliasSum);
+ errs() << " " << MayAliasCount << " may alias responses ";
+ PrintPercent(MayAliasCount, AliasSum);
+ errs() << " " << PartialAliasCount << " partial alias responses ";
+ PrintPercent(PartialAliasCount, AliasSum);
+ errs() << " " << MustAliasCount << " must alias responses ";
+ PrintPercent(MustAliasCount, AliasSum);
+ errs() << " Alias Analysis Evaluator Pointer Alias Summary: "
+ << NoAliasCount * 100 / AliasSum << "%/"
+ << MayAliasCount * 100 / AliasSum << "%/"
+ << PartialAliasCount * 100 / AliasSum << "%/"
+ << MustAliasCount * 100 / AliasSum << "%\n";
+ }
+
+ // Display the summary for mod/ref analysis
+ int64_t ModRefSum = NoModRefCount + RefCount + ModCount + ModRefCount +
+ MustCount + MustRefCount + MustModCount + MustModRefCount;
+ if (ModRefSum == 0) {
+ errs() << " Alias Analysis Mod/Ref Evaluator Summary: no "
+ "mod/ref!\n";
+ } else {
+ errs() << " " << ModRefSum << " Total ModRef Queries Performed\n";
+ errs() << " " << NoModRefCount << " no mod/ref responses ";
+ PrintPercent(NoModRefCount, ModRefSum);
+ errs() << " " << ModCount << " mod responses ";
+ PrintPercent(ModCount, ModRefSum);
+ errs() << " " << RefCount << " ref responses ";
+ PrintPercent(RefCount, ModRefSum);
+ errs() << " " << ModRefCount << " mod & ref responses ";
+ PrintPercent(ModRefCount, ModRefSum);
+ errs() << " " << MustCount << " must responses ";
+ PrintPercent(MustCount, ModRefSum);
+ errs() << " " << MustModCount << " must mod responses ";
+ PrintPercent(MustModCount, ModRefSum);
+ errs() << " " << MustRefCount << " must ref responses ";
+ PrintPercent(MustRefCount, ModRefSum);
+ errs() << " " << MustModRefCount << " must mod & ref responses ";
+ PrintPercent(MustModRefCount, ModRefSum);
+ errs() << " Alias Analysis Evaluator Mod/Ref Summary: "
+ << NoModRefCount * 100 / ModRefSum << "%/"
+ << ModCount * 100 / ModRefSum << "%/" << RefCount * 100 / ModRefSum
+ << "%/" << ModRefCount * 100 / ModRefSum << "%/"
+ << MustCount * 100 / ModRefSum << "%/"
+ << MustRefCount * 100 / ModRefSum << "%/"
+ << MustModCount * 100 / ModRefSum << "%/"
+ << MustModRefCount * 100 / ModRefSum << "%\n";
+ }
+}
+
+namespace llvm {
+class AAEvalLegacyPass : public FunctionPass {
+ std::unique_ptr<AAEvaluator> P;
+
+public:
+ static char ID; // Pass identification, replacement for typeid
+ AAEvalLegacyPass() : FunctionPass(ID) {
+ initializeAAEvalLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.setPreservesAll();
+ }
+
+ bool doInitialization(Module &M) override {
+ P.reset(new AAEvaluator());
+ return false;
+ }
+
+ bool runOnFunction(Function &F) override {
+ P->runInternal(F, getAnalysis<AAResultsWrapperPass>().getAAResults());
+ return false;
+ }
+ bool doFinalization(Module &M) override {
+ P.reset();
+ return false;
+ }
+};
+}
+
+char AAEvalLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(AAEvalLegacyPass, "aa-eval",
+ "Exhaustive Alias Analysis Precision Evaluator", false,
+ true)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(AAEvalLegacyPass, "aa-eval",
+ "Exhaustive Alias Analysis Precision Evaluator", false,
+ true)
+
+FunctionPass *llvm::createAAEvalPass() { return new AAEvalLegacyPass(); }
diff --git a/llvm/lib/Analysis/AliasAnalysisSummary.cpp b/llvm/lib/Analysis/AliasAnalysisSummary.cpp
new file mode 100644
index 000000000000..2f3396a44117
--- /dev/null
+++ b/llvm/lib/Analysis/AliasAnalysisSummary.cpp
@@ -0,0 +1,103 @@
+#include "AliasAnalysisSummary.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Compiler.h"
+
+namespace llvm {
+namespace cflaa {
+
+namespace {
+const unsigned AttrEscapedIndex = 0;
+const unsigned AttrUnknownIndex = 1;
+const unsigned AttrGlobalIndex = 2;
+const unsigned AttrCallerIndex = 3;
+const unsigned AttrFirstArgIndex = 4;
+const unsigned AttrLastArgIndex = NumAliasAttrs;
+const unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex;
+
+// It would be *slightly* prettier if we changed these to AliasAttrs, but it
+// seems that both GCC and MSVC emit dynamic initializers for const bitsets.
+using AliasAttr = unsigned;
+const AliasAttr AttrNone = 0;
+const AliasAttr AttrEscaped = 1 << AttrEscapedIndex;
+const AliasAttr AttrUnknown = 1 << AttrUnknownIndex;
+const AliasAttr AttrGlobal = 1 << AttrGlobalIndex;
+const AliasAttr AttrCaller = 1 << AttrCallerIndex;
+const AliasAttr ExternalAttrMask = AttrEscaped | AttrUnknown | AttrGlobal;
+}
+
+AliasAttrs getAttrNone() { return AttrNone; }
+
+AliasAttrs getAttrUnknown() { return AttrUnknown; }
+bool hasUnknownAttr(AliasAttrs Attr) { return Attr.test(AttrUnknownIndex); }
+
+AliasAttrs getAttrCaller() { return AttrCaller; }
+bool hasCallerAttr(AliasAttrs Attr) { return Attr.test(AttrCaller); }
+bool hasUnknownOrCallerAttr(AliasAttrs Attr) {
+ return Attr.test(AttrUnknownIndex) || Attr.test(AttrCallerIndex);
+}
+
+AliasAttrs getAttrEscaped() { return AttrEscaped; }
+bool hasEscapedAttr(AliasAttrs Attr) { return Attr.test(AttrEscapedIndex); }
+
+static AliasAttr argNumberToAttr(unsigned ArgNum) {
+ if (ArgNum >= AttrMaxNumArgs)
+ return AttrUnknown;
+ // N.B. MSVC complains if we use `1U` here, since AliasAttr' ctor takes
+ // an unsigned long long.
+ return AliasAttr(1ULL << (ArgNum + AttrFirstArgIndex));
+}
+
+AliasAttrs getGlobalOrArgAttrFromValue(const Value &Val) {
+ if (isa<GlobalValue>(Val))
+ return AttrGlobal;
+
+ if (auto *Arg = dyn_cast<Argument>(&Val))
+ // Only pointer arguments should have the argument attribute,
+ // because things can't escape through scalars without us seeing a
+ // cast, and thus, interaction with them doesn't matter.
+ if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy())
+ return argNumberToAttr(Arg->getArgNo());
+ return AttrNone;
+}
+
+bool isGlobalOrArgAttr(AliasAttrs Attr) {
+ return Attr.reset(AttrEscapedIndex)
+ .reset(AttrUnknownIndex)
+ .reset(AttrCallerIndex)
+ .any();
+}
+
+AliasAttrs getExternallyVisibleAttrs(AliasAttrs Attr) {
+ return Attr & AliasAttrs(ExternalAttrMask);
+}
+
+Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue,
+ CallBase &Call) {
+ auto Index = IValue.Index;
+ auto *V = (Index == 0) ? &Call : Call.getArgOperand(Index - 1);
+ if (V->getType()->isPointerTy())
+ return InstantiatedValue{V, IValue.DerefLevel};
+ return None;
+}
+
+Optional<InstantiatedRelation>
+instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call) {
+ auto From = instantiateInterfaceValue(ERelation.From, Call);
+ if (!From)
+ return None;
+ auto To = instantiateInterfaceValue(ERelation.To, Call);
+ if (!To)
+ return None;
+ return InstantiatedRelation{*From, *To, ERelation.Offset};
+}
+
+Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr,
+ CallBase &Call) {
+ auto Value = instantiateInterfaceValue(EAttr.IValue, Call);
+ if (!Value)
+ return None;
+ return InstantiatedAttr{*Value, EAttr.Attr};
+}
+}
+}
diff --git a/llvm/lib/Analysis/AliasAnalysisSummary.h b/llvm/lib/Analysis/AliasAnalysisSummary.h
new file mode 100644
index 000000000000..fe75b03cedef
--- /dev/null
+++ b/llvm/lib/Analysis/AliasAnalysisSummary.h
@@ -0,0 +1,265 @@
+//=====- CFLSummary.h - Abstract stratified sets implementation. --------=====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines various utility types and functions useful to
+/// summary-based alias analysis.
+///
+/// Summary-based analysis, also known as bottom-up analysis, is a style of
+/// interprocedrual static analysis that tries to analyze the callees before the
+/// callers get analyzed. The key idea of summary-based analysis is to first
+/// process each function independently, outline its behavior in a condensed
+/// summary, and then instantiate the summary at the callsite when the said
+/// function is called elsewhere. This is often in contrast to another style
+/// called top-down analysis, in which callers are always analyzed first before
+/// the callees.
+///
+/// In a summary-based analysis, functions must be examined independently and
+/// out-of-context. We have no information on the state of the memory, the
+/// arguments, the global values, and anything else external to the function. To
+/// carry out the analysis conservative assumptions have to be made about those
+/// external states. In exchange for the potential loss of precision, the
+/// summary we obtain this way is highly reusable, which makes the analysis
+/// easier to scale to large programs even if carried out context-sensitively.
+///
+/// Currently, all CFL-based alias analyses adopt the summary-based approach
+/// and therefore heavily rely on this header.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H
+#define LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/InstrTypes.h"
+#include <bitset>
+
+namespace llvm {
+namespace cflaa {
+
+//===----------------------------------------------------------------------===//
+// AliasAttr related stuffs
+//===----------------------------------------------------------------------===//
+
+/// The number of attributes that AliasAttr should contain. Attributes are
+/// described below, and 32 was an arbitrary choice because it fits nicely in 32
+/// bits (because we use a bitset for AliasAttr).
+static const unsigned NumAliasAttrs = 32;
+
+/// These are attributes that an alias analysis can use to mark certain special
+/// properties of a given pointer. Refer to the related functions below to see
+/// what kinds of attributes are currently defined.
+typedef std::bitset<NumAliasAttrs> AliasAttrs;
+
+/// Attr represent whether the said pointer comes from an unknown source
+/// (such as opaque memory or an integer cast).
+AliasAttrs getAttrNone();
+
+/// AttrUnknown represent whether the said pointer comes from a source not known
+/// to alias analyses (such as opaque memory or an integer cast).
+AliasAttrs getAttrUnknown();
+bool hasUnknownAttr(AliasAttrs);
+
+/// AttrCaller represent whether the said pointer comes from a source not known
+/// to the current function but known to the caller. Values pointed to by the
+/// arguments of the current function have this attribute set
+AliasAttrs getAttrCaller();
+bool hasCallerAttr(AliasAttrs);
+bool hasUnknownOrCallerAttr(AliasAttrs);
+
+/// AttrEscaped represent whether the said pointer comes from a known source but
+/// escapes to the unknown world (e.g. casted to an integer, or passed as an
+/// argument to opaque function). Unlike non-escaped pointers, escaped ones may
+/// alias pointers coming from unknown sources.
+AliasAttrs getAttrEscaped();
+bool hasEscapedAttr(AliasAttrs);
+
+/// AttrGlobal represent whether the said pointer is a global value.
+/// AttrArg represent whether the said pointer is an argument, and if so, what
+/// index the argument has.
+AliasAttrs getGlobalOrArgAttrFromValue(const Value &);
+bool isGlobalOrArgAttr(AliasAttrs);
+
+/// Given an AliasAttrs, return a new AliasAttrs that only contains attributes
+/// meaningful to the caller. This function is primarily used for
+/// interprocedural analysis
+/// Currently, externally visible AliasAttrs include AttrUnknown, AttrGlobal,
+/// and AttrEscaped
+AliasAttrs getExternallyVisibleAttrs(AliasAttrs);
+
+//===----------------------------------------------------------------------===//
+// Function summary related stuffs
+//===----------------------------------------------------------------------===//
+
+/// The maximum number of arguments we can put into a summary.
+static const unsigned MaxSupportedArgsInSummary = 50;
+
+/// We use InterfaceValue to describe parameters/return value, as well as
+/// potential memory locations that are pointed to by parameters/return value,
+/// of a function.
+/// Index is an integer which represents a single parameter or a return value.
+/// When the index is 0, it refers to the return value. Non-zero index i refers
+/// to the i-th parameter.
+/// DerefLevel indicates the number of dereferences one must perform on the
+/// parameter/return value to get this InterfaceValue.
+struct InterfaceValue {
+ unsigned Index;
+ unsigned DerefLevel;
+};
+
+inline bool operator==(InterfaceValue LHS, InterfaceValue RHS) {
+ return LHS.Index == RHS.Index && LHS.DerefLevel == RHS.DerefLevel;
+}
+inline bool operator!=(InterfaceValue LHS, InterfaceValue RHS) {
+ return !(LHS == RHS);
+}
+inline bool operator<(InterfaceValue LHS, InterfaceValue RHS) {
+ return LHS.Index < RHS.Index ||
+ (LHS.Index == RHS.Index && LHS.DerefLevel < RHS.DerefLevel);
+}
+inline bool operator>(InterfaceValue LHS, InterfaceValue RHS) {
+ return RHS < LHS;
+}
+inline bool operator<=(InterfaceValue LHS, InterfaceValue RHS) {
+ return !(RHS < LHS);
+}
+inline bool operator>=(InterfaceValue LHS, InterfaceValue RHS) {
+ return !(LHS < RHS);
+}
+
+// We use UnknownOffset to represent pointer offsets that cannot be determined
+// at compile time. Note that MemoryLocation::UnknownSize cannot be used here
+// because we require a signed value.
+static const int64_t UnknownOffset = INT64_MAX;
+
+inline int64_t addOffset(int64_t LHS, int64_t RHS) {
+ if (LHS == UnknownOffset || RHS == UnknownOffset)
+ return UnknownOffset;
+ // FIXME: Do we need to guard against integer overflow here?
+ return LHS + RHS;
+}
+
+/// We use ExternalRelation to describe an externally visible aliasing relations
+/// between parameters/return value of a function.
+struct ExternalRelation {
+ InterfaceValue From, To;
+ int64_t Offset;
+};
+
+inline bool operator==(ExternalRelation LHS, ExternalRelation RHS) {
+ return LHS.From == RHS.From && LHS.To == RHS.To && LHS.Offset == RHS.Offset;
+}
+inline bool operator!=(ExternalRelation LHS, ExternalRelation RHS) {
+ return !(LHS == RHS);
+}
+inline bool operator<(ExternalRelation LHS, ExternalRelation RHS) {
+ if (LHS.From < RHS.From)
+ return true;
+ if (LHS.From > RHS.From)
+ return false;
+ if (LHS.To < RHS.To)
+ return true;
+ if (LHS.To > RHS.To)
+ return false;
+ return LHS.Offset < RHS.Offset;
+}
+inline bool operator>(ExternalRelation LHS, ExternalRelation RHS) {
+ return RHS < LHS;
+}
+inline bool operator<=(ExternalRelation LHS, ExternalRelation RHS) {
+ return !(RHS < LHS);
+}
+inline bool operator>=(ExternalRelation LHS, ExternalRelation RHS) {
+ return !(LHS < RHS);
+}
+
+/// We use ExternalAttribute to describe an externally visible AliasAttrs
+/// for parameters/return value.
+struct ExternalAttribute {
+ InterfaceValue IValue;
+ AliasAttrs Attr;
+};
+
+/// AliasSummary is just a collection of ExternalRelation and ExternalAttribute
+struct AliasSummary {
+ // RetParamRelations is a collection of ExternalRelations.
+ SmallVector<ExternalRelation, 8> RetParamRelations;
+
+ // RetParamAttributes is a collection of ExternalAttributes.
+ SmallVector<ExternalAttribute, 8> RetParamAttributes;
+};
+
+/// This is the result of instantiating InterfaceValue at a particular call
+struct InstantiatedValue {
+ Value *Val;
+ unsigned DerefLevel;
+};
+Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue,
+ CallBase &Call);
+
+inline bool operator==(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel;
+}
+inline bool operator!=(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return !(LHS == RHS);
+}
+inline bool operator<(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return std::less<Value *>()(LHS.Val, RHS.Val) ||
+ (LHS.Val == RHS.Val && LHS.DerefLevel < RHS.DerefLevel);
+}
+inline bool operator>(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return RHS < LHS;
+}
+inline bool operator<=(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return !(RHS < LHS);
+}
+inline bool operator>=(InstantiatedValue LHS, InstantiatedValue RHS) {
+ return !(LHS < RHS);
+}
+
+/// This is the result of instantiating ExternalRelation at a particular
+/// callsite
+struct InstantiatedRelation {
+ InstantiatedValue From, To;
+ int64_t Offset;
+};
+Optional<InstantiatedRelation>
+instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call);
+
+/// This is the result of instantiating ExternalAttribute at a particular
+/// callsite
+struct InstantiatedAttr {
+ InstantiatedValue IValue;
+ AliasAttrs Attr;
+};
+Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr,
+ CallBase &Call);
+}
+
+template <> struct DenseMapInfo<cflaa::InstantiatedValue> {
+ static inline cflaa::InstantiatedValue getEmptyKey() {
+ return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getEmptyKey(),
+ DenseMapInfo<unsigned>::getEmptyKey()};
+ }
+ static inline cflaa::InstantiatedValue getTombstoneKey() {
+ return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getTombstoneKey(),
+ DenseMapInfo<unsigned>::getTombstoneKey()};
+ }
+ static unsigned getHashValue(const cflaa::InstantiatedValue &IV) {
+ return DenseMapInfo<std::pair<Value *, unsigned>>::getHashValue(
+ std::make_pair(IV.Val, IV.DerefLevel));
+ }
+ static bool isEqual(const cflaa::InstantiatedValue &LHS,
+ const cflaa::InstantiatedValue &RHS) {
+ return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel;
+ }
+};
+}
+
+#endif
diff --git a/llvm/lib/Analysis/AliasSetTracker.cpp b/llvm/lib/Analysis/AliasSetTracker.cpp
new file mode 100644
index 000000000000..79fbcd464c1b
--- /dev/null
+++ b/llvm/lib/Analysis/AliasSetTracker.cpp
@@ -0,0 +1,776 @@
+//===- AliasSetTracker.cpp - Alias Sets Tracker implementation-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the AliasSetTracker and AliasSet classes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AliasSetTracker.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/GuardUtils.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <cstdint>
+#include <vector>
+
+using namespace llvm;
+
+static cl::opt<unsigned>
+ SaturationThreshold("alias-set-saturation-threshold", cl::Hidden,
+ cl::init(250),
+ cl::desc("The maximum number of pointers may-alias "
+ "sets may contain before degradation"));
+
+/// mergeSetIn - Merge the specified alias set into this alias set.
+///
+void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) {
+ assert(!AS.Forward && "Alias set is already forwarding!");
+ assert(!Forward && "This set is a forwarding set!!");
+
+ bool WasMustAlias = (Alias == SetMustAlias);
+ // Update the alias and access types of this set...
+ Access |= AS.Access;
+ Alias |= AS.Alias;
+
+ if (Alias == SetMustAlias) {
+ // Check that these two merged sets really are must aliases. Since both
+ // used to be must-alias sets, we can just check any pointer from each set
+ // for aliasing.
+ AliasAnalysis &AA = AST.getAliasAnalysis();
+ PointerRec *L = getSomePointer();
+ PointerRec *R = AS.getSomePointer();
+
+ // If the pointers are not a must-alias pair, this set becomes a may alias.
+ if (AA.alias(MemoryLocation(L->getValue(), L->getSize(), L->getAAInfo()),
+ MemoryLocation(R->getValue(), R->getSize(), R->getAAInfo())) !=
+ MustAlias)
+ Alias = SetMayAlias;
+ }
+
+ if (Alias == SetMayAlias) {
+ if (WasMustAlias)
+ AST.TotalMayAliasSetSize += size();
+ if (AS.Alias == SetMustAlias)
+ AST.TotalMayAliasSetSize += AS.size();
+ }
+
+ bool ASHadUnknownInsts = !AS.UnknownInsts.empty();
+ if (UnknownInsts.empty()) { // Merge call sites...
+ if (ASHadUnknownInsts) {
+ std::swap(UnknownInsts, AS.UnknownInsts);
+ addRef();
+ }
+ } else if (ASHadUnknownInsts) {
+ UnknownInsts.insert(UnknownInsts.end(), AS.UnknownInsts.begin(), AS.UnknownInsts.end());
+ AS.UnknownInsts.clear();
+ }
+
+ AS.Forward = this; // Forward across AS now...
+ addRef(); // AS is now pointing to us...
+
+ // Merge the list of constituent pointers...
+ if (AS.PtrList) {
+ SetSize += AS.size();
+ AS.SetSize = 0;
+ *PtrListEnd = AS.PtrList;
+ AS.PtrList->setPrevInList(PtrListEnd);
+ PtrListEnd = AS.PtrListEnd;
+
+ AS.PtrList = nullptr;
+ AS.PtrListEnd = &AS.PtrList;
+ assert(*AS.PtrListEnd == nullptr && "End of list is not null?");
+ }
+ if (ASHadUnknownInsts)
+ AS.dropRef(AST);
+}
+
+void AliasSetTracker::removeAliasSet(AliasSet *AS) {
+ if (AliasSet *Fwd = AS->Forward) {
+ Fwd->dropRef(*this);
+ AS->Forward = nullptr;
+ } else // Update TotalMayAliasSetSize only if not forwarding.
+ if (AS->Alias == AliasSet::SetMayAlias)
+ TotalMayAliasSetSize -= AS->size();
+
+ AliasSets.erase(AS);
+ // If we've removed the saturated alias set, set saturated marker back to
+ // nullptr and ensure this tracker is empty.
+ if (AS == AliasAnyAS) {
+ AliasAnyAS = nullptr;
+ assert(AliasSets.empty() && "Tracker not empty");
+ }
+}
+
+void AliasSet::removeFromTracker(AliasSetTracker &AST) {
+ assert(RefCount == 0 && "Cannot remove non-dead alias set from tracker!");
+ AST.removeAliasSet(this);
+}
+
+void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry,
+ LocationSize Size, const AAMDNodes &AAInfo,
+ bool KnownMustAlias, bool SkipSizeUpdate) {
+ assert(!Entry.hasAliasSet() && "Entry already in set!");
+
+ // Check to see if we have to downgrade to _may_ alias.
+ if (isMustAlias())
+ if (PointerRec *P = getSomePointer()) {
+ if (!KnownMustAlias) {
+ AliasAnalysis &AA = AST.getAliasAnalysis();
+ AliasResult Result = AA.alias(
+ MemoryLocation(P->getValue(), P->getSize(), P->getAAInfo()),
+ MemoryLocation(Entry.getValue(), Size, AAInfo));
+ if (Result != MustAlias) {
+ Alias = SetMayAlias;
+ AST.TotalMayAliasSetSize += size();
+ }
+ assert(Result != NoAlias && "Cannot be part of must set!");
+ } else if (!SkipSizeUpdate)
+ P->updateSizeAndAAInfo(Size, AAInfo);
+ }
+
+ Entry.setAliasSet(this);
+ Entry.updateSizeAndAAInfo(Size, AAInfo);
+
+ // Add it to the end of the list...
+ ++SetSize;
+ assert(*PtrListEnd == nullptr && "End of list is not null?");
+ *PtrListEnd = &Entry;
+ PtrListEnd = Entry.setPrevInList(PtrListEnd);
+ assert(*PtrListEnd == nullptr && "End of list is not null?");
+ // Entry points to alias set.
+ addRef();
+
+ if (Alias == SetMayAlias)
+ AST.TotalMayAliasSetSize++;
+}
+
+void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) {
+ if (UnknownInsts.empty())
+ addRef();
+ UnknownInsts.emplace_back(I);
+
+ // Guards are marked as modifying memory for control flow modelling purposes,
+ // but don't actually modify any specific memory location.
+ using namespace PatternMatch;
+ bool MayWriteMemory = I->mayWriteToMemory() && !isGuard(I) &&
+ !(I->use_empty() && match(I, m_Intrinsic<Intrinsic::invariant_start>()));
+ if (!MayWriteMemory) {
+ Alias = SetMayAlias;
+ Access |= RefAccess;
+ return;
+ }
+
+ // FIXME: This should use mod/ref information to make this not suck so bad
+ Alias = SetMayAlias;
+ Access = ModRefAccess;
+}
+
+/// aliasesPointer - If the specified pointer "may" (or must) alias one of the
+/// members in the set return the appropriate AliasResult. Otherwise return
+/// NoAlias.
+///
+AliasResult AliasSet::aliasesPointer(const Value *Ptr, LocationSize Size,
+ const AAMDNodes &AAInfo,
+ AliasAnalysis &AA) const {
+ if (AliasAny)
+ return MayAlias;
+
+ if (Alias == SetMustAlias) {
+ assert(UnknownInsts.empty() && "Illegal must alias set!");
+
+ // If this is a set of MustAliases, only check to see if the pointer aliases
+ // SOME value in the set.
+ PointerRec *SomePtr = getSomePointer();
+ assert(SomePtr && "Empty must-alias set??");
+ return AA.alias(MemoryLocation(SomePtr->getValue(), SomePtr->getSize(),
+ SomePtr->getAAInfo()),
+ MemoryLocation(Ptr, Size, AAInfo));
+ }
+
+ // If this is a may-alias set, we have to check all of the pointers in the set
+ // to be sure it doesn't alias the set...
+ for (iterator I = begin(), E = end(); I != E; ++I)
+ if (AliasResult AR = AA.alias(
+ MemoryLocation(Ptr, Size, AAInfo),
+ MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo())))
+ return AR;
+
+ // Check the unknown instructions...
+ if (!UnknownInsts.empty()) {
+ for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i)
+ if (auto *Inst = getUnknownInst(i))
+ if (isModOrRefSet(
+ AA.getModRefInfo(Inst, MemoryLocation(Ptr, Size, AAInfo))))
+ return MayAlias;
+ }
+
+ return NoAlias;
+}
+
+bool AliasSet::aliasesUnknownInst(const Instruction *Inst,
+ AliasAnalysis &AA) const {
+
+ if (AliasAny)
+ return true;
+
+ assert(Inst->mayReadOrWriteMemory() &&
+ "Instruction must either read or write memory.");
+
+ for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) {
+ if (auto *UnknownInst = getUnknownInst(i)) {
+ const auto *C1 = dyn_cast<CallBase>(UnknownInst);
+ const auto *C2 = dyn_cast<CallBase>(Inst);
+ if (!C1 || !C2 || isModOrRefSet(AA.getModRefInfo(C1, C2)) ||
+ isModOrRefSet(AA.getModRefInfo(C2, C1)))
+ return true;
+ }
+ }
+
+ for (iterator I = begin(), E = end(); I != E; ++I)
+ if (isModOrRefSet(AA.getModRefInfo(
+ Inst, MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo()))))
+ return true;
+
+ return false;
+}
+
+Instruction* AliasSet::getUniqueInstruction() {
+ if (AliasAny)
+ // May have collapses alias set
+ return nullptr;
+ if (begin() != end()) {
+ if (!UnknownInsts.empty())
+ // Another instruction found
+ return nullptr;
+ if (std::next(begin()) != end())
+ // Another instruction found
+ return nullptr;
+ Value *Addr = begin()->getValue();
+ assert(!Addr->user_empty() &&
+ "where's the instruction which added this pointer?");
+ if (std::next(Addr->user_begin()) != Addr->user_end())
+ // Another instruction found -- this is really restrictive
+ // TODO: generalize!
+ return nullptr;
+ return cast<Instruction>(*(Addr->user_begin()));
+ }
+ if (1 != UnknownInsts.size())
+ return nullptr;
+ return cast<Instruction>(UnknownInsts[0]);
+}
+
+void AliasSetTracker::clear() {
+ // Delete all the PointerRec entries.
+ for (PointerMapType::iterator I = PointerMap.begin(), E = PointerMap.end();
+ I != E; ++I)
+ I->second->eraseFromList();
+
+ PointerMap.clear();
+
+ // The alias sets should all be clear now.
+ AliasSets.clear();
+}
+
+/// mergeAliasSetsForPointer - Given a pointer, merge all alias sets that may
+/// alias the pointer. Return the unified set, or nullptr if no set that aliases
+/// the pointer was found. MustAliasAll is updated to true/false if the pointer
+/// is found to MustAlias all the sets it merged.
+AliasSet *AliasSetTracker::mergeAliasSetsForPointer(const Value *Ptr,
+ LocationSize Size,
+ const AAMDNodes &AAInfo,
+ bool &MustAliasAll) {
+ AliasSet *FoundSet = nullptr;
+ AliasResult AllAR = MustAlias;
+ for (iterator I = begin(), E = end(); I != E;) {
+ iterator Cur = I++;
+ if (Cur->Forward)
+ continue;
+
+ AliasResult AR = Cur->aliasesPointer(Ptr, Size, AAInfo, AA);
+ if (AR == NoAlias)
+ continue;
+
+ AllAR =
+ AliasResult(AllAR & AR); // Possible downgrade to May/Partial, even No
+
+ if (!FoundSet) {
+ // If this is the first alias set ptr can go into, remember it.
+ FoundSet = &*Cur;
+ } else {
+ // Otherwise, we must merge the sets.
+ FoundSet->mergeSetIn(*Cur, *this);
+ }
+ }
+
+ MustAliasAll = (AllAR == MustAlias);
+ return FoundSet;
+}
+
+AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) {
+ AliasSet *FoundSet = nullptr;
+ for (iterator I = begin(), E = end(); I != E;) {
+ iterator Cur = I++;
+ if (Cur->Forward || !Cur->aliasesUnknownInst(Inst, AA))
+ continue;
+ if (!FoundSet) {
+ // If this is the first alias set ptr can go into, remember it.
+ FoundSet = &*Cur;
+ } else {
+ // Otherwise, we must merge the sets.
+ FoundSet->mergeSetIn(*Cur, *this);
+ }
+ }
+ return FoundSet;
+}
+
+AliasSet &AliasSetTracker::getAliasSetFor(const MemoryLocation &MemLoc) {
+
+ Value * const Pointer = const_cast<Value*>(MemLoc.Ptr);
+ const LocationSize Size = MemLoc.Size;
+ const AAMDNodes &AAInfo = MemLoc.AATags;
+
+ AliasSet::PointerRec &Entry = getEntryFor(Pointer);
+
+ if (AliasAnyAS) {
+ // At this point, the AST is saturated, so we only have one active alias
+ // set. That means we already know which alias set we want to return, and
+ // just need to add the pointer to that set to keep the data structure
+ // consistent.
+ // This, of course, means that we will never need a merge here.
+ if (Entry.hasAliasSet()) {
+ Entry.updateSizeAndAAInfo(Size, AAInfo);
+ assert(Entry.getAliasSet(*this) == AliasAnyAS &&
+ "Entry in saturated AST must belong to only alias set");
+ } else {
+ AliasAnyAS->addPointer(*this, Entry, Size, AAInfo);
+ }
+ return *AliasAnyAS;
+ }
+
+ bool MustAliasAll = false;
+ // Check to see if the pointer is already known.
+ if (Entry.hasAliasSet()) {
+ // If the size changed, we may need to merge several alias sets.
+ // Note that we can *not* return the result of mergeAliasSetsForPointer
+ // due to a quirk of alias analysis behavior. Since alias(undef, undef)
+ // is NoAlias, mergeAliasSetsForPointer(undef, ...) will not find the
+ // the right set for undef, even if it exists.
+ if (Entry.updateSizeAndAAInfo(Size, AAInfo))
+ mergeAliasSetsForPointer(Pointer, Size, AAInfo, MustAliasAll);
+ // Return the set!
+ return *Entry.getAliasSet(*this)->getForwardedTarget(*this);
+ }
+
+ if (AliasSet *AS =
+ mergeAliasSetsForPointer(Pointer, Size, AAInfo, MustAliasAll)) {
+ // Add it to the alias set it aliases.
+ AS->addPointer(*this, Entry, Size, AAInfo, MustAliasAll);
+ return *AS;
+ }
+
+ // Otherwise create a new alias set to hold the loaded pointer.
+ AliasSets.push_back(new AliasSet());
+ AliasSets.back().addPointer(*this, Entry, Size, AAInfo, true);
+ return AliasSets.back();
+}
+
+void AliasSetTracker::add(Value *Ptr, LocationSize Size,
+ const AAMDNodes &AAInfo) {
+ addPointer(MemoryLocation(Ptr, Size, AAInfo), AliasSet::NoAccess);
+}
+
+void AliasSetTracker::add(LoadInst *LI) {
+ if (isStrongerThanMonotonic(LI->getOrdering()))
+ return addUnknown(LI);
+ addPointer(MemoryLocation::get(LI), AliasSet::RefAccess);
+}
+
+void AliasSetTracker::add(StoreInst *SI) {
+ if (isStrongerThanMonotonic(SI->getOrdering()))
+ return addUnknown(SI);
+ addPointer(MemoryLocation::get(SI), AliasSet::ModAccess);
+}
+
+void AliasSetTracker::add(VAArgInst *VAAI) {
+ addPointer(MemoryLocation::get(VAAI), AliasSet::ModRefAccess);
+}
+
+void AliasSetTracker::add(AnyMemSetInst *MSI) {
+ addPointer(MemoryLocation::getForDest(MSI), AliasSet::ModAccess);
+}
+
+void AliasSetTracker::add(AnyMemTransferInst *MTI) {
+ addPointer(MemoryLocation::getForDest(MTI), AliasSet::ModAccess);
+ addPointer(MemoryLocation::getForSource(MTI), AliasSet::RefAccess);
+}
+
+void AliasSetTracker::addUnknown(Instruction *Inst) {
+ if (isa<DbgInfoIntrinsic>(Inst))
+ return; // Ignore DbgInfo Intrinsics.
+
+ if (auto *II = dyn_cast<IntrinsicInst>(Inst)) {
+ // These intrinsics will show up as affecting memory, but they are just
+ // markers.
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ // FIXME: Add lifetime/invariant intrinsics (See: PR30807).
+ case Intrinsic::assume:
+ case Intrinsic::sideeffect:
+ return;
+ }
+ }
+ if (!Inst->mayReadOrWriteMemory())
+ return; // doesn't alias anything
+
+ if (AliasSet *AS = findAliasSetForUnknownInst(Inst)) {
+ AS->addUnknownInst(Inst, AA);
+ return;
+ }
+ AliasSets.push_back(new AliasSet());
+ AliasSets.back().addUnknownInst(Inst, AA);
+}
+
+void AliasSetTracker::add(Instruction *I) {
+ // Dispatch to one of the other add methods.
+ if (LoadInst *LI = dyn_cast<LoadInst>(I))
+ return add(LI);
+ if (StoreInst *SI = dyn_cast<StoreInst>(I))
+ return add(SI);
+ if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I))
+ return add(VAAI);
+ if (AnyMemSetInst *MSI = dyn_cast<AnyMemSetInst>(I))
+ return add(MSI);
+ if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(I))
+ return add(MTI);
+
+ // Handle all calls with known mod/ref sets genericall
+ if (auto *Call = dyn_cast<CallBase>(I))
+ if (Call->onlyAccessesArgMemory()) {
+ auto getAccessFromModRef = [](ModRefInfo MRI) {
+ if (isRefSet(MRI) && isModSet(MRI))
+ return AliasSet::ModRefAccess;
+ else if (isModSet(MRI))
+ return AliasSet::ModAccess;
+ else if (isRefSet(MRI))
+ return AliasSet::RefAccess;
+ else
+ return AliasSet::NoAccess;
+ };
+
+ ModRefInfo CallMask = createModRefInfo(AA.getModRefBehavior(Call));
+
+ // Some intrinsics are marked as modifying memory for control flow
+ // modelling purposes, but don't actually modify any specific memory
+ // location.
+ using namespace PatternMatch;
+ if (Call->use_empty() &&
+ match(Call, m_Intrinsic<Intrinsic::invariant_start>()))
+ CallMask = clearMod(CallMask);
+
+ for (auto IdxArgPair : enumerate(Call->args())) {
+ int ArgIdx = IdxArgPair.index();
+ const Value *Arg = IdxArgPair.value();
+ if (!Arg->getType()->isPointerTy())
+ continue;
+ MemoryLocation ArgLoc =
+ MemoryLocation::getForArgument(Call, ArgIdx, nullptr);
+ ModRefInfo ArgMask = AA.getArgModRefInfo(Call, ArgIdx);
+ ArgMask = intersectModRef(CallMask, ArgMask);
+ if (!isNoModRef(ArgMask))
+ addPointer(ArgLoc, getAccessFromModRef(ArgMask));
+ }
+ return;
+ }
+
+ return addUnknown(I);
+}
+
+void AliasSetTracker::add(BasicBlock &BB) {
+ for (auto &I : BB)
+ add(&I);
+}
+
+void AliasSetTracker::add(const AliasSetTracker &AST) {
+ assert(&AA == &AST.AA &&
+ "Merging AliasSetTracker objects with different Alias Analyses!");
+
+ // Loop over all of the alias sets in AST, adding the pointers contained
+ // therein into the current alias sets. This can cause alias sets to be
+ // merged together in the current AST.
+ for (const AliasSet &AS : AST) {
+ if (AS.Forward)
+ continue; // Ignore forwarding alias sets
+
+ // If there are any call sites in the alias set, add them to this AST.
+ for (unsigned i = 0, e = AS.UnknownInsts.size(); i != e; ++i)
+ if (auto *Inst = AS.getUnknownInst(i))
+ add(Inst);
+
+ // Loop over all of the pointers in this alias set.
+ for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI)
+ addPointer(
+ MemoryLocation(ASI.getPointer(), ASI.getSize(), ASI.getAAInfo()),
+ (AliasSet::AccessLattice)AS.Access);
+ }
+}
+
+void AliasSetTracker::addAllInstructionsInLoopUsingMSSA() {
+ assert(MSSA && L && "MSSA and L must be available");
+ for (const BasicBlock *BB : L->blocks())
+ if (auto *Accesses = MSSA->getBlockAccesses(BB))
+ for (auto &Access : *Accesses)
+ if (auto *MUD = dyn_cast<MemoryUseOrDef>(&Access))
+ add(MUD->getMemoryInst());
+}
+
+// deleteValue method - This method is used to remove a pointer value from the
+// AliasSetTracker entirely. It should be used when an instruction is deleted
+// from the program to update the AST. If you don't use this, you would have
+// dangling pointers to deleted instructions.
+//
+void AliasSetTracker::deleteValue(Value *PtrVal) {
+ // First, look up the PointerRec for this pointer.
+ PointerMapType::iterator I = PointerMap.find_as(PtrVal);
+ if (I == PointerMap.end()) return; // Noop
+
+ // If we found one, remove the pointer from the alias set it is in.
+ AliasSet::PointerRec *PtrValEnt = I->second;
+ AliasSet *AS = PtrValEnt->getAliasSet(*this);
+
+ // Unlink and delete from the list of values.
+ PtrValEnt->eraseFromList();
+
+ if (AS->Alias == AliasSet::SetMayAlias) {
+ AS->SetSize--;
+ TotalMayAliasSetSize--;
+ }
+
+ // Stop using the alias set.
+ AS->dropRef(*this);
+
+ PointerMap.erase(I);
+}
+
+// copyValue - This method should be used whenever a preexisting value in the
+// program is copied or cloned, introducing a new value. Note that it is ok for
+// clients that use this method to introduce the same value multiple times: if
+// the tracker already knows about a value, it will ignore the request.
+//
+void AliasSetTracker::copyValue(Value *From, Value *To) {
+ // First, look up the PointerRec for this pointer.
+ PointerMapType::iterator I = PointerMap.find_as(From);
+ if (I == PointerMap.end())
+ return; // Noop
+ assert(I->second->hasAliasSet() && "Dead entry?");
+
+ AliasSet::PointerRec &Entry = getEntryFor(To);
+ if (Entry.hasAliasSet()) return; // Already in the tracker!
+
+ // getEntryFor above may invalidate iterator \c I, so reinitialize it.
+ I = PointerMap.find_as(From);
+ // Add it to the alias set it aliases...
+ AliasSet *AS = I->second->getAliasSet(*this);
+ AS->addPointer(*this, Entry, I->second->getSize(), I->second->getAAInfo(),
+ true, true);
+}
+
+AliasSet &AliasSetTracker::mergeAllAliasSets() {
+ assert(!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold) &&
+ "Full merge should happen once, when the saturation threshold is "
+ "reached");
+
+ // Collect all alias sets, so that we can drop references with impunity
+ // without worrying about iterator invalidation.
+ std::vector<AliasSet *> ASVector;
+ ASVector.reserve(SaturationThreshold);
+ for (iterator I = begin(), E = end(); I != E; I++)
+ ASVector.push_back(&*I);
+
+ // Copy all instructions and pointers into a new set, and forward all other
+ // sets to it.
+ AliasSets.push_back(new AliasSet());
+ AliasAnyAS = &AliasSets.back();
+ AliasAnyAS->Alias = AliasSet::SetMayAlias;
+ AliasAnyAS->Access = AliasSet::ModRefAccess;
+ AliasAnyAS->AliasAny = true;
+
+ for (auto Cur : ASVector) {
+ // If Cur was already forwarding, just forward to the new AS instead.
+ AliasSet *FwdTo = Cur->Forward;
+ if (FwdTo) {
+ Cur->Forward = AliasAnyAS;
+ AliasAnyAS->addRef();
+ FwdTo->dropRef(*this);
+ continue;
+ }
+
+ // Otherwise, perform the actual merge.
+ AliasAnyAS->mergeSetIn(*Cur, *this);
+ }
+
+ return *AliasAnyAS;
+}
+
+AliasSet &AliasSetTracker::addPointer(MemoryLocation Loc,
+ AliasSet::AccessLattice E) {
+ AliasSet &AS = getAliasSetFor(Loc);
+ AS.Access |= E;
+
+ if (!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold)) {
+ // The AST is now saturated. From here on, we conservatively consider all
+ // pointers to alias each-other.
+ return mergeAllAliasSets();
+ }
+
+ return AS;
+}
+
+//===----------------------------------------------------------------------===//
+// AliasSet/AliasSetTracker Printing Support
+//===----------------------------------------------------------------------===//
+
+void AliasSet::print(raw_ostream &OS) const {
+ OS << " AliasSet[" << (const void*)this << ", " << RefCount << "] ";
+ OS << (Alias == SetMustAlias ? "must" : "may") << " alias, ";
+ switch (Access) {
+ case NoAccess: OS << "No access "; break;
+ case RefAccess: OS << "Ref "; break;
+ case ModAccess: OS << "Mod "; break;
+ case ModRefAccess: OS << "Mod/Ref "; break;
+ default: llvm_unreachable("Bad value for Access!");
+ }
+ if (Forward)
+ OS << " forwarding to " << (void*)Forward;
+
+ if (!empty()) {
+ OS << "Pointers: ";
+ for (iterator I = begin(), E = end(); I != E; ++I) {
+ if (I != begin()) OS << ", ";
+ I.getPointer()->printAsOperand(OS << "(");
+ if (I.getSize() == LocationSize::unknown())
+ OS << ", unknown)";
+ else
+ OS << ", " << I.getSize() << ")";
+ }
+ }
+ if (!UnknownInsts.empty()) {
+ OS << "\n " << UnknownInsts.size() << " Unknown instructions: ";
+ for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) {
+ if (i) OS << ", ";
+ if (auto *I = getUnknownInst(i)) {
+ if (I->hasName())
+ I->printAsOperand(OS);
+ else
+ I->print(OS);
+ }
+ }
+ }
+ OS << "\n";
+}
+
+void AliasSetTracker::print(raw_ostream &OS) const {
+ OS << "Alias Set Tracker: " << AliasSets.size();
+ if (AliasAnyAS)
+ OS << " (Saturated)";
+ OS << " alias sets for " << PointerMap.size() << " pointer values.\n";
+ for (const AliasSet &AS : *this)
+ AS.print(OS);
+ OS << "\n";
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void AliasSet::dump() const { print(dbgs()); }
+LLVM_DUMP_METHOD void AliasSetTracker::dump() const { print(dbgs()); }
+#endif
+
+//===----------------------------------------------------------------------===//
+// ASTCallbackVH Class Implementation
+//===----------------------------------------------------------------------===//
+
+void AliasSetTracker::ASTCallbackVH::deleted() {
+ assert(AST && "ASTCallbackVH called with a null AliasSetTracker!");
+ AST->deleteValue(getValPtr());
+ // this now dangles!
+}
+
+void AliasSetTracker::ASTCallbackVH::allUsesReplacedWith(Value *V) {
+ AST->copyValue(getValPtr(), V);
+}
+
+AliasSetTracker::ASTCallbackVH::ASTCallbackVH(Value *V, AliasSetTracker *ast)
+ : CallbackVH(V), AST(ast) {}
+
+AliasSetTracker::ASTCallbackVH &
+AliasSetTracker::ASTCallbackVH::operator=(Value *V) {
+ return *this = ASTCallbackVH(V, AST);
+}
+
+//===----------------------------------------------------------------------===//
+// AliasSetPrinter Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+ class AliasSetPrinter : public FunctionPass {
+ AliasSetTracker *Tracker;
+
+ public:
+ static char ID; // Pass identification, replacement for typeid
+
+ AliasSetPrinter() : FunctionPass(ID) {
+ initializeAliasSetPrinterPass(*PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ AU.addRequired<AAResultsWrapperPass>();
+ }
+
+ bool runOnFunction(Function &F) override {
+ auto &AAWP = getAnalysis<AAResultsWrapperPass>();
+ Tracker = new AliasSetTracker(AAWP.getAAResults());
+ errs() << "Alias sets for function '" << F.getName() << "':\n";
+ for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
+ Tracker->add(&*I);
+ Tracker->print(errs());
+ delete Tracker;
+ return false;
+ }
+ };
+
+} // end anonymous namespace
+
+char AliasSetPrinter::ID = 0;
+
+INITIALIZE_PASS_BEGIN(AliasSetPrinter, "print-alias-sets",
+ "Alias Set Printer", false, true)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(AliasSetPrinter, "print-alias-sets",
+ "Alias Set Printer", false, true)
diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp
new file mode 100644
index 000000000000..af718526684b
--- /dev/null
+++ b/llvm/lib/Analysis/Analysis.cpp
@@ -0,0 +1,138 @@
+//===-- Analysis.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm-c/Analysis.h"
+#include "llvm-c/Initialization.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstring>
+
+using namespace llvm;
+
+/// initializeAnalysis - Initialize all passes linked into the Analysis library.
+void llvm::initializeAnalysis(PassRegistry &Registry) {
+ initializeAAEvalLegacyPassPass(Registry);
+ initializeAliasSetPrinterPass(Registry);
+ initializeBasicAAWrapperPassPass(Registry);
+ initializeBlockFrequencyInfoWrapperPassPass(Registry);
+ initializeBranchProbabilityInfoWrapperPassPass(Registry);
+ initializeCallGraphWrapperPassPass(Registry);
+ initializeCallGraphDOTPrinterPass(Registry);
+ initializeCallGraphPrinterLegacyPassPass(Registry);
+ initializeCallGraphViewerPass(Registry);
+ initializeCostModelAnalysisPass(Registry);
+ initializeCFGViewerLegacyPassPass(Registry);
+ initializeCFGPrinterLegacyPassPass(Registry);
+ initializeCFGOnlyViewerLegacyPassPass(Registry);
+ initializeCFGOnlyPrinterLegacyPassPass(Registry);
+ initializeCFLAndersAAWrapperPassPass(Registry);
+ initializeCFLSteensAAWrapperPassPass(Registry);
+ initializeDependenceAnalysisWrapperPassPass(Registry);
+ initializeDelinearizationPass(Registry);
+ initializeDemandedBitsWrapperPassPass(Registry);
+ initializeDominanceFrontierWrapperPassPass(Registry);
+ initializeDomViewerPass(Registry);
+ initializeDomPrinterPass(Registry);
+ initializeDomOnlyViewerPass(Registry);
+ initializePostDomViewerPass(Registry);
+ initializeDomOnlyPrinterPass(Registry);
+ initializePostDomPrinterPass(Registry);
+ initializePostDomOnlyViewerPass(Registry);
+ initializePostDomOnlyPrinterPass(Registry);
+ initializeAAResultsWrapperPassPass(Registry);
+ initializeGlobalsAAWrapperPassPass(Registry);
+ initializeIVUsersWrapperPassPass(Registry);
+ initializeInstCountPass(Registry);
+ initializeIntervalPartitionPass(Registry);
+ initializeLazyBranchProbabilityInfoPassPass(Registry);
+ initializeLazyBlockFrequencyInfoPassPass(Registry);
+ initializeLazyValueInfoWrapperPassPass(Registry);
+ initializeLazyValueInfoPrinterPass(Registry);
+ initializeLegacyDivergenceAnalysisPass(Registry);
+ initializeLintPass(Registry);
+ initializeLoopInfoWrapperPassPass(Registry);
+ initializeMemDepPrinterPass(Registry);
+ initializeMemDerefPrinterPass(Registry);
+ initializeMemoryDependenceWrapperPassPass(Registry);
+ initializeModuleDebugInfoPrinterPass(Registry);
+ initializeModuleSummaryIndexWrapperPassPass(Registry);
+ initializeMustExecutePrinterPass(Registry);
+ initializeMustBeExecutedContextPrinterPass(Registry);
+ initializeObjCARCAAWrapperPassPass(Registry);
+ initializeOptimizationRemarkEmitterWrapperPassPass(Registry);
+ initializePhiValuesWrapperPassPass(Registry);
+ initializePostDominatorTreeWrapperPassPass(Registry);
+ initializeRegionInfoPassPass(Registry);
+ initializeRegionViewerPass(Registry);
+ initializeRegionPrinterPass(Registry);
+ initializeRegionOnlyViewerPass(Registry);
+ initializeRegionOnlyPrinterPass(Registry);
+ initializeSCEVAAWrapperPassPass(Registry);
+ initializeScalarEvolutionWrapperPassPass(Registry);
+ initializeStackSafetyGlobalInfoWrapperPassPass(Registry);
+ initializeStackSafetyInfoWrapperPassPass(Registry);
+ initializeTargetTransformInfoWrapperPassPass(Registry);
+ initializeTypeBasedAAWrapperPassPass(Registry);
+ initializeScopedNoAliasAAWrapperPassPass(Registry);
+ initializeLCSSAVerificationPassPass(Registry);
+ initializeMemorySSAWrapperPassPass(Registry);
+ initializeMemorySSAPrinterLegacyPassPass(Registry);
+}
+
+void LLVMInitializeAnalysis(LLVMPassRegistryRef R) {
+ initializeAnalysis(*unwrap(R));
+}
+
+void LLVMInitializeIPA(LLVMPassRegistryRef R) {
+ initializeAnalysis(*unwrap(R));
+}
+
+LLVMBool LLVMVerifyModule(LLVMModuleRef M, LLVMVerifierFailureAction Action,
+ char **OutMessages) {
+ raw_ostream *DebugOS = Action != LLVMReturnStatusAction ? &errs() : nullptr;
+ std::string Messages;
+ raw_string_ostream MsgsOS(Messages);
+
+ LLVMBool Result = verifyModule(*unwrap(M), OutMessages ? &MsgsOS : DebugOS);
+
+ // Duplicate the output to stderr.
+ if (DebugOS && OutMessages)
+ *DebugOS << MsgsOS.str();
+
+ if (Action == LLVMAbortProcessAction && Result)
+ report_fatal_error("Broken module found, compilation aborted!");
+
+ if (OutMessages)
+ *OutMessages = strdup(MsgsOS.str().c_str());
+
+ return Result;
+}
+
+LLVMBool LLVMVerifyFunction(LLVMValueRef Fn, LLVMVerifierFailureAction Action) {
+ LLVMBool Result = verifyFunction(
+ *unwrap<Function>(Fn), Action != LLVMReturnStatusAction ? &errs()
+ : nullptr);
+
+ if (Action == LLVMAbortProcessAction && Result)
+ report_fatal_error("Broken function found, compilation aborted!");
+
+ return Result;
+}
+
+void LLVMViewFunctionCFG(LLVMValueRef Fn) {
+ Function *F = unwrap<Function>(Fn);
+ F->viewCFG();
+}
+
+void LLVMViewFunctionCFGOnly(LLVMValueRef Fn) {
+ Function *F = unwrap<Function>(Fn);
+ F->viewCFGOnly();
+}
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
new file mode 100644
index 000000000000..129944743c5e
--- /dev/null
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -0,0 +1,302 @@
+//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass that keeps track of @llvm.assume intrinsics in
+// the functions of a module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <utility>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+static cl::opt<bool>
+ VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
+ cl::desc("Enable verification of assumption cache"),
+ cl::init(false));
+
+SmallVector<WeakTrackingVH, 1> &
+AssumptionCache::getOrInsertAffectedValues(Value *V) {
+ // Try using find_as first to avoid creating extra value handles just for the
+ // purpose of doing the lookup.
+ auto AVI = AffectedValues.find_as(V);
+ if (AVI != AffectedValues.end())
+ return AVI->second;
+
+ auto AVIP = AffectedValues.insert(
+ {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()});
+ return AVIP.first->second;
+}
+
+static void findAffectedValues(CallInst *CI,
+ SmallVectorImpl<Value *> &Affected) {
+ // Note: This code must be kept in-sync with the code in
+ // computeKnownBitsFromAssume in ValueTracking.
+
+ auto AddAffected = [&Affected](Value *V) {
+ if (isa<Argument>(V)) {
+ Affected.push_back(V);
+ } else if (auto *I = dyn_cast<Instruction>(V)) {
+ Affected.push_back(I);
+
+ // Peek through unary operators to find the source of the condition.
+ Value *Op;
+ if (match(I, m_BitCast(m_Value(Op))) ||
+ match(I, m_PtrToInt(m_Value(Op))) ||
+ match(I, m_Not(m_Value(Op)))) {
+ if (isa<Instruction>(Op) || isa<Argument>(Op))
+ Affected.push_back(Op);
+ }
+ }
+ };
+
+ Value *Cond = CI->getArgOperand(0), *A, *B;
+ AddAffected(Cond);
+
+ CmpInst::Predicate Pred;
+ if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+ AddAffected(A);
+ AddAffected(B);
+
+ if (Pred == ICmpInst::ICMP_EQ) {
+ // For equality comparisons, we handle the case of bit inversion.
+ auto AddAffectedFromEq = [&AddAffected](Value *V) {
+ Value *A;
+ if (match(V, m_Not(m_Value(A)))) {
+ AddAffected(A);
+ V = A;
+ }
+
+ Value *B;
+ ConstantInt *C;
+ // (A & B) or (A | B) or (A ^ B).
+ if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
+ AddAffected(A);
+ AddAffected(B);
+ // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
+ } else if (match(V, m_Shift(m_Value(A), m_ConstantInt(C)))) {
+ AddAffected(A);
+ }
+ };
+
+ AddAffectedFromEq(A);
+ AddAffectedFromEq(B);
+ }
+ }
+}
+
+void AssumptionCache::updateAffectedValues(CallInst *CI) {
+ SmallVector<Value *, 16> Affected;
+ findAffectedValues(CI, Affected);
+
+ for (auto &AV : Affected) {
+ auto &AVV = getOrInsertAffectedValues(AV);
+ if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
+ AVV.push_back(CI);
+ }
+}
+
+void AssumptionCache::unregisterAssumption(CallInst *CI) {
+ SmallVector<Value *, 16> Affected;
+ findAffectedValues(CI, Affected);
+
+ for (auto &AV : Affected) {
+ auto AVI = AffectedValues.find_as(AV);
+ if (AVI != AffectedValues.end())
+ AffectedValues.erase(AVI);
+ }
+
+ AssumeHandles.erase(
+ remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }),
+ AssumeHandles.end());
+}
+
+void AssumptionCache::AffectedValueCallbackVH::deleted() {
+ auto AVI = AC->AffectedValues.find(getValPtr());
+ if (AVI != AC->AffectedValues.end())
+ AC->AffectedValues.erase(AVI);
+ // 'this' now dangles!
+}
+
+void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
+ auto &NAVV = getOrInsertAffectedValues(NV);
+ auto AVI = AffectedValues.find(OV);
+ if (AVI == AffectedValues.end())
+ return;
+
+ for (auto &A : AVI->second)
+ if (std::find(NAVV.begin(), NAVV.end(), A) == NAVV.end())
+ NAVV.push_back(A);
+ AffectedValues.erase(OV);
+}
+
+void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
+ if (!isa<Instruction>(NV) && !isa<Argument>(NV))
+ return;
+
+ // Any assumptions that affected this value now affect the new value.
+
+ AC->transferAffectedValuesInCache(getValPtr(), NV);
+ // 'this' now might dangle! If the AffectedValues map was resized to add an
+ // entry for NV then this object might have been destroyed in favor of some
+ // copy in the grown map.
+}
+
+void AssumptionCache::scanFunction() {
+ assert(!Scanned && "Tried to scan the function twice!");
+ assert(AssumeHandles.empty() && "Already have assumes when scanning!");
+
+ // Go through all instructions in all blocks, add all calls to @llvm.assume
+ // to this cache.
+ for (BasicBlock &B : F)
+ for (Instruction &II : B)
+ if (match(&II, m_Intrinsic<Intrinsic::assume>()))
+ AssumeHandles.push_back(&II);
+
+ // Mark the scan as complete.
+ Scanned = true;
+
+ // Update affected values.
+ for (auto &A : AssumeHandles)
+ updateAffectedValues(cast<CallInst>(A));
+}
+
+void AssumptionCache::registerAssumption(CallInst *CI) {
+ assert(match(CI, m_Intrinsic<Intrinsic::assume>()) &&
+ "Registered call does not call @llvm.assume");
+
+ // If we haven't scanned the function yet, just drop this assumption. It will
+ // be found when we scan later.
+ if (!Scanned)
+ return;
+
+ AssumeHandles.push_back(CI);
+
+#ifndef NDEBUG
+ assert(CI->getParent() &&
+ "Cannot register @llvm.assume call not in a basic block");
+ assert(&F == CI->getParent()->getParent() &&
+ "Cannot register @llvm.assume call not in this function");
+
+ // We expect the number of assumptions to be small, so in an asserts build
+ // check that we don't accumulate duplicates and that all assumptions point
+ // to the same function.
+ SmallPtrSet<Value *, 16> AssumptionSet;
+ for (auto &VH : AssumeHandles) {
+ if (!VH)
+ continue;
+
+ assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
+ "Cached assumption not inside this function!");
+ assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
+ "Cached something other than a call to @llvm.assume!");
+ assert(AssumptionSet.insert(VH).second &&
+ "Cache contains multiple copies of a call!");
+ }
+#endif
+
+ updateAffectedValues(CI);
+}
+
+AnalysisKey AssumptionAnalysis::Key;
+
+PreservedAnalyses AssumptionPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
+
+ OS << "Cached assumptions for function: " << F.getName() << "\n";
+ for (auto &VH : AC.assumptions())
+ if (VH)
+ OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
+
+ return PreservedAnalyses::all();
+}
+
+void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
+ auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
+ if (I != ACT->AssumptionCaches.end())
+ ACT->AssumptionCaches.erase(I);
+ // 'this' now dangles!
+}
+
+AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
+ // We probe the function map twice to try and avoid creating a value handle
+ // around the function in common cases. This makes insertion a bit slower,
+ // but if we have to insert we're going to scan the whole function so that
+ // shouldn't matter.
+ auto I = AssumptionCaches.find_as(&F);
+ if (I != AssumptionCaches.end())
+ return *I->second;
+
+ // Ok, build a new cache by scanning the function, insert it and the value
+ // handle into our map, and return the newly populated cache.
+ auto IP = AssumptionCaches.insert(std::make_pair(
+ FunctionCallbackVH(&F, this), std::make_unique<AssumptionCache>(F)));
+ assert(IP.second && "Scanning function already in the map?");
+ return *IP.first->second;
+}
+
+AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
+ auto I = AssumptionCaches.find_as(&F);
+ if (I != AssumptionCaches.end())
+ return I->second.get();
+ return nullptr;
+}
+
+void AssumptionCacheTracker::verifyAnalysis() const {
+ // FIXME: In the long term the verifier should not be controllable with a
+ // flag. We should either fix all passes to correctly update the assumption
+ // cache and enable the verifier unconditionally or somehow arrange for the
+ // assumption list to be updated automatically by passes.
+ if (!VerifyAssumptionCache)
+ return;
+
+ SmallPtrSet<const CallInst *, 4> AssumptionSet;
+ for (const auto &I : AssumptionCaches) {
+ for (auto &VH : I.second->assumptions())
+ if (VH)
+ AssumptionSet.insert(cast<CallInst>(VH));
+
+ for (const BasicBlock &B : cast<Function>(*I.first))
+ for (const Instruction &II : B)
+ if (match(&II, m_Intrinsic<Intrinsic::assume>()) &&
+ !AssumptionSet.count(cast<CallInst>(&II)))
+ report_fatal_error("Assumption in scanned function not in cache");
+ }
+}
+
+AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
+ initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
+}
+
+AssumptionCacheTracker::~AssumptionCacheTracker() = default;
+
+char AssumptionCacheTracker::ID = 0;
+
+INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
+ "Assumption Cache Tracker", false, true)
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
new file mode 100644
index 000000000000..f3c30c258c19
--- /dev/null
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -0,0 +1,2100 @@
+//===- BasicAliasAnalysis.cpp - Stateless Alias Analysis Impl -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the primary stateless implementation of the
+// Alias Analysis interface that implements identities (two different
+// globals cannot alias, etc), but does no stateful analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/PhiValues.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/KnownBits.h"
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+#include <utility>
+
+#define DEBUG_TYPE "basicaa"
+
+using namespace llvm;
+
+/// Enable analysis of recursive PHI nodes.
+static cl::opt<bool> EnableRecPhiAnalysis("basicaa-recphi", cl::Hidden,
+ cl::init(false));
+
+/// By default, even on 32-bit architectures we use 64-bit integers for
+/// calculations. This will allow us to more-aggressively decompose indexing
+/// expressions calculated using i64 values (e.g., long long in C) which is
+/// common enough to worry about.
+static cl::opt<bool> ForceAtLeast64Bits("basicaa-force-at-least-64b",
+ cl::Hidden, cl::init(true));
+static cl::opt<bool> DoubleCalcBits("basicaa-double-calc-bits",
+ cl::Hidden, cl::init(false));
+
+/// SearchLimitReached / SearchTimes shows how often the limit of
+/// to decompose GEPs is reached. It will affect the precision
+/// of basic alias analysis.
+STATISTIC(SearchLimitReached, "Number of times the limit to "
+ "decompose GEPs is reached");
+STATISTIC(SearchTimes, "Number of times a GEP is decomposed");
+
+/// Cutoff after which to stop analysing a set of phi nodes potentially involved
+/// in a cycle. Because we are analysing 'through' phi nodes, we need to be
+/// careful with value equivalence. We use reachability to make sure a value
+/// cannot be involved in a cycle.
+const unsigned MaxNumPhiBBsValueReachabilityCheck = 20;
+
+// The max limit of the search depth in DecomposeGEPExpression() and
+// GetUnderlyingObject(), both functions need to use the same search
+// depth otherwise the algorithm in aliasGEP will assert.
+static const unsigned MaxLookupSearchDepth = 6;
+
+bool BasicAAResult::invalidate(Function &Fn, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // We don't care if this analysis itself is preserved, it has no state. But
+ // we need to check that the analyses it depends on have been. Note that we
+ // may be created without handles to some analyses and in that case don't
+ // depend on them.
+ if (Inv.invalidate<AssumptionAnalysis>(Fn, PA) ||
+ (DT && Inv.invalidate<DominatorTreeAnalysis>(Fn, PA)) ||
+ (LI && Inv.invalidate<LoopAnalysis>(Fn, PA)) ||
+ (PV && Inv.invalidate<PhiValuesAnalysis>(Fn, PA)))
+ return true;
+
+ // Otherwise this analysis result remains valid.
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Useful predicates
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the pointer is to a function-local object that never
+/// escapes from the function.
+static bool isNonEscapingLocalObject(
+ const Value *V,
+ SmallDenseMap<const Value *, bool, 8> *IsCapturedCache = nullptr) {
+ SmallDenseMap<const Value *, bool, 8>::iterator CacheIt;
+ if (IsCapturedCache) {
+ bool Inserted;
+ std::tie(CacheIt, Inserted) = IsCapturedCache->insert({V, false});
+ if (!Inserted)
+ // Found cached result, return it!
+ return CacheIt->second;
+ }
+
+ // If this is a local allocation, check to see if it escapes.
+ if (isa<AllocaInst>(V) || isNoAliasCall(V)) {
+ // Set StoreCaptures to True so that we can assume in our callers that the
+ // pointer is not the result of a load instruction. Currently
+ // PointerMayBeCaptured doesn't have any special analysis for the
+ // StoreCaptures=false case; if it did, our callers could be refined to be
+ // more precise.
+ auto Ret = !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true);
+ if (IsCapturedCache)
+ CacheIt->second = Ret;
+ return Ret;
+ }
+
+ // If this is an argument that corresponds to a byval or noalias argument,
+ // then it has not escaped before entering the function. Check if it escapes
+ // inside the function.
+ if (const Argument *A = dyn_cast<Argument>(V))
+ if (A->hasByValAttr() || A->hasNoAliasAttr()) {
+ // Note even if the argument is marked nocapture, we still need to check
+ // for copies made inside the function. The nocapture attribute only
+ // specifies that there are no copies made that outlive the function.
+ auto Ret = !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true);
+ if (IsCapturedCache)
+ CacheIt->second = Ret;
+ return Ret;
+ }
+
+ return false;
+}
+
+/// Returns true if the pointer is one which would have been considered an
+/// escape by isNonEscapingLocalObject.
+static bool isEscapeSource(const Value *V) {
+ if (isa<CallBase>(V))
+ return true;
+
+ if (isa<Argument>(V))
+ return true;
+
+ // The load case works because isNonEscapingLocalObject considers all
+ // stores to be escapes (it passes true for the StoreCaptures argument
+ // to PointerMayBeCaptured).
+ if (isa<LoadInst>(V))
+ return true;
+
+ return false;
+}
+
+/// Returns the size of the object specified by V or UnknownSize if unknown.
+static uint64_t getObjectSize(const Value *V, const DataLayout &DL,
+ const TargetLibraryInfo &TLI,
+ bool NullIsValidLoc,
+ bool RoundToAlign = false) {
+ uint64_t Size;
+ ObjectSizeOpts Opts;
+ Opts.RoundToAlign = RoundToAlign;
+ Opts.NullIsUnknownSize = NullIsValidLoc;
+ if (getObjectSize(V, Size, DL, &TLI, Opts))
+ return Size;
+ return MemoryLocation::UnknownSize;
+}
+
+/// Returns true if we can prove that the object specified by V is smaller than
+/// Size.
+static bool isObjectSmallerThan(const Value *V, uint64_t Size,
+ const DataLayout &DL,
+ const TargetLibraryInfo &TLI,
+ bool NullIsValidLoc) {
+ // Note that the meanings of the "object" are slightly different in the
+ // following contexts:
+ // c1: llvm::getObjectSize()
+ // c2: llvm.objectsize() intrinsic
+ // c3: isObjectSmallerThan()
+ // c1 and c2 share the same meaning; however, the meaning of "object" in c3
+ // refers to the "entire object".
+ //
+ // Consider this example:
+ // char *p = (char*)malloc(100)
+ // char *q = p+80;
+ //
+ // In the context of c1 and c2, the "object" pointed by q refers to the
+ // stretch of memory of q[0:19]. So, getObjectSize(q) should return 20.
+ //
+ // However, in the context of c3, the "object" refers to the chunk of memory
+ // being allocated. So, the "object" has 100 bytes, and q points to the middle
+ // the "object". In case q is passed to isObjectSmallerThan() as the 1st
+ // parameter, before the llvm::getObjectSize() is called to get the size of
+ // entire object, we should:
+ // - either rewind the pointer q to the base-address of the object in
+ // question (in this case rewind to p), or
+ // - just give up. It is up to caller to make sure the pointer is pointing
+ // to the base address the object.
+ //
+ // We go for 2nd option for simplicity.
+ if (!isIdentifiedObject(V))
+ return false;
+
+ // This function needs to use the aligned object size because we allow
+ // reads a bit past the end given sufficient alignment.
+ uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc,
+ /*RoundToAlign*/ true);
+
+ return ObjectSize != MemoryLocation::UnknownSize && ObjectSize < Size;
+}
+
+/// Return the minimal extent from \p V to the end of the underlying object,
+/// assuming the result is used in an aliasing query. E.g., we do use the query
+/// location size and the fact that null pointers cannot alias here.
+static uint64_t getMinimalExtentFrom(const Value &V,
+ const LocationSize &LocSize,
+ const DataLayout &DL,
+ bool NullIsValidLoc) {
+ // If we have dereferenceability information we know a lower bound for the
+ // extent as accesses for a lower offset would be valid. We need to exclude
+ // the "or null" part if null is a valid pointer.
+ bool CanBeNull;
+ uint64_t DerefBytes = V.getPointerDereferenceableBytes(DL, CanBeNull);
+ DerefBytes = (CanBeNull && NullIsValidLoc) ? 0 : DerefBytes;
+ // If queried with a precise location size, we assume that location size to be
+ // accessed, thus valid.
+ if (LocSize.isPrecise())
+ DerefBytes = std::max(DerefBytes, LocSize.getValue());
+ return DerefBytes;
+}
+
+/// Returns true if we can prove that the object specified by V has size Size.
+static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL,
+ const TargetLibraryInfo &TLI, bool NullIsValidLoc) {
+ uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc);
+ return ObjectSize != MemoryLocation::UnknownSize && ObjectSize == Size;
+}
+
+//===----------------------------------------------------------------------===//
+// GetElementPtr Instruction Decomposition and Analysis
+//===----------------------------------------------------------------------===//
+
+/// Analyzes the specified value as a linear expression: "A*V + B", where A and
+/// B are constant integers.
+///
+/// Returns the scale and offset values as APInts and return V as a Value*, and
+/// return whether we looked through any sign or zero extends. The incoming
+/// Value is known to have IntegerType, and it may already be sign or zero
+/// extended.
+///
+/// Note that this looks through extends, so the high bits may not be
+/// represented in the result.
+/*static*/ const Value *BasicAAResult::GetLinearExpression(
+ const Value *V, APInt &Scale, APInt &Offset, unsigned &ZExtBits,
+ unsigned &SExtBits, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, DominatorTree *DT, bool &NSW, bool &NUW) {
+ assert(V->getType()->isIntegerTy() && "Not an integer value");
+
+ // Limit our recursion depth.
+ if (Depth == 6) {
+ Scale = 1;
+ Offset = 0;
+ return V;
+ }
+
+ if (const ConstantInt *Const = dyn_cast<ConstantInt>(V)) {
+ // If it's a constant, just convert it to an offset and remove the variable.
+ // If we've been called recursively, the Offset bit width will be greater
+ // than the constant's (the Offset's always as wide as the outermost call),
+ // so we'll zext here and process any extension in the isa<SExtInst> &
+ // isa<ZExtInst> cases below.
+ Offset += Const->getValue().zextOrSelf(Offset.getBitWidth());
+ assert(Scale == 0 && "Constant values don't have a scale");
+ return V;
+ }
+
+ if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) {
+ if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
+ // If we've been called recursively, then Offset and Scale will be wider
+ // than the BOp operands. We'll always zext it here as we'll process sign
+ // extensions below (see the isa<SExtInst> / isa<ZExtInst> cases).
+ APInt RHS = RHSC->getValue().zextOrSelf(Offset.getBitWidth());
+
+ switch (BOp->getOpcode()) {
+ default:
+ // We don't understand this instruction, so we can't decompose it any
+ // further.
+ Scale = 1;
+ Offset = 0;
+ return V;
+ case Instruction::Or:
+ // X|C == X+C if all the bits in C are unset in X. Otherwise we can't
+ // analyze it.
+ if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), DL, 0, AC,
+ BOp, DT)) {
+ Scale = 1;
+ Offset = 0;
+ return V;
+ }
+ LLVM_FALLTHROUGH;
+ case Instruction::Add:
+ V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
+ SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
+ Offset += RHS;
+ break;
+ case Instruction::Sub:
+ V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
+ SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
+ Offset -= RHS;
+ break;
+ case Instruction::Mul:
+ V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
+ SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
+ Offset *= RHS;
+ Scale *= RHS;
+ break;
+ case Instruction::Shl:
+ V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
+ SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
+
+ // We're trying to linearize an expression of the kind:
+ // shl i8 -128, 36
+ // where the shift count exceeds the bitwidth of the type.
+ // We can't decompose this further (the expression would return
+ // a poison value).
+ if (Offset.getBitWidth() < RHS.getLimitedValue() ||
+ Scale.getBitWidth() < RHS.getLimitedValue()) {
+ Scale = 1;
+ Offset = 0;
+ return V;
+ }
+
+ Offset <<= RHS.getLimitedValue();
+ Scale <<= RHS.getLimitedValue();
+ // the semantics of nsw and nuw for left shifts don't match those of
+ // multiplications, so we won't propagate them.
+ NSW = NUW = false;
+ return V;
+ }
+
+ if (isa<OverflowingBinaryOperator>(BOp)) {
+ NUW &= BOp->hasNoUnsignedWrap();
+ NSW &= BOp->hasNoSignedWrap();
+ }
+ return V;
+ }
+ }
+
+ // Since GEP indices are sign extended anyway, we don't care about the high
+ // bits of a sign or zero extended value - just scales and offsets. The
+ // extensions have to be consistent though.
+ if (isa<SExtInst>(V) || isa<ZExtInst>(V)) {
+ Value *CastOp = cast<CastInst>(V)->getOperand(0);
+ unsigned NewWidth = V->getType()->getPrimitiveSizeInBits();
+ unsigned SmallWidth = CastOp->getType()->getPrimitiveSizeInBits();
+ unsigned OldZExtBits = ZExtBits, OldSExtBits = SExtBits;
+ const Value *Result =
+ GetLinearExpression(CastOp, Scale, Offset, ZExtBits, SExtBits, DL,
+ Depth + 1, AC, DT, NSW, NUW);
+
+ // zext(zext(%x)) == zext(%x), and similarly for sext; we'll handle this
+ // by just incrementing the number of bits we've extended by.
+ unsigned ExtendedBy = NewWidth - SmallWidth;
+
+ if (isa<SExtInst>(V) && ZExtBits == 0) {
+ // sext(sext(%x, a), b) == sext(%x, a + b)
+
+ if (NSW) {
+ // We haven't sign-wrapped, so it's valid to decompose sext(%x + c)
+ // into sext(%x) + sext(c). We'll sext the Offset ourselves:
+ unsigned OldWidth = Offset.getBitWidth();
+ Offset = Offset.trunc(SmallWidth).sext(NewWidth).zextOrSelf(OldWidth);
+ } else {
+ // We may have signed-wrapped, so don't decompose sext(%x + c) into
+ // sext(%x) + sext(c)
+ Scale = 1;
+ Offset = 0;
+ Result = CastOp;
+ ZExtBits = OldZExtBits;
+ SExtBits = OldSExtBits;
+ }
+ SExtBits += ExtendedBy;
+ } else {
+ // sext(zext(%x, a), b) = zext(zext(%x, a), b) = zext(%x, a + b)
+
+ if (!NUW) {
+ // We may have unsigned-wrapped, so don't decompose zext(%x + c) into
+ // zext(%x) + zext(c)
+ Scale = 1;
+ Offset = 0;
+ Result = CastOp;
+ ZExtBits = OldZExtBits;
+ SExtBits = OldSExtBits;
+ }
+ ZExtBits += ExtendedBy;
+ }
+
+ return Result;
+ }
+
+ Scale = 1;
+ Offset = 0;
+ return V;
+}
+
+/// To ensure a pointer offset fits in an integer of size PointerSize
+/// (in bits) when that size is smaller than the maximum pointer size. This is
+/// an issue, for example, in particular for 32b pointers with negative indices
+/// that rely on two's complement wrap-arounds for precise alias information
+/// where the maximum pointer size is 64b.
+static APInt adjustToPointerSize(APInt Offset, unsigned PointerSize) {
+ assert(PointerSize <= Offset.getBitWidth() && "Invalid PointerSize!");
+ unsigned ShiftBits = Offset.getBitWidth() - PointerSize;
+ return (Offset << ShiftBits).ashr(ShiftBits);
+}
+
+static unsigned getMaxPointerSize(const DataLayout &DL) {
+ unsigned MaxPointerSize = DL.getMaxPointerSizeInBits();
+ if (MaxPointerSize < 64 && ForceAtLeast64Bits) MaxPointerSize = 64;
+ if (DoubleCalcBits) MaxPointerSize *= 2;
+
+ return MaxPointerSize;
+}
+
+/// If V is a symbolic pointer expression, decompose it into a base pointer
+/// with a constant offset and a number of scaled symbolic offsets.
+///
+/// The scaled symbolic offsets (represented by pairs of a Value* and a scale
+/// in the VarIndices vector) are Value*'s that are known to be scaled by the
+/// specified amount, but which may have other unrepresented high bits. As
+/// such, the gep cannot necessarily be reconstructed from its decomposed form.
+///
+/// When DataLayout is around, this function is capable of analyzing everything
+/// that GetUnderlyingObject can look through. To be able to do that
+/// GetUnderlyingObject and DecomposeGEPExpression must use the same search
+/// depth (MaxLookupSearchDepth). When DataLayout not is around, it just looks
+/// through pointer casts.
+bool BasicAAResult::DecomposeGEPExpression(const Value *V,
+ DecomposedGEP &Decomposed, const DataLayout &DL, AssumptionCache *AC,
+ DominatorTree *DT) {
+ // Limit recursion depth to limit compile time in crazy cases.
+ unsigned MaxLookup = MaxLookupSearchDepth;
+ SearchTimes++;
+
+ unsigned MaxPointerSize = getMaxPointerSize(DL);
+ Decomposed.VarIndices.clear();
+ do {
+ // See if this is a bitcast or GEP.
+ const Operator *Op = dyn_cast<Operator>(V);
+ if (!Op) {
+ // The only non-operator case we can handle are GlobalAliases.
+ if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+ if (!GA->isInterposable()) {
+ V = GA->getAliasee();
+ continue;
+ }
+ }
+ Decomposed.Base = V;
+ return false;
+ }
+
+ if (Op->getOpcode() == Instruction::BitCast ||
+ Op->getOpcode() == Instruction::AddrSpaceCast) {
+ V = Op->getOperand(0);
+ continue;
+ }
+
+ const GEPOperator *GEPOp = dyn_cast<GEPOperator>(Op);
+ if (!GEPOp) {
+ if (const auto *Call = dyn_cast<CallBase>(V)) {
+ // CaptureTracking can know about special capturing properties of some
+ // intrinsics like launder.invariant.group, that can't be expressed with
+ // the attributes, but have properties like returning aliasing pointer.
+ // Because some analysis may assume that nocaptured pointer is not
+ // returned from some special intrinsic (because function would have to
+ // be marked with returns attribute), it is crucial to use this function
+ // because it should be in sync with CaptureTracking. Not using it may
+ // cause weird miscompilations where 2 aliasing pointers are assumed to
+ // noalias.
+ if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) {
+ V = RP;
+ continue;
+ }
+ }
+
+ // If it's not a GEP, hand it off to SimplifyInstruction to see if it
+ // can come up with something. This matches what GetUnderlyingObject does.
+ if (const Instruction *I = dyn_cast<Instruction>(V))
+ // TODO: Get a DominatorTree and AssumptionCache and use them here
+ // (these are both now available in this function, but this should be
+ // updated when GetUnderlyingObject is updated). TLI should be
+ // provided also.
+ if (const Value *Simplified =
+ SimplifyInstruction(const_cast<Instruction *>(I), DL)) {
+ V = Simplified;
+ continue;
+ }
+
+ Decomposed.Base = V;
+ return false;
+ }
+
+ // Don't attempt to analyze GEPs over unsized objects.
+ if (!GEPOp->getSourceElementType()->isSized()) {
+ Decomposed.Base = V;
+ return false;
+ }
+
+ unsigned AS = GEPOp->getPointerAddressSpace();
+ // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices.
+ gep_type_iterator GTI = gep_type_begin(GEPOp);
+ unsigned PointerSize = DL.getPointerSizeInBits(AS);
+ // Assume all GEP operands are constants until proven otherwise.
+ bool GepHasConstantOffset = true;
+ for (User::const_op_iterator I = GEPOp->op_begin() + 1, E = GEPOp->op_end();
+ I != E; ++I, ++GTI) {
+ const Value *Index = *I;
+ // Compute the (potentially symbolic) offset in bytes for this index.
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ // For a struct, add the member offset.
+ unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
+ if (FieldNo == 0)
+ continue;
+
+ Decomposed.StructOffset +=
+ DL.getStructLayout(STy)->getElementOffset(FieldNo);
+ continue;
+ }
+
+ // For an array/pointer, add the element offset, explicitly scaled.
+ if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Index)) {
+ if (CIdx->isZero())
+ continue;
+ Decomposed.OtherOffset +=
+ (DL.getTypeAllocSize(GTI.getIndexedType()) *
+ CIdx->getValue().sextOrSelf(MaxPointerSize))
+ .sextOrTrunc(MaxPointerSize);
+ continue;
+ }
+
+ GepHasConstantOffset = false;
+
+ APInt Scale(MaxPointerSize, DL.getTypeAllocSize(GTI.getIndexedType()));
+ unsigned ZExtBits = 0, SExtBits = 0;
+
+ // If the integer type is smaller than the pointer size, it is implicitly
+ // sign extended to pointer size.
+ unsigned Width = Index->getType()->getIntegerBitWidth();
+ if (PointerSize > Width)
+ SExtBits += PointerSize - Width;
+
+ // Use GetLinearExpression to decompose the index into a C1*V+C2 form.
+ APInt IndexScale(Width, 0), IndexOffset(Width, 0);
+ bool NSW = true, NUW = true;
+ const Value *OrigIndex = Index;
+ Index = GetLinearExpression(Index, IndexScale, IndexOffset, ZExtBits,
+ SExtBits, DL, 0, AC, DT, NSW, NUW);
+
+ // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale.
+ // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale.
+
+ // It can be the case that, even through C1*V+C2 does not overflow for
+ // relevant values of V, (C2*Scale) can overflow. In that case, we cannot
+ // decompose the expression in this way.
+ //
+ // FIXME: C1*Scale and the other operations in the decomposed
+ // (C1*Scale)*V+C2*Scale can also overflow. We should check for this
+ // possibility.
+ APInt WideScaledOffset = IndexOffset.sextOrTrunc(MaxPointerSize*2) *
+ Scale.sext(MaxPointerSize*2);
+ if (WideScaledOffset.getMinSignedBits() > MaxPointerSize) {
+ Index = OrigIndex;
+ IndexScale = 1;
+ IndexOffset = 0;
+
+ ZExtBits = SExtBits = 0;
+ if (PointerSize > Width)
+ SExtBits += PointerSize - Width;
+ } else {
+ Decomposed.OtherOffset += IndexOffset.sextOrTrunc(MaxPointerSize) * Scale;
+ Scale *= IndexScale.sextOrTrunc(MaxPointerSize);
+ }
+
+ // If we already had an occurrence of this index variable, merge this
+ // scale into it. For example, we want to handle:
+ // A[x][x] -> x*16 + x*4 -> x*20
+ // This also ensures that 'x' only appears in the index list once.
+ for (unsigned i = 0, e = Decomposed.VarIndices.size(); i != e; ++i) {
+ if (Decomposed.VarIndices[i].V == Index &&
+ Decomposed.VarIndices[i].ZExtBits == ZExtBits &&
+ Decomposed.VarIndices[i].SExtBits == SExtBits) {
+ Scale += Decomposed.VarIndices[i].Scale;
+ Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i);
+ break;
+ }
+ }
+
+ // Make sure that we have a scale that makes sense for this target's
+ // pointer size.
+ Scale = adjustToPointerSize(Scale, PointerSize);
+
+ if (!!Scale) {
+ VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, Scale};
+ Decomposed.VarIndices.push_back(Entry);
+ }
+ }
+
+ // Take care of wrap-arounds
+ if (GepHasConstantOffset) {
+ Decomposed.StructOffset =
+ adjustToPointerSize(Decomposed.StructOffset, PointerSize);
+ Decomposed.OtherOffset =
+ adjustToPointerSize(Decomposed.OtherOffset, PointerSize);
+ }
+
+ // Analyze the base pointer next.
+ V = GEPOp->getOperand(0);
+ } while (--MaxLookup);
+
+ // If the chain of expressions is too deep, just return early.
+ Decomposed.Base = V;
+ SearchLimitReached++;
+ return true;
+}
+
+/// Returns whether the given pointer value points to memory that is local to
+/// the function, with global constants being considered local to all
+/// functions.
+bool BasicAAResult::pointsToConstantMemory(const MemoryLocation &Loc,
+ AAQueryInfo &AAQI, bool OrLocal) {
+ assert(Visited.empty() && "Visited must be cleared after use!");
+
+ unsigned MaxLookup = 8;
+ SmallVector<const Value *, 16> Worklist;
+ Worklist.push_back(Loc.Ptr);
+ do {
+ const Value *V = GetUnderlyingObject(Worklist.pop_back_val(), DL);
+ if (!Visited.insert(V).second) {
+ Visited.clear();
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+ }
+
+ // An alloca instruction defines local memory.
+ if (OrLocal && isa<AllocaInst>(V))
+ continue;
+
+ // A global constant counts as local memory for our purposes.
+ if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
+ // Note: this doesn't require GV to be "ODR" because it isn't legal for a
+ // global to be marked constant in some modules and non-constant in
+ // others. GV may even be a declaration, not a definition.
+ if (!GV->isConstant()) {
+ Visited.clear();
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+ }
+ continue;
+ }
+
+ // If both select values point to local memory, then so does the select.
+ if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
+ Worklist.push_back(SI->getTrueValue());
+ Worklist.push_back(SI->getFalseValue());
+ continue;
+ }
+
+ // If all values incoming to a phi node point to local memory, then so does
+ // the phi.
+ if (const PHINode *PN = dyn_cast<PHINode>(V)) {
+ // Don't bother inspecting phi nodes with many operands.
+ if (PN->getNumIncomingValues() > MaxLookup) {
+ Visited.clear();
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+ }
+ for (Value *IncValue : PN->incoming_values())
+ Worklist.push_back(IncValue);
+ continue;
+ }
+
+ // Otherwise be conservative.
+ Visited.clear();
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+ } while (!Worklist.empty() && --MaxLookup);
+
+ Visited.clear();
+ return Worklist.empty();
+}
+
+/// Returns the behavior when calling the given call site.
+FunctionModRefBehavior BasicAAResult::getModRefBehavior(const CallBase *Call) {
+ if (Call->doesNotAccessMemory())
+ // Can't do better than this.
+ return FMRB_DoesNotAccessMemory;
+
+ FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior;
+
+ // If the callsite knows it only reads memory, don't return worse
+ // than that.
+ if (Call->onlyReadsMemory())
+ Min = FMRB_OnlyReadsMemory;
+ else if (Call->doesNotReadMemory())
+ Min = FMRB_DoesNotReadMemory;
+
+ if (Call->onlyAccessesArgMemory())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees);
+ else if (Call->onlyAccessesInaccessibleMemory())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem);
+ else if (Call->onlyAccessesInaccessibleMemOrArgMem())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem);
+
+ // If the call has operand bundles then aliasing attributes from the function
+ // it calls do not directly apply to the call. This can be made more precise
+ // in the future.
+ if (!Call->hasOperandBundles())
+ if (const Function *F = Call->getCalledFunction())
+ Min =
+ FunctionModRefBehavior(Min & getBestAAResults().getModRefBehavior(F));
+
+ return Min;
+}
+
+/// Returns the behavior when calling the given function. For use when the call
+/// site is not known.
+FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) {
+ // If the function declares it doesn't access memory, we can't do better.
+ if (F->doesNotAccessMemory())
+ return FMRB_DoesNotAccessMemory;
+
+ FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior;
+
+ // If the function declares it only reads memory, go with that.
+ if (F->onlyReadsMemory())
+ Min = FMRB_OnlyReadsMemory;
+ else if (F->doesNotReadMemory())
+ Min = FMRB_DoesNotReadMemory;
+
+ if (F->onlyAccessesArgMemory())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees);
+ else if (F->onlyAccessesInaccessibleMemory())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem);
+ else if (F->onlyAccessesInaccessibleMemOrArgMem())
+ Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem);
+
+ return Min;
+}
+
+/// Returns true if this is a writeonly (i.e Mod only) parameter.
+static bool isWriteOnlyParam(const CallBase *Call, unsigned ArgIdx,
+ const TargetLibraryInfo &TLI) {
+ if (Call->paramHasAttr(ArgIdx, Attribute::WriteOnly))
+ return true;
+
+ // We can bound the aliasing properties of memset_pattern16 just as we can
+ // for memcpy/memset. This is particularly important because the
+ // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16
+ // whenever possible.
+ // FIXME Consider handling this in InferFunctionAttr.cpp together with other
+ // attributes.
+ LibFunc F;
+ if (Call->getCalledFunction() &&
+ TLI.getLibFunc(*Call->getCalledFunction(), F) &&
+ F == LibFunc_memset_pattern16 && TLI.has(F))
+ if (ArgIdx == 0)
+ return true;
+
+ // TODO: memset_pattern4, memset_pattern8
+ // TODO: _chk variants
+ // TODO: strcmp, strcpy
+
+ return false;
+}
+
+ModRefInfo BasicAAResult::getArgModRefInfo(const CallBase *Call,
+ unsigned ArgIdx) {
+ // Checking for known builtin intrinsics and target library functions.
+ if (isWriteOnlyParam(Call, ArgIdx, TLI))
+ return ModRefInfo::Mod;
+
+ if (Call->paramHasAttr(ArgIdx, Attribute::ReadOnly))
+ return ModRefInfo::Ref;
+
+ if (Call->paramHasAttr(ArgIdx, Attribute::ReadNone))
+ return ModRefInfo::NoModRef;
+
+ return AAResultBase::getArgModRefInfo(Call, ArgIdx);
+}
+
+static bool isIntrinsicCall(const CallBase *Call, Intrinsic::ID IID) {
+ const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call);
+ return II && II->getIntrinsicID() == IID;
+}
+
+#ifndef NDEBUG
+static const Function *getParent(const Value *V) {
+ if (const Instruction *inst = dyn_cast<Instruction>(V)) {
+ if (!inst->getParent())
+ return nullptr;
+ return inst->getParent()->getParent();
+ }
+
+ if (const Argument *arg = dyn_cast<Argument>(V))
+ return arg->getParent();
+
+ return nullptr;
+}
+
+static bool notDifferentParent(const Value *O1, const Value *O2) {
+
+ const Function *F1 = getParent(O1);
+ const Function *F2 = getParent(O2);
+
+ return !F1 || !F2 || F1 == F2;
+}
+#endif
+
+AliasResult BasicAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ assert(notDifferentParent(LocA.Ptr, LocB.Ptr) &&
+ "BasicAliasAnalysis doesn't support interprocedural queries.");
+
+ // If we have a directly cached entry for these locations, we have recursed
+ // through this once, so just return the cached results. Notably, when this
+ // happens, we don't clear the cache.
+ auto CacheIt = AAQI.AliasCache.find(AAQueryInfo::LocPair(LocA, LocB));
+ if (CacheIt != AAQI.AliasCache.end())
+ return CacheIt->second;
+
+ CacheIt = AAQI.AliasCache.find(AAQueryInfo::LocPair(LocB, LocA));
+ if (CacheIt != AAQI.AliasCache.end())
+ return CacheIt->second;
+
+ AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.AATags, LocB.Ptr,
+ LocB.Size, LocB.AATags, AAQI);
+
+ VisitedPhiBBs.clear();
+ return Alias;
+}
+
+/// Checks to see if the specified callsite can clobber the specified memory
+/// object.
+///
+/// Since we only look at local properties of this function, we really can't
+/// say much about this query. We do, however, use simple "address taken"
+/// analysis on local objects.
+ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ assert(notDifferentParent(Call, Loc.Ptr) &&
+ "AliasAnalysis query involving multiple functions!");
+
+ const Value *Object = GetUnderlyingObject(Loc.Ptr, DL);
+
+ // Calls marked 'tail' cannot read or write allocas from the current frame
+ // because the current frame might be destroyed by the time they run. However,
+ // a tail call may use an alloca with byval. Calling with byval copies the
+ // contents of the alloca into argument registers or stack slots, so there is
+ // no lifetime issue.
+ if (isa<AllocaInst>(Object))
+ if (const CallInst *CI = dyn_cast<CallInst>(Call))
+ if (CI->isTailCall() &&
+ !CI->getAttributes().hasAttrSomewhere(Attribute::ByVal))
+ return ModRefInfo::NoModRef;
+
+ // Stack restore is able to modify unescaped dynamic allocas. Assume it may
+ // modify them even though the alloca is not escaped.
+ if (auto *AI = dyn_cast<AllocaInst>(Object))
+ if (!AI->isStaticAlloca() && isIntrinsicCall(Call, Intrinsic::stackrestore))
+ return ModRefInfo::Mod;
+
+ // If the pointer is to a locally allocated object that does not escape,
+ // then the call can not mod/ref the pointer unless the call takes the pointer
+ // as an argument, and itself doesn't capture it.
+ if (!isa<Constant>(Object) && Call != Object &&
+ isNonEscapingLocalObject(Object, &AAQI.IsCapturedCache)) {
+
+ // Optimistically assume that call doesn't touch Object and check this
+ // assumption in the following loop.
+ ModRefInfo Result = ModRefInfo::NoModRef;
+ bool IsMustAlias = true;
+
+ unsigned OperandNo = 0;
+ for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end();
+ CI != CE; ++CI, ++OperandNo) {
+ // Only look at the no-capture or byval pointer arguments. If this
+ // pointer were passed to arguments that were neither of these, then it
+ // couldn't be no-capture.
+ if (!(*CI)->getType()->isPointerTy() ||
+ (!Call->doesNotCapture(OperandNo) &&
+ OperandNo < Call->getNumArgOperands() &&
+ !Call->isByValArgument(OperandNo)))
+ continue;
+
+ // Call doesn't access memory through this operand, so we don't care
+ // if it aliases with Object.
+ if (Call->doesNotAccessMemory(OperandNo))
+ continue;
+
+ // If this is a no-capture pointer argument, see if we can tell that it
+ // is impossible to alias the pointer we're checking.
+ AliasResult AR = getBestAAResults().alias(MemoryLocation(*CI),
+ MemoryLocation(Object), AAQI);
+ if (AR != MustAlias)
+ IsMustAlias = false;
+ // Operand doesn't alias 'Object', continue looking for other aliases
+ if (AR == NoAlias)
+ continue;
+ // Operand aliases 'Object', but call doesn't modify it. Strengthen
+ // initial assumption and keep looking in case if there are more aliases.
+ if (Call->onlyReadsMemory(OperandNo)) {
+ Result = setRef(Result);
+ continue;
+ }
+ // Operand aliases 'Object' but call only writes into it.
+ if (Call->doesNotReadMemory(OperandNo)) {
+ Result = setMod(Result);
+ continue;
+ }
+ // This operand aliases 'Object' and call reads and writes into it.
+ // Setting ModRef will not yield an early return below, MustAlias is not
+ // used further.
+ Result = ModRefInfo::ModRef;
+ break;
+ }
+
+ // No operand aliases, reset Must bit. Add below if at least one aliases
+ // and all aliases found are MustAlias.
+ if (isNoModRef(Result))
+ IsMustAlias = false;
+
+ // Early return if we improved mod ref information
+ if (!isModAndRefSet(Result)) {
+ if (isNoModRef(Result))
+ return ModRefInfo::NoModRef;
+ return IsMustAlias ? setMust(Result) : clearMust(Result);
+ }
+ }
+
+ // If the call is to malloc or calloc, we can assume that it doesn't
+ // modify any IR visible value. This is only valid because we assume these
+ // routines do not read values visible in the IR. TODO: Consider special
+ // casing realloc and strdup routines which access only their arguments as
+ // well. Or alternatively, replace all of this with inaccessiblememonly once
+ // that's implemented fully.
+ if (isMallocOrCallocLikeFn(Call, &TLI)) {
+ // Be conservative if the accessed pointer may alias the allocation -
+ // fallback to the generic handling below.
+ if (getBestAAResults().alias(MemoryLocation(Call), Loc, AAQI) == NoAlias)
+ return ModRefInfo::NoModRef;
+ }
+
+ // The semantics of memcpy intrinsics forbid overlap between their respective
+ // operands, i.e., source and destination of any given memcpy must no-alias.
+ // If Loc must-aliases either one of these two locations, then it necessarily
+ // no-aliases the other.
+ if (auto *Inst = dyn_cast<AnyMemCpyInst>(Call)) {
+ AliasResult SrcAA, DestAA;
+
+ if ((SrcAA = getBestAAResults().alias(MemoryLocation::getForSource(Inst),
+ Loc, AAQI)) == MustAlias)
+ // Loc is exactly the memcpy source thus disjoint from memcpy dest.
+ return ModRefInfo::Ref;
+ if ((DestAA = getBestAAResults().alias(MemoryLocation::getForDest(Inst),
+ Loc, AAQI)) == MustAlias)
+ // The converse case.
+ return ModRefInfo::Mod;
+
+ // It's also possible for Loc to alias both src and dest, or neither.
+ ModRefInfo rv = ModRefInfo::NoModRef;
+ if (SrcAA != NoAlias)
+ rv = setRef(rv);
+ if (DestAA != NoAlias)
+ rv = setMod(rv);
+ return rv;
+ }
+
+ // While the assume intrinsic is marked as arbitrarily writing so that
+ // proper control dependencies will be maintained, it never aliases any
+ // particular memory location.
+ if (isIntrinsicCall(Call, Intrinsic::assume))
+ return ModRefInfo::NoModRef;
+
+ // Like assumes, guard intrinsics are also marked as arbitrarily writing so
+ // that proper control dependencies are maintained but they never mods any
+ // particular memory location.
+ //
+ // *Unlike* assumes, guard intrinsics are modeled as reading memory since the
+ // heap state at the point the guard is issued needs to be consistent in case
+ // the guard invokes the "deopt" continuation.
+ if (isIntrinsicCall(Call, Intrinsic::experimental_guard))
+ return ModRefInfo::Ref;
+
+ // Like assumes, invariant.start intrinsics were also marked as arbitrarily
+ // writing so that proper control dependencies are maintained but they never
+ // mod any particular memory location visible to the IR.
+ // *Unlike* assumes (which are now modeled as NoModRef), invariant.start
+ // intrinsic is now modeled as reading memory. This prevents hoisting the
+ // invariant.start intrinsic over stores. Consider:
+ // *ptr = 40;
+ // *ptr = 50;
+ // invariant_start(ptr)
+ // int val = *ptr;
+ // print(val);
+ //
+ // This cannot be transformed to:
+ //
+ // *ptr = 40;
+ // invariant_start(ptr)
+ // *ptr = 50;
+ // int val = *ptr;
+ // print(val);
+ //
+ // The transformation will cause the second store to be ignored (based on
+ // rules of invariant.start) and print 40, while the first program always
+ // prints 50.
+ if (isIntrinsicCall(Call, Intrinsic::invariant_start))
+ return ModRefInfo::Ref;
+
+ // The AAResultBase base class has some smarts, lets use them.
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+}
+
+ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call1,
+ const CallBase *Call2,
+ AAQueryInfo &AAQI) {
+ // While the assume intrinsic is marked as arbitrarily writing so that
+ // proper control dependencies will be maintained, it never aliases any
+ // particular memory location.
+ if (isIntrinsicCall(Call1, Intrinsic::assume) ||
+ isIntrinsicCall(Call2, Intrinsic::assume))
+ return ModRefInfo::NoModRef;
+
+ // Like assumes, guard intrinsics are also marked as arbitrarily writing so
+ // that proper control dependencies are maintained but they never mod any
+ // particular memory location.
+ //
+ // *Unlike* assumes, guard intrinsics are modeled as reading memory since the
+ // heap state at the point the guard is issued needs to be consistent in case
+ // the guard invokes the "deopt" continuation.
+
+ // NB! This function is *not* commutative, so we special case two
+ // possibilities for guard intrinsics.
+
+ if (isIntrinsicCall(Call1, Intrinsic::experimental_guard))
+ return isModSet(createModRefInfo(getModRefBehavior(Call2)))
+ ? ModRefInfo::Ref
+ : ModRefInfo::NoModRef;
+
+ if (isIntrinsicCall(Call2, Intrinsic::experimental_guard))
+ return isModSet(createModRefInfo(getModRefBehavior(Call1)))
+ ? ModRefInfo::Mod
+ : ModRefInfo::NoModRef;
+
+ // The AAResultBase base class has some smarts, lets use them.
+ return AAResultBase::getModRefInfo(Call1, Call2, AAQI);
+}
+
+/// Provide ad-hoc rules to disambiguate accesses through two GEP operators,
+/// both having the exact same pointer operand.
+static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1,
+ LocationSize MaybeV1Size,
+ const GEPOperator *GEP2,
+ LocationSize MaybeV2Size,
+ const DataLayout &DL) {
+ assert(GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() ==
+ GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() &&
+ GEP1->getPointerOperandType() == GEP2->getPointerOperandType() &&
+ "Expected GEPs with the same pointer operand");
+
+ // Try to determine whether GEP1 and GEP2 index through arrays, into structs,
+ // such that the struct field accesses provably cannot alias.
+ // We also need at least two indices (the pointer, and the struct field).
+ if (GEP1->getNumIndices() != GEP2->getNumIndices() ||
+ GEP1->getNumIndices() < 2)
+ return MayAlias;
+
+ // If we don't know the size of the accesses through both GEPs, we can't
+ // determine whether the struct fields accessed can't alias.
+ if (MaybeV1Size == LocationSize::unknown() ||
+ MaybeV2Size == LocationSize::unknown())
+ return MayAlias;
+
+ const uint64_t V1Size = MaybeV1Size.getValue();
+ const uint64_t V2Size = MaybeV2Size.getValue();
+
+ ConstantInt *C1 =
+ dyn_cast<ConstantInt>(GEP1->getOperand(GEP1->getNumOperands() - 1));
+ ConstantInt *C2 =
+ dyn_cast<ConstantInt>(GEP2->getOperand(GEP2->getNumOperands() - 1));
+
+ // If the last (struct) indices are constants and are equal, the other indices
+ // might be also be dynamically equal, so the GEPs can alias.
+ if (C1 && C2) {
+ unsigned BitWidth = std::max(C1->getBitWidth(), C2->getBitWidth());
+ if (C1->getValue().sextOrSelf(BitWidth) ==
+ C2->getValue().sextOrSelf(BitWidth))
+ return MayAlias;
+ }
+
+ // Find the last-indexed type of the GEP, i.e., the type you'd get if
+ // you stripped the last index.
+ // On the way, look at each indexed type. If there's something other
+ // than an array, different indices can lead to different final types.
+ SmallVector<Value *, 8> IntermediateIndices;
+
+ // Insert the first index; we don't need to check the type indexed
+ // through it as it only drops the pointer indirection.
+ assert(GEP1->getNumIndices() > 1 && "Not enough GEP indices to examine");
+ IntermediateIndices.push_back(GEP1->getOperand(1));
+
+ // Insert all the remaining indices but the last one.
+ // Also, check that they all index through arrays.
+ for (unsigned i = 1, e = GEP1->getNumIndices() - 1; i != e; ++i) {
+ if (!isa<ArrayType>(GetElementPtrInst::getIndexedType(
+ GEP1->getSourceElementType(), IntermediateIndices)))
+ return MayAlias;
+ IntermediateIndices.push_back(GEP1->getOperand(i + 1));
+ }
+
+ auto *Ty = GetElementPtrInst::getIndexedType(
+ GEP1->getSourceElementType(), IntermediateIndices);
+ StructType *LastIndexedStruct = dyn_cast<StructType>(Ty);
+
+ if (isa<SequentialType>(Ty)) {
+ // We know that:
+ // - both GEPs begin indexing from the exact same pointer;
+ // - the last indices in both GEPs are constants, indexing into a sequential
+ // type (array or pointer);
+ // - both GEPs only index through arrays prior to that.
+ //
+ // Because array indices greater than the number of elements are valid in
+ // GEPs, unless we know the intermediate indices are identical between
+ // GEP1 and GEP2 we cannot guarantee that the last indexed arrays don't
+ // partially overlap. We also need to check that the loaded size matches
+ // the element size, otherwise we could still have overlap.
+ const uint64_t ElementSize =
+ DL.getTypeStoreSize(cast<SequentialType>(Ty)->getElementType());
+ if (V1Size != ElementSize || V2Size != ElementSize)
+ return MayAlias;
+
+ for (unsigned i = 0, e = GEP1->getNumIndices() - 1; i != e; ++i)
+ if (GEP1->getOperand(i + 1) != GEP2->getOperand(i + 1))
+ return MayAlias;
+
+ // Now we know that the array/pointer that GEP1 indexes into and that
+ // that GEP2 indexes into must either precisely overlap or be disjoint.
+ // Because they cannot partially overlap and because fields in an array
+ // cannot overlap, if we can prove the final indices are different between
+ // GEP1 and GEP2, we can conclude GEP1 and GEP2 don't alias.
+
+ // If the last indices are constants, we've already checked they don't
+ // equal each other so we can exit early.
+ if (C1 && C2)
+ return NoAlias;
+ {
+ Value *GEP1LastIdx = GEP1->getOperand(GEP1->getNumOperands() - 1);
+ Value *GEP2LastIdx = GEP2->getOperand(GEP2->getNumOperands() - 1);
+ if (isa<PHINode>(GEP1LastIdx) || isa<PHINode>(GEP2LastIdx)) {
+ // If one of the indices is a PHI node, be safe and only use
+ // computeKnownBits so we don't make any assumptions about the
+ // relationships between the two indices. This is important if we're
+ // asking about values from different loop iterations. See PR32314.
+ // TODO: We may be able to change the check so we only do this when
+ // we definitely looked through a PHINode.
+ if (GEP1LastIdx != GEP2LastIdx &&
+ GEP1LastIdx->getType() == GEP2LastIdx->getType()) {
+ KnownBits Known1 = computeKnownBits(GEP1LastIdx, DL);
+ KnownBits Known2 = computeKnownBits(GEP2LastIdx, DL);
+ if (Known1.Zero.intersects(Known2.One) ||
+ Known1.One.intersects(Known2.Zero))
+ return NoAlias;
+ }
+ } else if (isKnownNonEqual(GEP1LastIdx, GEP2LastIdx, DL))
+ return NoAlias;
+ }
+ return MayAlias;
+ } else if (!LastIndexedStruct || !C1 || !C2) {
+ return MayAlias;
+ }
+
+ if (C1->getValue().getActiveBits() > 64 ||
+ C2->getValue().getActiveBits() > 64)
+ return MayAlias;
+
+ // We know that:
+ // - both GEPs begin indexing from the exact same pointer;
+ // - the last indices in both GEPs are constants, indexing into a struct;
+ // - said indices are different, hence, the pointed-to fields are different;
+ // - both GEPs only index through arrays prior to that.
+ //
+ // This lets us determine that the struct that GEP1 indexes into and the
+ // struct that GEP2 indexes into must either precisely overlap or be
+ // completely disjoint. Because they cannot partially overlap, indexing into
+ // different non-overlapping fields of the struct will never alias.
+
+ // Therefore, the only remaining thing needed to show that both GEPs can't
+ // alias is that the fields are not overlapping.
+ const StructLayout *SL = DL.getStructLayout(LastIndexedStruct);
+ const uint64_t StructSize = SL->getSizeInBytes();
+ const uint64_t V1Off = SL->getElementOffset(C1->getZExtValue());
+ const uint64_t V2Off = SL->getElementOffset(C2->getZExtValue());
+
+ auto EltsDontOverlap = [StructSize](uint64_t V1Off, uint64_t V1Size,
+ uint64_t V2Off, uint64_t V2Size) {
+ return V1Off < V2Off && V1Off + V1Size <= V2Off &&
+ ((V2Off + V2Size <= StructSize) ||
+ (V2Off + V2Size - StructSize <= V1Off));
+ };
+
+ if (EltsDontOverlap(V1Off, V1Size, V2Off, V2Size) ||
+ EltsDontOverlap(V2Off, V2Size, V1Off, V1Size))
+ return NoAlias;
+
+ return MayAlias;
+}
+
+// If a we have (a) a GEP and (b) a pointer based on an alloca, and the
+// beginning of the object the GEP points would have a negative offset with
+// repsect to the alloca, that means the GEP can not alias pointer (b).
+// Note that the pointer based on the alloca may not be a GEP. For
+// example, it may be the alloca itself.
+// The same applies if (b) is based on a GlobalVariable. Note that just being
+// based on isIdentifiedObject() is not enough - we need an identified object
+// that does not permit access to negative offsets. For example, a negative
+// offset from a noalias argument or call can be inbounds w.r.t the actual
+// underlying object.
+//
+// For example, consider:
+//
+// struct { int f0, int f1, ...} foo;
+// foo alloca;
+// foo* random = bar(alloca);
+// int *f0 = &alloca.f0
+// int *f1 = &random->f1;
+//
+// Which is lowered, approximately, to:
+//
+// %alloca = alloca %struct.foo
+// %random = call %struct.foo* @random(%struct.foo* %alloca)
+// %f0 = getelementptr inbounds %struct, %struct.foo* %alloca, i32 0, i32 0
+// %f1 = getelementptr inbounds %struct, %struct.foo* %random, i32 0, i32 1
+//
+// Assume %f1 and %f0 alias. Then %f1 would point into the object allocated
+// by %alloca. Since the %f1 GEP is inbounds, that means %random must also
+// point into the same object. But since %f0 points to the beginning of %alloca,
+// the highest %f1 can be is (%alloca + 3). This means %random can not be higher
+// than (%alloca - 1), and so is not inbounds, a contradiction.
+bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp,
+ const DecomposedGEP &DecompGEP, const DecomposedGEP &DecompObject,
+ LocationSize MaybeObjectAccessSize) {
+ // If the object access size is unknown, or the GEP isn't inbounds, bail.
+ if (MaybeObjectAccessSize == LocationSize::unknown() || !GEPOp->isInBounds())
+ return false;
+
+ const uint64_t ObjectAccessSize = MaybeObjectAccessSize.getValue();
+
+ // We need the object to be an alloca or a globalvariable, and want to know
+ // the offset of the pointer from the object precisely, so no variable
+ // indices are allowed.
+ if (!(isa<AllocaInst>(DecompObject.Base) ||
+ isa<GlobalVariable>(DecompObject.Base)) ||
+ !DecompObject.VarIndices.empty())
+ return false;
+
+ APInt ObjectBaseOffset = DecompObject.StructOffset +
+ DecompObject.OtherOffset;
+
+ // If the GEP has no variable indices, we know the precise offset
+ // from the base, then use it. If the GEP has variable indices,
+ // we can't get exact GEP offset to identify pointer alias. So return
+ // false in that case.
+ if (!DecompGEP.VarIndices.empty())
+ return false;
+
+ APInt GEPBaseOffset = DecompGEP.StructOffset;
+ GEPBaseOffset += DecompGEP.OtherOffset;
+
+ return GEPBaseOffset.sge(ObjectBaseOffset + (int64_t)ObjectAccessSize);
+}
+
+/// Provides a bunch of ad-hoc rules to disambiguate a GEP instruction against
+/// another pointer.
+///
+/// We know that V1 is a GEP, but we don't know anything about V2.
+/// UnderlyingV1 is GetUnderlyingObject(GEP1, DL), UnderlyingV2 is the same for
+/// V2.
+AliasResult BasicAAResult::aliasGEP(
+ const GEPOperator *GEP1, LocationSize V1Size, const AAMDNodes &V1AAInfo,
+ const Value *V2, LocationSize V2Size, const AAMDNodes &V2AAInfo,
+ const Value *UnderlyingV1, const Value *UnderlyingV2, AAQueryInfo &AAQI) {
+ DecomposedGEP DecompGEP1, DecompGEP2;
+ unsigned MaxPointerSize = getMaxPointerSize(DL);
+ DecompGEP1.StructOffset = DecompGEP1.OtherOffset = APInt(MaxPointerSize, 0);
+ DecompGEP2.StructOffset = DecompGEP2.OtherOffset = APInt(MaxPointerSize, 0);
+
+ bool GEP1MaxLookupReached =
+ DecomposeGEPExpression(GEP1, DecompGEP1, DL, &AC, DT);
+ bool GEP2MaxLookupReached =
+ DecomposeGEPExpression(V2, DecompGEP2, DL, &AC, DT);
+
+ APInt GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset;
+ APInt GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset;
+
+ assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 &&
+ "DecomposeGEPExpression returned a result different from "
+ "GetUnderlyingObject");
+
+ // If the GEP's offset relative to its base is such that the base would
+ // fall below the start of the object underlying V2, then the GEP and V2
+ // cannot alias.
+ if (!GEP1MaxLookupReached && !GEP2MaxLookupReached &&
+ isGEPBaseAtNegativeOffset(GEP1, DecompGEP1, DecompGEP2, V2Size))
+ return NoAlias;
+ // If we have two gep instructions with must-alias or not-alias'ing base
+ // pointers, figure out if the indexes to the GEP tell us anything about the
+ // derived pointer.
+ if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) {
+ // Check for the GEP base being at a negative offset, this time in the other
+ // direction.
+ if (!GEP1MaxLookupReached && !GEP2MaxLookupReached &&
+ isGEPBaseAtNegativeOffset(GEP2, DecompGEP2, DecompGEP1, V1Size))
+ return NoAlias;
+ // Do the base pointers alias?
+ AliasResult BaseAlias =
+ aliasCheck(UnderlyingV1, LocationSize::unknown(), AAMDNodes(),
+ UnderlyingV2, LocationSize::unknown(), AAMDNodes(), AAQI);
+
+ // Check for geps of non-aliasing underlying pointers where the offsets are
+ // identical.
+ if ((BaseAlias == MayAlias) && V1Size == V2Size) {
+ // Do the base pointers alias assuming type and size.
+ AliasResult PreciseBaseAlias = aliasCheck(
+ UnderlyingV1, V1Size, V1AAInfo, UnderlyingV2, V2Size, V2AAInfo, AAQI);
+ if (PreciseBaseAlias == NoAlias) {
+ // See if the computed offset from the common pointer tells us about the
+ // relation of the resulting pointer.
+ // If the max search depth is reached the result is undefined
+ if (GEP2MaxLookupReached || GEP1MaxLookupReached)
+ return MayAlias;
+
+ // Same offsets.
+ if (GEP1BaseOffset == GEP2BaseOffset &&
+ DecompGEP1.VarIndices == DecompGEP2.VarIndices)
+ return NoAlias;
+ }
+ }
+
+ // If we get a No or May, then return it immediately, no amount of analysis
+ // will improve this situation.
+ if (BaseAlias != MustAlias) {
+ assert(BaseAlias == NoAlias || BaseAlias == MayAlias);
+ return BaseAlias;
+ }
+
+ // Otherwise, we have a MustAlias. Since the base pointers alias each other
+ // exactly, see if the computed offset from the common pointer tells us
+ // about the relation of the resulting pointer.
+ // If we know the two GEPs are based off of the exact same pointer (and not
+ // just the same underlying object), see if that tells us anything about
+ // the resulting pointers.
+ if (GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() ==
+ GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() &&
+ GEP1->getPointerOperandType() == GEP2->getPointerOperandType()) {
+ AliasResult R = aliasSameBasePointerGEPs(GEP1, V1Size, GEP2, V2Size, DL);
+ // If we couldn't find anything interesting, don't abandon just yet.
+ if (R != MayAlias)
+ return R;
+ }
+
+ // If the max search depth is reached, the result is undefined
+ if (GEP2MaxLookupReached || GEP1MaxLookupReached)
+ return MayAlias;
+
+ // Subtract the GEP2 pointer from the GEP1 pointer to find out their
+ // symbolic difference.
+ GEP1BaseOffset -= GEP2BaseOffset;
+ GetIndexDifference(DecompGEP1.VarIndices, DecompGEP2.VarIndices);
+
+ } else {
+ // Check to see if these two pointers are related by the getelementptr
+ // instruction. If one pointer is a GEP with a non-zero index of the other
+ // pointer, we know they cannot alias.
+
+ // If both accesses are unknown size, we can't do anything useful here.
+ if (V1Size == LocationSize::unknown() && V2Size == LocationSize::unknown())
+ return MayAlias;
+
+ AliasResult R = aliasCheck(UnderlyingV1, LocationSize::unknown(),
+ AAMDNodes(), V2, LocationSize::unknown(),
+ V2AAInfo, AAQI, nullptr, UnderlyingV2);
+ if (R != MustAlias) {
+ // If V2 may alias GEP base pointer, conservatively returns MayAlias.
+ // If V2 is known not to alias GEP base pointer, then the two values
+ // cannot alias per GEP semantics: "Any memory access must be done through
+ // a pointer value associated with an address range of the memory access,
+ // otherwise the behavior is undefined.".
+ assert(R == NoAlias || R == MayAlias);
+ return R;
+ }
+
+ // If the max search depth is reached the result is undefined
+ if (GEP1MaxLookupReached)
+ return MayAlias;
+ }
+
+ // In the two GEP Case, if there is no difference in the offsets of the
+ // computed pointers, the resultant pointers are a must alias. This
+ // happens when we have two lexically identical GEP's (for example).
+ //
+ // In the other case, if we have getelementptr <ptr>, 0, 0, 0, 0, ... and V2
+ // must aliases the GEP, the end result is a must alias also.
+ if (GEP1BaseOffset == 0 && DecompGEP1.VarIndices.empty())
+ return MustAlias;
+
+ // If there is a constant difference between the pointers, but the difference
+ // is less than the size of the associated memory object, then we know
+ // that the objects are partially overlapping. If the difference is
+ // greater, we know they do not overlap.
+ if (GEP1BaseOffset != 0 && DecompGEP1.VarIndices.empty()) {
+ if (GEP1BaseOffset.sge(0)) {
+ if (V2Size != LocationSize::unknown()) {
+ if (GEP1BaseOffset.ult(V2Size.getValue()))
+ return PartialAlias;
+ return NoAlias;
+ }
+ } else {
+ // We have the situation where:
+ // + +
+ // | BaseOffset |
+ // ---------------->|
+ // |-->V1Size |-------> V2Size
+ // GEP1 V2
+ // We need to know that V2Size is not unknown, otherwise we might have
+ // stripped a gep with negative index ('gep <ptr>, -1, ...).
+ if (V1Size != LocationSize::unknown() &&
+ V2Size != LocationSize::unknown()) {
+ if ((-GEP1BaseOffset).ult(V1Size.getValue()))
+ return PartialAlias;
+ return NoAlias;
+ }
+ }
+ }
+
+ if (!DecompGEP1.VarIndices.empty()) {
+ APInt Modulo(MaxPointerSize, 0);
+ bool AllPositive = true;
+ for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) {
+
+ // Try to distinguish something like &A[i][1] against &A[42][0].
+ // Grab the least significant bit set in any of the scales. We
+ // don't need std::abs here (even if the scale's negative) as we'll
+ // be ^'ing Modulo with itself later.
+ Modulo |= DecompGEP1.VarIndices[i].Scale;
+
+ if (AllPositive) {
+ // If the Value could change between cycles, then any reasoning about
+ // the Value this cycle may not hold in the next cycle. We'll just
+ // give up if we can't determine conditions that hold for every cycle:
+ const Value *V = DecompGEP1.VarIndices[i].V;
+
+ KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, DT);
+ bool SignKnownZero = Known.isNonNegative();
+ bool SignKnownOne = Known.isNegative();
+
+ // Zero-extension widens the variable, and so forces the sign
+ // bit to zero.
+ bool IsZExt = DecompGEP1.VarIndices[i].ZExtBits > 0 || isa<ZExtInst>(V);
+ SignKnownZero |= IsZExt;
+ SignKnownOne &= !IsZExt;
+
+ // If the variable begins with a zero then we know it's
+ // positive, regardless of whether the value is signed or
+ // unsigned.
+ APInt Scale = DecompGEP1.VarIndices[i].Scale;
+ AllPositive =
+ (SignKnownZero && Scale.sge(0)) || (SignKnownOne && Scale.slt(0));
+ }
+ }
+
+ Modulo = Modulo ^ (Modulo & (Modulo - 1));
+
+ // We can compute the difference between the two addresses
+ // mod Modulo. Check whether that difference guarantees that the
+ // two locations do not alias.
+ APInt ModOffset = GEP1BaseOffset & (Modulo - 1);
+ if (V1Size != LocationSize::unknown() &&
+ V2Size != LocationSize::unknown() && ModOffset.uge(V2Size.getValue()) &&
+ (Modulo - ModOffset).uge(V1Size.getValue()))
+ return NoAlias;
+
+ // If we know all the variables are positive, then GEP1 >= GEP1BasePtr.
+ // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers
+ // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr.
+ if (AllPositive && GEP1BaseOffset.sgt(0) &&
+ V2Size != LocationSize::unknown() &&
+ GEP1BaseOffset.uge(V2Size.getValue()))
+ return NoAlias;
+
+ if (constantOffsetHeuristic(DecompGEP1.VarIndices, V1Size, V2Size,
+ GEP1BaseOffset, &AC, DT))
+ return NoAlias;
+ }
+
+ // Statically, we can see that the base objects are the same, but the
+ // pointers have dynamic offsets which we can't resolve. And none of our
+ // little tricks above worked.
+ return MayAlias;
+}
+
+static AliasResult MergeAliasResults(AliasResult A, AliasResult B) {
+ // If the results agree, take it.
+ if (A == B)
+ return A;
+ // A mix of PartialAlias and MustAlias is PartialAlias.
+ if ((A == PartialAlias && B == MustAlias) ||
+ (B == PartialAlias && A == MustAlias))
+ return PartialAlias;
+ // Otherwise, we don't know anything.
+ return MayAlias;
+}
+
+/// Provides a bunch of ad-hoc rules to disambiguate a Select instruction
+/// against another.
+AliasResult
+BasicAAResult::aliasSelect(const SelectInst *SI, LocationSize SISize,
+ const AAMDNodes &SIAAInfo, const Value *V2,
+ LocationSize V2Size, const AAMDNodes &V2AAInfo,
+ const Value *UnderV2, AAQueryInfo &AAQI) {
+ // If the values are Selects with the same condition, we can do a more precise
+ // check: just check for aliases between the values on corresponding arms.
+ if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2))
+ if (SI->getCondition() == SI2->getCondition()) {
+ AliasResult Alias =
+ aliasCheck(SI->getTrueValue(), SISize, SIAAInfo, SI2->getTrueValue(),
+ V2Size, V2AAInfo, AAQI);
+ if (Alias == MayAlias)
+ return MayAlias;
+ AliasResult ThisAlias =
+ aliasCheck(SI->getFalseValue(), SISize, SIAAInfo,
+ SI2->getFalseValue(), V2Size, V2AAInfo, AAQI);
+ return MergeAliasResults(ThisAlias, Alias);
+ }
+
+ // If both arms of the Select node NoAlias or MustAlias V2, then returns
+ // NoAlias / MustAlias. Otherwise, returns MayAlias.
+ AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, SI->getTrueValue(),
+ SISize, SIAAInfo, AAQI, UnderV2);
+ if (Alias == MayAlias)
+ return MayAlias;
+
+ AliasResult ThisAlias = aliasCheck(V2, V2Size, V2AAInfo, SI->getFalseValue(),
+ SISize, SIAAInfo, AAQI, UnderV2);
+ return MergeAliasResults(ThisAlias, Alias);
+}
+
+/// Provide a bunch of ad-hoc rules to disambiguate a PHI instruction against
+/// another.
+AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize,
+ const AAMDNodes &PNAAInfo, const Value *V2,
+ LocationSize V2Size,
+ const AAMDNodes &V2AAInfo,
+ const Value *UnderV2, AAQueryInfo &AAQI) {
+ // Track phi nodes we have visited. We use this information when we determine
+ // value equivalence.
+ VisitedPhiBBs.insert(PN->getParent());
+
+ // If the values are PHIs in the same block, we can do a more precise
+ // as well as efficient check: just check for aliases between the values
+ // on corresponding edges.
+ if (const PHINode *PN2 = dyn_cast<PHINode>(V2))
+ if (PN2->getParent() == PN->getParent()) {
+ AAQueryInfo::LocPair Locs(MemoryLocation(PN, PNSize, PNAAInfo),
+ MemoryLocation(V2, V2Size, V2AAInfo));
+ if (PN > V2)
+ std::swap(Locs.first, Locs.second);
+ // Analyse the PHIs' inputs under the assumption that the PHIs are
+ // NoAlias.
+ // If the PHIs are May/MustAlias there must be (recursively) an input
+ // operand from outside the PHIs' cycle that is MayAlias/MustAlias or
+ // there must be an operation on the PHIs within the PHIs' value cycle
+ // that causes a MayAlias.
+ // Pretend the phis do not alias.
+ AliasResult Alias = NoAlias;
+ AliasResult OrigAliasResult;
+ {
+ // Limited lifetime iterator invalidated by the aliasCheck call below.
+ auto CacheIt = AAQI.AliasCache.find(Locs);
+ assert((CacheIt != AAQI.AliasCache.end()) &&
+ "There must exist an entry for the phi node");
+ OrigAliasResult = CacheIt->second;
+ CacheIt->second = NoAlias;
+ }
+
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ AliasResult ThisAlias =
+ aliasCheck(PN->getIncomingValue(i), PNSize, PNAAInfo,
+ PN2->getIncomingValueForBlock(PN->getIncomingBlock(i)),
+ V2Size, V2AAInfo, AAQI);
+ Alias = MergeAliasResults(ThisAlias, Alias);
+ if (Alias == MayAlias)
+ break;
+ }
+
+ // Reset if speculation failed.
+ if (Alias != NoAlias) {
+ auto Pair =
+ AAQI.AliasCache.insert(std::make_pair(Locs, OrigAliasResult));
+ assert(!Pair.second && "Entry must have existed");
+ Pair.first->second = OrigAliasResult;
+ }
+ return Alias;
+ }
+
+ SmallVector<Value *, 4> V1Srcs;
+ bool isRecursive = false;
+ if (PV) {
+ // If we have PhiValues then use it to get the underlying phi values.
+ const PhiValues::ValueSet &PhiValueSet = PV->getValuesForPhi(PN);
+ // If we have more phi values than the search depth then return MayAlias
+ // conservatively to avoid compile time explosion. The worst possible case
+ // is if both sides are PHI nodes. In which case, this is O(m x n) time
+ // where 'm' and 'n' are the number of PHI sources.
+ if (PhiValueSet.size() > MaxLookupSearchDepth)
+ return MayAlias;
+ // Add the values to V1Srcs
+ for (Value *PV1 : PhiValueSet) {
+ if (EnableRecPhiAnalysis) {
+ if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) {
+ // Check whether the incoming value is a GEP that advances the pointer
+ // result of this PHI node (e.g. in a loop). If this is the case, we
+ // would recurse and always get a MayAlias. Handle this case specially
+ // below.
+ if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 &&
+ isa<ConstantInt>(PV1GEP->idx_begin())) {
+ isRecursive = true;
+ continue;
+ }
+ }
+ }
+ V1Srcs.push_back(PV1);
+ }
+ } else {
+ // If we don't have PhiInfo then just look at the operands of the phi itself
+ // FIXME: Remove this once we can guarantee that we have PhiInfo always
+ SmallPtrSet<Value *, 4> UniqueSrc;
+ for (Value *PV1 : PN->incoming_values()) {
+ if (isa<PHINode>(PV1))
+ // If any of the source itself is a PHI, return MayAlias conservatively
+ // to avoid compile time explosion. The worst possible case is if both
+ // sides are PHI nodes. In which case, this is O(m x n) time where 'm'
+ // and 'n' are the number of PHI sources.
+ return MayAlias;
+
+ if (EnableRecPhiAnalysis)
+ if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) {
+ // Check whether the incoming value is a GEP that advances the pointer
+ // result of this PHI node (e.g. in a loop). If this is the case, we
+ // would recurse and always get a MayAlias. Handle this case specially
+ // below.
+ if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 &&
+ isa<ConstantInt>(PV1GEP->idx_begin())) {
+ isRecursive = true;
+ continue;
+ }
+ }
+
+ if (UniqueSrc.insert(PV1).second)
+ V1Srcs.push_back(PV1);
+ }
+ }
+
+ // If V1Srcs is empty then that means that the phi has no underlying non-phi
+ // value. This should only be possible in blocks unreachable from the entry
+ // block, but return MayAlias just in case.
+ if (V1Srcs.empty())
+ return MayAlias;
+
+ // If this PHI node is recursive, set the size of the accessed memory to
+ // unknown to represent all the possible values the GEP could advance the
+ // pointer to.
+ if (isRecursive)
+ PNSize = LocationSize::unknown();
+
+ AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, V1Srcs[0], PNSize,
+ PNAAInfo, AAQI, UnderV2);
+
+ // Early exit if the check of the first PHI source against V2 is MayAlias.
+ // Other results are not possible.
+ if (Alias == MayAlias)
+ return MayAlias;
+
+ // If all sources of the PHI node NoAlias or MustAlias V2, then returns
+ // NoAlias / MustAlias. Otherwise, returns MayAlias.
+ for (unsigned i = 1, e = V1Srcs.size(); i != e; ++i) {
+ Value *V = V1Srcs[i];
+
+ AliasResult ThisAlias =
+ aliasCheck(V2, V2Size, V2AAInfo, V, PNSize, PNAAInfo, AAQI, UnderV2);
+ Alias = MergeAliasResults(ThisAlias, Alias);
+ if (Alias == MayAlias)
+ break;
+ }
+
+ return Alias;
+}
+
+/// Provides a bunch of ad-hoc rules to disambiguate in common cases, such as
+/// array references.
+AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size,
+ AAMDNodes V1AAInfo, const Value *V2,
+ LocationSize V2Size, AAMDNodes V2AAInfo,
+ AAQueryInfo &AAQI, const Value *O1,
+ const Value *O2) {
+ // If either of the memory references is empty, it doesn't matter what the
+ // pointer values are.
+ if (V1Size.isZero() || V2Size.isZero())
+ return NoAlias;
+
+ // Strip off any casts if they exist.
+ V1 = V1->stripPointerCastsAndInvariantGroups();
+ V2 = V2->stripPointerCastsAndInvariantGroups();
+
+ // If V1 or V2 is undef, the result is NoAlias because we can always pick a
+ // value for undef that aliases nothing in the program.
+ if (isa<UndefValue>(V1) || isa<UndefValue>(V2))
+ return NoAlias;
+
+ // Are we checking for alias of the same value?
+ // Because we look 'through' phi nodes, we could look at "Value" pointers from
+ // different iterations. We must therefore make sure that this is not the
+ // case. The function isValueEqualInPotentialCycles ensures that this cannot
+ // happen by looking at the visited phi nodes and making sure they cannot
+ // reach the value.
+ if (isValueEqualInPotentialCycles(V1, V2))
+ return MustAlias;
+
+ if (!V1->getType()->isPointerTy() || !V2->getType()->isPointerTy())
+ return NoAlias; // Scalars cannot alias each other
+
+ // Figure out what objects these things are pointing to if we can.
+ if (O1 == nullptr)
+ O1 = GetUnderlyingObject(V1, DL, MaxLookupSearchDepth);
+
+ if (O2 == nullptr)
+ O2 = GetUnderlyingObject(V2, DL, MaxLookupSearchDepth);
+
+ // Null values in the default address space don't point to any object, so they
+ // don't alias any other pointer.
+ if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O1))
+ if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace()))
+ return NoAlias;
+ if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O2))
+ if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace()))
+ return NoAlias;
+
+ if (O1 != O2) {
+ // If V1/V2 point to two different objects, we know that we have no alias.
+ if (isIdentifiedObject(O1) && isIdentifiedObject(O2))
+ return NoAlias;
+
+ // Constant pointers can't alias with non-const isIdentifiedObject objects.
+ if ((isa<Constant>(O1) && isIdentifiedObject(O2) && !isa<Constant>(O2)) ||
+ (isa<Constant>(O2) && isIdentifiedObject(O1) && !isa<Constant>(O1)))
+ return NoAlias;
+
+ // Function arguments can't alias with things that are known to be
+ // unambigously identified at the function level.
+ if ((isa<Argument>(O1) && isIdentifiedFunctionLocal(O2)) ||
+ (isa<Argument>(O2) && isIdentifiedFunctionLocal(O1)))
+ return NoAlias;
+
+ // If one pointer is the result of a call/invoke or load and the other is a
+ // non-escaping local object within the same function, then we know the
+ // object couldn't escape to a point where the call could return it.
+ //
+ // Note that if the pointers are in different functions, there are a
+ // variety of complications. A call with a nocapture argument may still
+ // temporary store the nocapture argument's value in a temporary memory
+ // location if that memory location doesn't escape. Or it may pass a
+ // nocapture value to other functions as long as they don't capture it.
+ if (isEscapeSource(O1) &&
+ isNonEscapingLocalObject(O2, &AAQI.IsCapturedCache))
+ return NoAlias;
+ if (isEscapeSource(O2) &&
+ isNonEscapingLocalObject(O1, &AAQI.IsCapturedCache))
+ return NoAlias;
+ }
+
+ // If the size of one access is larger than the entire object on the other
+ // side, then we know such behavior is undefined and can assume no alias.
+ bool NullIsValidLocation = NullPointerIsDefined(&F);
+ if ((isObjectSmallerThan(
+ O2, getMinimalExtentFrom(*V1, V1Size, DL, NullIsValidLocation), DL,
+ TLI, NullIsValidLocation)) ||
+ (isObjectSmallerThan(
+ O1, getMinimalExtentFrom(*V2, V2Size, DL, NullIsValidLocation), DL,
+ TLI, NullIsValidLocation)))
+ return NoAlias;
+
+ // Check the cache before climbing up use-def chains. This also terminates
+ // otherwise infinitely recursive queries.
+ AAQueryInfo::LocPair Locs(MemoryLocation(V1, V1Size, V1AAInfo),
+ MemoryLocation(V2, V2Size, V2AAInfo));
+ if (V1 > V2)
+ std::swap(Locs.first, Locs.second);
+ std::pair<AAQueryInfo::AliasCacheT::iterator, bool> Pair =
+ AAQI.AliasCache.try_emplace(Locs, MayAlias);
+ if (!Pair.second)
+ return Pair.first->second;
+
+ // FIXME: This isn't aggressively handling alias(GEP, PHI) for example: if the
+ // GEP can't simplify, we don't even look at the PHI cases.
+ if (!isa<GEPOperator>(V1) && isa<GEPOperator>(V2)) {
+ std::swap(V1, V2);
+ std::swap(V1Size, V2Size);
+ std::swap(O1, O2);
+ std::swap(V1AAInfo, V2AAInfo);
+ }
+ if (const GEPOperator *GV1 = dyn_cast<GEPOperator>(V1)) {
+ AliasResult Result =
+ aliasGEP(GV1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O1, O2, AAQI);
+ if (Result != MayAlias) {
+ auto ItInsPair = AAQI.AliasCache.insert(std::make_pair(Locs, Result));
+ assert(!ItInsPair.second && "Entry must have existed");
+ ItInsPair.first->second = Result;
+ return Result;
+ }
+ }
+
+ if (isa<PHINode>(V2) && !isa<PHINode>(V1)) {
+ std::swap(V1, V2);
+ std::swap(O1, O2);
+ std::swap(V1Size, V2Size);
+ std::swap(V1AAInfo, V2AAInfo);
+ }
+ if (const PHINode *PN = dyn_cast<PHINode>(V1)) {
+ AliasResult Result =
+ aliasPHI(PN, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O2, AAQI);
+ if (Result != MayAlias) {
+ Pair = AAQI.AliasCache.try_emplace(Locs, Result);
+ assert(!Pair.second && "Entry must have existed");
+ return Pair.first->second = Result;
+ }
+ }
+
+ if (isa<SelectInst>(V2) && !isa<SelectInst>(V1)) {
+ std::swap(V1, V2);
+ std::swap(O1, O2);
+ std::swap(V1Size, V2Size);
+ std::swap(V1AAInfo, V2AAInfo);
+ }
+ if (const SelectInst *S1 = dyn_cast<SelectInst>(V1)) {
+ AliasResult Result =
+ aliasSelect(S1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O2, AAQI);
+ if (Result != MayAlias) {
+ Pair = AAQI.AliasCache.try_emplace(Locs, Result);
+ assert(!Pair.second && "Entry must have existed");
+ return Pair.first->second = Result;
+ }
+ }
+
+ // If both pointers are pointing into the same object and one of them
+ // accesses the entire object, then the accesses must overlap in some way.
+ if (O1 == O2)
+ if (V1Size.isPrecise() && V2Size.isPrecise() &&
+ (isObjectSize(O1, V1Size.getValue(), DL, TLI, NullIsValidLocation) ||
+ isObjectSize(O2, V2Size.getValue(), DL, TLI, NullIsValidLocation))) {
+ Pair = AAQI.AliasCache.try_emplace(Locs, PartialAlias);
+ assert(!Pair.second && "Entry must have existed");
+ return Pair.first->second = PartialAlias;
+ }
+
+ // Recurse back into the best AA results we have, potentially with refined
+ // memory locations. We have already ensured that BasicAA has a MayAlias
+ // cache result for these, so any recursion back into BasicAA won't loop.
+ AliasResult Result = getBestAAResults().alias(Locs.first, Locs.second, AAQI);
+ Pair = AAQI.AliasCache.try_emplace(Locs, Result);
+ assert(!Pair.second && "Entry must have existed");
+ return Pair.first->second = Result;
+}
+
+/// Check whether two Values can be considered equivalent.
+///
+/// In addition to pointer equivalence of \p V1 and \p V2 this checks whether
+/// they can not be part of a cycle in the value graph by looking at all
+/// visited phi nodes an making sure that the phis cannot reach the value. We
+/// have to do this because we are looking through phi nodes (That is we say
+/// noalias(V, phi(VA, VB)) if noalias(V, VA) and noalias(V, VB).
+bool BasicAAResult::isValueEqualInPotentialCycles(const Value *V,
+ const Value *V2) {
+ if (V != V2)
+ return false;
+
+ const Instruction *Inst = dyn_cast<Instruction>(V);
+ if (!Inst)
+ return true;
+
+ if (VisitedPhiBBs.empty())
+ return true;
+
+ if (VisitedPhiBBs.size() > MaxNumPhiBBsValueReachabilityCheck)
+ return false;
+
+ // Make sure that the visited phis cannot reach the Value. This ensures that
+ // the Values cannot come from different iterations of a potential cycle the
+ // phi nodes could be involved in.
+ for (auto *P : VisitedPhiBBs)
+ if (isPotentiallyReachable(&P->front(), Inst, nullptr, DT, LI))
+ return false;
+
+ return true;
+}
+
+/// Computes the symbolic difference between two de-composed GEPs.
+///
+/// Dest and Src are the variable indices from two decomposed GetElementPtr
+/// instructions GEP1 and GEP2 which have common base pointers.
+void BasicAAResult::GetIndexDifference(
+ SmallVectorImpl<VariableGEPIndex> &Dest,
+ const SmallVectorImpl<VariableGEPIndex> &Src) {
+ if (Src.empty())
+ return;
+
+ for (unsigned i = 0, e = Src.size(); i != e; ++i) {
+ const Value *V = Src[i].V;
+ unsigned ZExtBits = Src[i].ZExtBits, SExtBits = Src[i].SExtBits;
+ APInt Scale = Src[i].Scale;
+
+ // Find V in Dest. This is N^2, but pointer indices almost never have more
+ // than a few variable indexes.
+ for (unsigned j = 0, e = Dest.size(); j != e; ++j) {
+ if (!isValueEqualInPotentialCycles(Dest[j].V, V) ||
+ Dest[j].ZExtBits != ZExtBits || Dest[j].SExtBits != SExtBits)
+ continue;
+
+ // If we found it, subtract off Scale V's from the entry in Dest. If it
+ // goes to zero, remove the entry.
+ if (Dest[j].Scale != Scale)
+ Dest[j].Scale -= Scale;
+ else
+ Dest.erase(Dest.begin() + j);
+ Scale = 0;
+ break;
+ }
+
+ // If we didn't consume this entry, add it to the end of the Dest list.
+ if (!!Scale) {
+ VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale};
+ Dest.push_back(Entry);
+ }
+ }
+}
+
+bool BasicAAResult::constantOffsetHeuristic(
+ const SmallVectorImpl<VariableGEPIndex> &VarIndices,
+ LocationSize MaybeV1Size, LocationSize MaybeV2Size, APInt BaseOffset,
+ AssumptionCache *AC, DominatorTree *DT) {
+ if (VarIndices.size() != 2 || MaybeV1Size == LocationSize::unknown() ||
+ MaybeV2Size == LocationSize::unknown())
+ return false;
+
+ const uint64_t V1Size = MaybeV1Size.getValue();
+ const uint64_t V2Size = MaybeV2Size.getValue();
+
+ const VariableGEPIndex &Var0 = VarIndices[0], &Var1 = VarIndices[1];
+
+ if (Var0.ZExtBits != Var1.ZExtBits || Var0.SExtBits != Var1.SExtBits ||
+ Var0.Scale != -Var1.Scale)
+ return false;
+
+ unsigned Width = Var1.V->getType()->getIntegerBitWidth();
+
+ // We'll strip off the Extensions of Var0 and Var1 and do another round
+ // of GetLinearExpression decomposition. In the example above, if Var0
+ // is zext(%x + 1) we should get V1 == %x and V1Offset == 1.
+
+ APInt V0Scale(Width, 0), V0Offset(Width, 0), V1Scale(Width, 0),
+ V1Offset(Width, 0);
+ bool NSW = true, NUW = true;
+ unsigned V0ZExtBits = 0, V0SExtBits = 0, V1ZExtBits = 0, V1SExtBits = 0;
+ const Value *V0 = GetLinearExpression(Var0.V, V0Scale, V0Offset, V0ZExtBits,
+ V0SExtBits, DL, 0, AC, DT, NSW, NUW);
+ NSW = true;
+ NUW = true;
+ const Value *V1 = GetLinearExpression(Var1.V, V1Scale, V1Offset, V1ZExtBits,
+ V1SExtBits, DL, 0, AC, DT, NSW, NUW);
+
+ if (V0Scale != V1Scale || V0ZExtBits != V1ZExtBits ||
+ V0SExtBits != V1SExtBits || !isValueEqualInPotentialCycles(V0, V1))
+ return false;
+
+ // We have a hit - Var0 and Var1 only differ by a constant offset!
+
+ // If we've been sext'ed then zext'd the maximum difference between Var0 and
+ // Var1 is possible to calculate, but we're just interested in the absolute
+ // minimum difference between the two. The minimum distance may occur due to
+ // wrapping; consider "add i3 %i, 5": if %i == 7 then 7 + 5 mod 8 == 4, and so
+ // the minimum distance between %i and %i + 5 is 3.
+ APInt MinDiff = V0Offset - V1Offset, Wrapped = -MinDiff;
+ MinDiff = APIntOps::umin(MinDiff, Wrapped);
+ APInt MinDiffBytes =
+ MinDiff.zextOrTrunc(Var0.Scale.getBitWidth()) * Var0.Scale.abs();
+
+ // We can't definitely say whether GEP1 is before or after V2 due to wrapping
+ // arithmetic (i.e. for some values of GEP1 and V2 GEP1 < V2, and for other
+ // values GEP1 > V2). We'll therefore only declare NoAlias if both V1Size and
+ // V2Size can fit in the MinDiffBytes gap.
+ return MinDiffBytes.uge(V1Size + BaseOffset.abs()) &&
+ MinDiffBytes.uge(V2Size + BaseOffset.abs());
+}
+
+//===----------------------------------------------------------------------===//
+// BasicAliasAnalysis Pass
+//===----------------------------------------------------------------------===//
+
+AnalysisKey BasicAA::Key;
+
+BasicAAResult BasicAA::run(Function &F, FunctionAnalysisManager &AM) {
+ return BasicAAResult(F.getParent()->getDataLayout(),
+ F,
+ AM.getResult<TargetLibraryAnalysis>(F),
+ AM.getResult<AssumptionAnalysis>(F),
+ &AM.getResult<DominatorTreeAnalysis>(F),
+ AM.getCachedResult<LoopAnalysis>(F),
+ AM.getCachedResult<PhiValuesAnalysis>(F));
+}
+
+BasicAAWrapperPass::BasicAAWrapperPass() : FunctionPass(ID) {
+ initializeBasicAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+char BasicAAWrapperPass::ID = 0;
+
+void BasicAAWrapperPass::anchor() {}
+
+INITIALIZE_PASS_BEGIN(BasicAAWrapperPass, "basicaa",
+ "Basic Alias Analysis (stateless AA impl)", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(BasicAAWrapperPass, "basicaa",
+ "Basic Alias Analysis (stateless AA impl)", false, true)
+
+FunctionPass *llvm::createBasicAAWrapperPass() {
+ return new BasicAAWrapperPass();
+}
+
+bool BasicAAWrapperPass::runOnFunction(Function &F) {
+ auto &ACT = getAnalysis<AssumptionCacheTracker>();
+ auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>();
+ auto &DTWP = getAnalysis<DominatorTreeWrapperPass>();
+ auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
+ auto *PVWP = getAnalysisIfAvailable<PhiValuesWrapperPass>();
+
+ Result.reset(new BasicAAResult(F.getParent()->getDataLayout(), F,
+ TLIWP.getTLI(F), ACT.getAssumptionCache(F),
+ &DTWP.getDomTree(),
+ LIWP ? &LIWP->getLoopInfo() : nullptr,
+ PVWP ? &PVWP->getResult() : nullptr));
+
+ return false;
+}
+
+void BasicAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addUsedIfAvailable<PhiValuesWrapperPass>();
+}
+
+BasicAAResult llvm::createLegacyPMBasicAAResult(Pass &P, Function &F) {
+ return BasicAAResult(
+ F.getParent()->getDataLayout(), F,
+ P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
+ P.getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F));
+}
diff --git a/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/llvm/lib/Analysis/BlockFrequencyInfo.cpp
new file mode 100644
index 000000000000..de183bbde173
--- /dev/null
+++ b/llvm/lib/Analysis/BlockFrequencyInfo.cpp
@@ -0,0 +1,342 @@
+//===- BlockFrequencyInfo.cpp - Block Frequency Analysis ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Loops should be simplified before this analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/iterator.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <string>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "block-freq"
+
+static cl::opt<GVDAGType> ViewBlockFreqPropagationDAG(
+ "view-block-freq-propagation-dags", cl::Hidden,
+ cl::desc("Pop up a window to show a dag displaying how block "
+ "frequencies propagation through the CFG."),
+ cl::values(clEnumValN(GVDT_None, "none", "do not display graphs."),
+ clEnumValN(GVDT_Fraction, "fraction",
+ "display a graph using the "
+ "fractional block frequency representation."),
+ clEnumValN(GVDT_Integer, "integer",
+ "display a graph using the raw "
+ "integer fractional block frequency representation."),
+ clEnumValN(GVDT_Count, "count", "display a graph using the real "
+ "profile count if available.")));
+
+cl::opt<std::string>
+ ViewBlockFreqFuncName("view-bfi-func-name", cl::Hidden,
+ cl::desc("The option to specify "
+ "the name of the function "
+ "whose CFG will be displayed."));
+
+cl::opt<unsigned>
+ ViewHotFreqPercent("view-hot-freq-percent", cl::init(10), cl::Hidden,
+ cl::desc("An integer in percent used to specify "
+ "the hot blocks/edges to be displayed "
+ "in red: a block or edge whose frequency "
+ "is no less than the max frequency of the "
+ "function multiplied by this percent."));
+
+// Command line option to turn on CFG dot or text dump after profile annotation.
+cl::opt<PGOViewCountsType> PGOViewCounts(
+ "pgo-view-counts", cl::Hidden,
+ cl::desc("A boolean option to show CFG dag or text with "
+ "block profile counts and branch probabilities "
+ "right after PGO profile annotation step. The "
+ "profile counts are computed using branch "
+ "probabilities from the runtime profile data and "
+ "block frequency propagation algorithm. To view "
+ "the raw counts from the profile, use option "
+ "-pgo-view-raw-counts instead. To limit graph "
+ "display to only one function, use filtering option "
+ "-view-bfi-func-name."),
+ cl::values(clEnumValN(PGOVCT_None, "none", "do not show."),
+ clEnumValN(PGOVCT_Graph, "graph", "show a graph."),
+ clEnumValN(PGOVCT_Text, "text", "show in text.")));
+
+static cl::opt<bool> PrintBlockFreq(
+ "print-bfi", cl::init(false), cl::Hidden,
+ cl::desc("Print the block frequency info."));
+
+cl::opt<std::string> PrintBlockFreqFuncName(
+ "print-bfi-func-name", cl::Hidden,
+ cl::desc("The option to specify the name of the function "
+ "whose block frequency info is printed."));
+
+namespace llvm {
+
+static GVDAGType getGVDT() {
+ if (PGOViewCounts == PGOVCT_Graph)
+ return GVDT_Count;
+ return ViewBlockFreqPropagationDAG;
+}
+
+template <>
+struct GraphTraits<BlockFrequencyInfo *> {
+ using NodeRef = const BasicBlock *;
+ using ChildIteratorType = succ_const_iterator;
+ using nodes_iterator = pointer_iterator<Function::const_iterator>;
+
+ static NodeRef getEntryNode(const BlockFrequencyInfo *G) {
+ return &G->getFunction()->front();
+ }
+
+ static ChildIteratorType child_begin(const NodeRef N) {
+ return succ_begin(N);
+ }
+
+ static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); }
+
+ static nodes_iterator nodes_begin(const BlockFrequencyInfo *G) {
+ return nodes_iterator(G->getFunction()->begin());
+ }
+
+ static nodes_iterator nodes_end(const BlockFrequencyInfo *G) {
+ return nodes_iterator(G->getFunction()->end());
+ }
+};
+
+using BFIDOTGTraitsBase =
+ BFIDOTGraphTraitsBase<BlockFrequencyInfo, BranchProbabilityInfo>;
+
+template <>
+struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase {
+ explicit DOTGraphTraits(bool isSimple = false)
+ : BFIDOTGTraitsBase(isSimple) {}
+
+ std::string getNodeLabel(const BasicBlock *Node,
+ const BlockFrequencyInfo *Graph) {
+
+ return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, getGVDT());
+ }
+
+ std::string getNodeAttributes(const BasicBlock *Node,
+ const BlockFrequencyInfo *Graph) {
+ return BFIDOTGTraitsBase::getNodeAttributes(Node, Graph,
+ ViewHotFreqPercent);
+ }
+
+ std::string getEdgeAttributes(const BasicBlock *Node, EdgeIter EI,
+ const BlockFrequencyInfo *BFI) {
+ return BFIDOTGTraitsBase::getEdgeAttributes(Node, EI, BFI, BFI->getBPI(),
+ ViewHotFreqPercent);
+ }
+};
+
+} // end namespace llvm
+
+BlockFrequencyInfo::BlockFrequencyInfo() = default;
+
+BlockFrequencyInfo::BlockFrequencyInfo(const Function &F,
+ const BranchProbabilityInfo &BPI,
+ const LoopInfo &LI) {
+ calculate(F, BPI, LI);
+}
+
+BlockFrequencyInfo::BlockFrequencyInfo(BlockFrequencyInfo &&Arg)
+ : BFI(std::move(Arg.BFI)) {}
+
+BlockFrequencyInfo &BlockFrequencyInfo::operator=(BlockFrequencyInfo &&RHS) {
+ releaseMemory();
+ BFI = std::move(RHS.BFI);
+ return *this;
+}
+
+// Explicitly define the default constructor otherwise it would be implicitly
+// defined at the first ODR-use which is the BFI member in the
+// LazyBlockFrequencyInfo header. The dtor needs the BlockFrequencyInfoImpl
+// template instantiated which is not available in the header.
+BlockFrequencyInfo::~BlockFrequencyInfo() = default;
+
+bool BlockFrequencyInfo::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // Check whether the analysis, all analyses on functions, or the function's
+ // CFG have been preserved.
+ auto PAC = PA.getChecker<BlockFrequencyAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
+ PAC.preservedSet<CFGAnalyses>());
+}
+
+void BlockFrequencyInfo::calculate(const Function &F,
+ const BranchProbabilityInfo &BPI,
+ const LoopInfo &LI) {
+ if (!BFI)
+ BFI.reset(new ImplType);
+ BFI->calculate(F, BPI, LI);
+ if (ViewBlockFreqPropagationDAG != GVDT_None &&
+ (ViewBlockFreqFuncName.empty() ||
+ F.getName().equals(ViewBlockFreqFuncName))) {
+ view();
+ }
+ if (PrintBlockFreq &&
+ (PrintBlockFreqFuncName.empty() ||
+ F.getName().equals(PrintBlockFreqFuncName))) {
+ print(dbgs());
+ }
+}
+
+BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const {
+ return BFI ? BFI->getBlockFreq(BB) : 0;
+}
+
+Optional<uint64_t>
+BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB,
+ bool AllowSynthetic) const {
+ if (!BFI)
+ return None;
+
+ return BFI->getBlockProfileCount(*getFunction(), BB, AllowSynthetic);
+}
+
+Optional<uint64_t>
+BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+ if (!BFI)
+ return None;
+ return BFI->getProfileCountFromFreq(*getFunction(), Freq);
+}
+
+bool BlockFrequencyInfo::isIrrLoopHeader(const BasicBlock *BB) {
+ assert(BFI && "Expected analysis to be available");
+ return BFI->isIrrLoopHeader(BB);
+}
+
+void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
+ assert(BFI && "Expected analysis to be available");
+ BFI->setBlockFreq(BB, Freq);
+}
+
+void BlockFrequencyInfo::setBlockFreqAndScale(
+ const BasicBlock *ReferenceBB, uint64_t Freq,
+ SmallPtrSetImpl<BasicBlock *> &BlocksToScale) {
+ assert(BFI && "Expected analysis to be available");
+ // Use 128 bits APInt to avoid overflow.
+ APInt NewFreq(128, Freq);
+ APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency());
+ APInt BBFreq(128, 0);
+ for (auto *BB : BlocksToScale) {
+ BBFreq = BFI->getBlockFreq(BB).getFrequency();
+ // Multiply first by NewFreq and then divide by OldFreq
+ // to minimize loss of precision.
+ BBFreq *= NewFreq;
+ // udiv is an expensive operation in the general case. If this ends up being
+ // a hot spot, one of the options proposed in
+ // https://reviews.llvm.org/D28535#650071 could be used to avoid this.
+ BBFreq = BBFreq.udiv(OldFreq);
+ BFI->setBlockFreq(BB, BBFreq.getLimitedValue());
+ }
+ BFI->setBlockFreq(ReferenceBB, Freq);
+}
+
+/// Pop up a ghostview window with the current block frequency propagation
+/// rendered using dot.
+void BlockFrequencyInfo::view(StringRef title) const {
+ ViewGraph(const_cast<BlockFrequencyInfo *>(this), title);
+}
+
+const Function *BlockFrequencyInfo::getFunction() const {
+ return BFI ? BFI->getFunction() : nullptr;
+}
+
+const BranchProbabilityInfo *BlockFrequencyInfo::getBPI() const {
+ return BFI ? &BFI->getBPI() : nullptr;
+}
+
+raw_ostream &BlockFrequencyInfo::
+printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const {
+ return BFI ? BFI->printBlockFreq(OS, Freq) : OS;
+}
+
+raw_ostream &
+BlockFrequencyInfo::printBlockFreq(raw_ostream &OS,
+ const BasicBlock *BB) const {
+ return BFI ? BFI->printBlockFreq(OS, BB) : OS;
+}
+
+uint64_t BlockFrequencyInfo::getEntryFreq() const {
+ return BFI ? BFI->getEntryFreq() : 0;
+}
+
+void BlockFrequencyInfo::releaseMemory() { BFI.reset(); }
+
+void BlockFrequencyInfo::print(raw_ostream &OS) const {
+ if (BFI)
+ BFI->print(OS);
+}
+
+INITIALIZE_PASS_BEGIN(BlockFrequencyInfoWrapperPass, "block-freq",
+ "Block Frequency Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(BlockFrequencyInfoWrapperPass, "block-freq",
+ "Block Frequency Analysis", true, true)
+
+char BlockFrequencyInfoWrapperPass::ID = 0;
+
+BlockFrequencyInfoWrapperPass::BlockFrequencyInfoWrapperPass()
+ : FunctionPass(ID) {
+ initializeBlockFrequencyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+BlockFrequencyInfoWrapperPass::~BlockFrequencyInfoWrapperPass() = default;
+
+void BlockFrequencyInfoWrapperPass::print(raw_ostream &OS,
+ const Module *) const {
+ BFI.print(OS);
+}
+
+void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<BranchProbabilityInfoWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.setPreservesAll();
+}
+
+void BlockFrequencyInfoWrapperPass::releaseMemory() { BFI.releaseMemory(); }
+
+bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) {
+ BranchProbabilityInfo &BPI =
+ getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ BFI.calculate(F, BPI, LI);
+ return false;
+}
+
+AnalysisKey BlockFrequencyAnalysis::Key;
+BlockFrequencyInfo BlockFrequencyAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ BlockFrequencyInfo BFI;
+ BFI.calculate(F, AM.getResult<BranchProbabilityAnalysis>(F),
+ AM.getResult<LoopAnalysis>(F));
+ return BFI;
+}
+
+PreservedAnalyses
+BlockFrequencyPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ OS << "Printing analysis results of BFI for function "
+ << "'" << F.getName() << "':"
+ << "\n";
+ AM.getResult<BlockFrequencyAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
new file mode 100644
index 000000000000..0db6dd04a7e8
--- /dev/null
+++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
@@ -0,0 +1,851 @@
+//===- BlockFrequencyImplInfo.cpp - Block Frequency Info Implementation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Loops should be simplified before this analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/BlockFrequency.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ScaledNumber.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <list>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+using namespace llvm::bfi_detail;
+
+#define DEBUG_TYPE "block-freq"
+
+ScaledNumber<uint64_t> BlockMass::toScaled() const {
+ if (isFull())
+ return ScaledNumber<uint64_t>(1, 0);
+ return ScaledNumber<uint64_t>(getMass() + 1, -64);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void BlockMass::dump() const { print(dbgs()); }
+#endif
+
+static char getHexDigit(int N) {
+ assert(N < 16);
+ if (N < 10)
+ return '0' + N;
+ return 'a' + N - 10;
+}
+
+raw_ostream &BlockMass::print(raw_ostream &OS) const {
+ for (int Digits = 0; Digits < 16; ++Digits)
+ OS << getHexDigit(Mass >> (60 - Digits * 4) & 0xf);
+ return OS;
+}
+
+namespace {
+
+using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
+using Distribution = BlockFrequencyInfoImplBase::Distribution;
+using WeightList = BlockFrequencyInfoImplBase::Distribution::WeightList;
+using Scaled64 = BlockFrequencyInfoImplBase::Scaled64;
+using LoopData = BlockFrequencyInfoImplBase::LoopData;
+using Weight = BlockFrequencyInfoImplBase::Weight;
+using FrequencyData = BlockFrequencyInfoImplBase::FrequencyData;
+
+/// Dithering mass distributer.
+///
+/// This class splits up a single mass into portions by weight, dithering to
+/// spread out error. No mass is lost. The dithering precision depends on the
+/// precision of the product of \a BlockMass and \a BranchProbability.
+///
+/// The distribution algorithm follows.
+///
+/// 1. Initialize by saving the sum of the weights in \a RemWeight and the
+/// mass to distribute in \a RemMass.
+///
+/// 2. For each portion:
+///
+/// 1. Construct a branch probability, P, as the portion's weight divided
+/// by the current value of \a RemWeight.
+/// 2. Calculate the portion's mass as \a RemMass times P.
+/// 3. Update \a RemWeight and \a RemMass at each portion by subtracting
+/// the current portion's weight and mass.
+struct DitheringDistributer {
+ uint32_t RemWeight;
+ BlockMass RemMass;
+
+ DitheringDistributer(Distribution &Dist, const BlockMass &Mass);
+
+ BlockMass takeMass(uint32_t Weight);
+};
+
+} // end anonymous namespace
+
+DitheringDistributer::DitheringDistributer(Distribution &Dist,
+ const BlockMass &Mass) {
+ Dist.normalize();
+ RemWeight = Dist.Total;
+ RemMass = Mass;
+}
+
+BlockMass DitheringDistributer::takeMass(uint32_t Weight) {
+ assert(Weight && "invalid weight");
+ assert(Weight <= RemWeight);
+ BlockMass Mass = RemMass * BranchProbability(Weight, RemWeight);
+
+ // Decrement totals (dither).
+ RemWeight -= Weight;
+ RemMass -= Mass;
+ return Mass;
+}
+
+void Distribution::add(const BlockNode &Node, uint64_t Amount,
+ Weight::DistType Type) {
+ assert(Amount && "invalid weight of 0");
+ uint64_t NewTotal = Total + Amount;
+
+ // Check for overflow. It should be impossible to overflow twice.
+ bool IsOverflow = NewTotal < Total;
+ assert(!(DidOverflow && IsOverflow) && "unexpected repeated overflow");
+ DidOverflow |= IsOverflow;
+
+ // Update the total.
+ Total = NewTotal;
+
+ // Save the weight.
+ Weights.push_back(Weight(Type, Node, Amount));
+}
+
+static void combineWeight(Weight &W, const Weight &OtherW) {
+ assert(OtherW.TargetNode.isValid());
+ if (!W.Amount) {
+ W = OtherW;
+ return;
+ }
+ assert(W.Type == OtherW.Type);
+ assert(W.TargetNode == OtherW.TargetNode);
+ assert(OtherW.Amount && "Expected non-zero weight");
+ if (W.Amount > W.Amount + OtherW.Amount)
+ // Saturate on overflow.
+ W.Amount = UINT64_MAX;
+ else
+ W.Amount += OtherW.Amount;
+}
+
+static void combineWeightsBySorting(WeightList &Weights) {
+ // Sort so edges to the same node are adjacent.
+ llvm::sort(Weights, [](const Weight &L, const Weight &R) {
+ return L.TargetNode < R.TargetNode;
+ });
+
+ // Combine adjacent edges.
+ WeightList::iterator O = Weights.begin();
+ for (WeightList::const_iterator I = O, L = O, E = Weights.end(); I != E;
+ ++O, (I = L)) {
+ *O = *I;
+
+ // Find the adjacent weights to the same node.
+ for (++L; L != E && I->TargetNode == L->TargetNode; ++L)
+ combineWeight(*O, *L);
+ }
+
+ // Erase extra entries.
+ Weights.erase(O, Weights.end());
+}
+
+static void combineWeightsByHashing(WeightList &Weights) {
+ // Collect weights into a DenseMap.
+ using HashTable = DenseMap<BlockNode::IndexType, Weight>;
+
+ HashTable Combined(NextPowerOf2(2 * Weights.size()));
+ for (const Weight &W : Weights)
+ combineWeight(Combined[W.TargetNode.Index], W);
+
+ // Check whether anything changed.
+ if (Weights.size() == Combined.size())
+ return;
+
+ // Fill in the new weights.
+ Weights.clear();
+ Weights.reserve(Combined.size());
+ for (const auto &I : Combined)
+ Weights.push_back(I.second);
+}
+
+static void combineWeights(WeightList &Weights) {
+ // Use a hash table for many successors to keep this linear.
+ if (Weights.size() > 128) {
+ combineWeightsByHashing(Weights);
+ return;
+ }
+
+ combineWeightsBySorting(Weights);
+}
+
+static uint64_t shiftRightAndRound(uint64_t N, int Shift) {
+ assert(Shift >= 0);
+ assert(Shift < 64);
+ if (!Shift)
+ return N;
+ return (N >> Shift) + (UINT64_C(1) & N >> (Shift - 1));
+}
+
+void Distribution::normalize() {
+ // Early exit for termination nodes.
+ if (Weights.empty())
+ return;
+
+ // Only bother if there are multiple successors.
+ if (Weights.size() > 1)
+ combineWeights(Weights);
+
+ // Early exit when combined into a single successor.
+ if (Weights.size() == 1) {
+ Total = 1;
+ Weights.front().Amount = 1;
+ return;
+ }
+
+ // Determine how much to shift right so that the total fits into 32-bits.
+ //
+ // If we shift at all, shift by 1 extra. Otherwise, the lower limit of 1
+ // for each weight can cause a 32-bit overflow.
+ int Shift = 0;
+ if (DidOverflow)
+ Shift = 33;
+ else if (Total > UINT32_MAX)
+ Shift = 33 - countLeadingZeros(Total);
+
+ // Early exit if nothing needs to be scaled.
+ if (!Shift) {
+ // If we didn't overflow then combineWeights() shouldn't have changed the
+ // sum of the weights, but let's double-check.
+ assert(Total == std::accumulate(Weights.begin(), Weights.end(), UINT64_C(0),
+ [](uint64_t Sum, const Weight &W) {
+ return Sum + W.Amount;
+ }) &&
+ "Expected total to be correct");
+ return;
+ }
+
+ // Recompute the total through accumulation (rather than shifting it) so that
+ // it's accurate after shifting and any changes combineWeights() made above.
+ Total = 0;
+
+ // Sum the weights to each node and shift right if necessary.
+ for (Weight &W : Weights) {
+ // Scale down below UINT32_MAX. Since Shift is larger than necessary, we
+ // can round here without concern about overflow.
+ assert(W.TargetNode.isValid());
+ W.Amount = std::max(UINT64_C(1), shiftRightAndRound(W.Amount, Shift));
+ assert(W.Amount <= UINT32_MAX);
+
+ // Update the total.
+ Total += W.Amount;
+ }
+ assert(Total <= UINT32_MAX);
+}
+
+void BlockFrequencyInfoImplBase::clear() {
+ // Swap with a default-constructed std::vector, since std::vector<>::clear()
+ // does not actually clear heap storage.
+ std::vector<FrequencyData>().swap(Freqs);
+ IsIrrLoopHeader.clear();
+ std::vector<WorkingData>().swap(Working);
+ Loops.clear();
+}
+
+/// Clear all memory not needed downstream.
+///
+/// Releases all memory not used downstream. In particular, saves Freqs.
+static void cleanup(BlockFrequencyInfoImplBase &BFI) {
+ std::vector<FrequencyData> SavedFreqs(std::move(BFI.Freqs));
+ SparseBitVector<> SavedIsIrrLoopHeader(std::move(BFI.IsIrrLoopHeader));
+ BFI.clear();
+ BFI.Freqs = std::move(SavedFreqs);
+ BFI.IsIrrLoopHeader = std::move(SavedIsIrrLoopHeader);
+}
+
+bool BlockFrequencyInfoImplBase::addToDist(Distribution &Dist,
+ const LoopData *OuterLoop,
+ const BlockNode &Pred,
+ const BlockNode &Succ,
+ uint64_t Weight) {
+ if (!Weight)
+ Weight = 1;
+
+ auto isLoopHeader = [&OuterLoop](const BlockNode &Node) {
+ return OuterLoop && OuterLoop->isHeader(Node);
+ };
+
+ BlockNode Resolved = Working[Succ.Index].getResolvedNode();
+
+#ifndef NDEBUG
+ auto debugSuccessor = [&](const char *Type) {
+ dbgs() << " =>"
+ << " [" << Type << "] weight = " << Weight;
+ if (!isLoopHeader(Resolved))
+ dbgs() << ", succ = " << getBlockName(Succ);
+ if (Resolved != Succ)
+ dbgs() << ", resolved = " << getBlockName(Resolved);
+ dbgs() << "\n";
+ };
+ (void)debugSuccessor;
+#endif
+
+ if (isLoopHeader(Resolved)) {
+ LLVM_DEBUG(debugSuccessor("backedge"));
+ Dist.addBackedge(Resolved, Weight);
+ return true;
+ }
+
+ if (Working[Resolved.Index].getContainingLoop() != OuterLoop) {
+ LLVM_DEBUG(debugSuccessor(" exit "));
+ Dist.addExit(Resolved, Weight);
+ return true;
+ }
+
+ if (Resolved < Pred) {
+ if (!isLoopHeader(Pred)) {
+ // If OuterLoop is an irreducible loop, we can't actually handle this.
+ assert((!OuterLoop || !OuterLoop->isIrreducible()) &&
+ "unhandled irreducible control flow");
+
+ // Irreducible backedge. Abort.
+ LLVM_DEBUG(debugSuccessor("abort!!!"));
+ return false;
+ }
+
+ // If "Pred" is a loop header, then this isn't really a backedge; rather,
+ // OuterLoop must be irreducible. These false backedges can come only from
+ // secondary loop headers.
+ assert(OuterLoop && OuterLoop->isIrreducible() && !isLoopHeader(Resolved) &&
+ "unhandled irreducible control flow");
+ }
+
+ LLVM_DEBUG(debugSuccessor(" local "));
+ Dist.addLocal(Resolved, Weight);
+ return true;
+}
+
+bool BlockFrequencyInfoImplBase::addLoopSuccessorsToDist(
+ const LoopData *OuterLoop, LoopData &Loop, Distribution &Dist) {
+ // Copy the exit map into Dist.
+ for (const auto &I : Loop.Exits)
+ if (!addToDist(Dist, OuterLoop, Loop.getHeader(), I.first,
+ I.second.getMass()))
+ // Irreducible backedge.
+ return false;
+
+ return true;
+}
+
+/// Compute the loop scale for a loop.
+void BlockFrequencyInfoImplBase::computeLoopScale(LoopData &Loop) {
+ // Compute loop scale.
+ LLVM_DEBUG(dbgs() << "compute-loop-scale: " << getLoopName(Loop) << "\n");
+
+ // Infinite loops need special handling. If we give the back edge an infinite
+ // mass, they may saturate all the other scales in the function down to 1,
+ // making all the other region temperatures look exactly the same. Choose an
+ // arbitrary scale to avoid these issues.
+ //
+ // FIXME: An alternate way would be to select a symbolic scale which is later
+ // replaced to be the maximum of all computed scales plus 1. This would
+ // appropriately describe the loop as having a large scale, without skewing
+ // the final frequency computation.
+ const Scaled64 InfiniteLoopScale(1, 12);
+
+ // LoopScale == 1 / ExitMass
+ // ExitMass == HeadMass - BackedgeMass
+ BlockMass TotalBackedgeMass;
+ for (auto &Mass : Loop.BackedgeMass)
+ TotalBackedgeMass += Mass;
+ BlockMass ExitMass = BlockMass::getFull() - TotalBackedgeMass;
+
+ // Block scale stores the inverse of the scale. If this is an infinite loop,
+ // its exit mass will be zero. In this case, use an arbitrary scale for the
+ // loop scale.
+ Loop.Scale =
+ ExitMass.isEmpty() ? InfiniteLoopScale : ExitMass.toScaled().inverse();
+
+ LLVM_DEBUG(dbgs() << " - exit-mass = " << ExitMass << " ("
+ << BlockMass::getFull() << " - " << TotalBackedgeMass
+ << ")\n"
+ << " - scale = " << Loop.Scale << "\n");
+}
+
+/// Package up a loop.
+void BlockFrequencyInfoImplBase::packageLoop(LoopData &Loop) {
+ LLVM_DEBUG(dbgs() << "packaging-loop: " << getLoopName(Loop) << "\n");
+
+ // Clear the subloop exits to prevent quadratic memory usage.
+ for (const BlockNode &M : Loop.Nodes) {
+ if (auto *Loop = Working[M.Index].getPackagedLoop())
+ Loop->Exits.clear();
+ LLVM_DEBUG(dbgs() << " - node: " << getBlockName(M.Index) << "\n");
+ }
+ Loop.IsPackaged = true;
+}
+
+#ifndef NDEBUG
+static void debugAssign(const BlockFrequencyInfoImplBase &BFI,
+ const DitheringDistributer &D, const BlockNode &T,
+ const BlockMass &M, const char *Desc) {
+ dbgs() << " => assign " << M << " (" << D.RemMass << ")";
+ if (Desc)
+ dbgs() << " [" << Desc << "]";
+ if (T.isValid())
+ dbgs() << " to " << BFI.getBlockName(T);
+ dbgs() << "\n";
+}
+#endif
+
+void BlockFrequencyInfoImplBase::distributeMass(const BlockNode &Source,
+ LoopData *OuterLoop,
+ Distribution &Dist) {
+ BlockMass Mass = Working[Source.Index].getMass();
+ LLVM_DEBUG(dbgs() << " => mass: " << Mass << "\n");
+
+ // Distribute mass to successors as laid out in Dist.
+ DitheringDistributer D(Dist, Mass);
+
+ for (const Weight &W : Dist.Weights) {
+ // Check for a local edge (non-backedge and non-exit).
+ BlockMass Taken = D.takeMass(W.Amount);
+ if (W.Type == Weight::Local) {
+ Working[W.TargetNode.Index].getMass() += Taken;
+ LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr));
+ continue;
+ }
+
+ // Backedges and exits only make sense if we're processing a loop.
+ assert(OuterLoop && "backedge or exit outside of loop");
+
+ // Check for a backedge.
+ if (W.Type == Weight::Backedge) {
+ OuterLoop->BackedgeMass[OuterLoop->getHeaderIndex(W.TargetNode)] += Taken;
+ LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "back"));
+ continue;
+ }
+
+ // This must be an exit.
+ assert(W.Type == Weight::Exit);
+ OuterLoop->Exits.push_back(std::make_pair(W.TargetNode, Taken));
+ LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "exit"));
+ }
+}
+
+static void convertFloatingToInteger(BlockFrequencyInfoImplBase &BFI,
+ const Scaled64 &Min, const Scaled64 &Max) {
+ // Scale the Factor to a size that creates integers. Ideally, integers would
+ // be scaled so that Max == UINT64_MAX so that they can be best
+ // differentiated. However, in the presence of large frequency values, small
+ // frequencies are scaled down to 1, making it impossible to differentiate
+ // small, unequal numbers. When the spread between Min and Max frequencies
+ // fits well within MaxBits, we make the scale be at least 8.
+ const unsigned MaxBits = 64;
+ const unsigned SpreadBits = (Max / Min).lg();
+ Scaled64 ScalingFactor;
+ if (SpreadBits <= MaxBits - 3) {
+ // If the values are small enough, make the scaling factor at least 8 to
+ // allow distinguishing small values.
+ ScalingFactor = Min.inverse();
+ ScalingFactor <<= 3;
+ } else {
+ // If the values need more than MaxBits to be represented, saturate small
+ // frequency values down to 1 by using a scaling factor that benefits large
+ // frequency values.
+ ScalingFactor = Scaled64(1, MaxBits) / Max;
+ }
+
+ // Translate the floats to integers.
+ LLVM_DEBUG(dbgs() << "float-to-int: min = " << Min << ", max = " << Max
+ << ", factor = " << ScalingFactor << "\n");
+ for (size_t Index = 0; Index < BFI.Freqs.size(); ++Index) {
+ Scaled64 Scaled = BFI.Freqs[Index].Scaled * ScalingFactor;
+ BFI.Freqs[Index].Integer = std::max(UINT64_C(1), Scaled.toInt<uint64_t>());
+ LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(Index) << ": float = "
+ << BFI.Freqs[Index].Scaled << ", scaled = " << Scaled
+ << ", int = " << BFI.Freqs[Index].Integer << "\n");
+ }
+}
+
+/// Unwrap a loop package.
+///
+/// Visits all the members of a loop, adjusting their BlockData according to
+/// the loop's pseudo-node.
+static void unwrapLoop(BlockFrequencyInfoImplBase &BFI, LoopData &Loop) {
+ LLVM_DEBUG(dbgs() << "unwrap-loop-package: " << BFI.getLoopName(Loop)
+ << ": mass = " << Loop.Mass << ", scale = " << Loop.Scale
+ << "\n");
+ Loop.Scale *= Loop.Mass.toScaled();
+ Loop.IsPackaged = false;
+ LLVM_DEBUG(dbgs() << " => combined-scale = " << Loop.Scale << "\n");
+
+ // Propagate the head scale through the loop. Since members are visited in
+ // RPO, the head scale will be updated by the loop scale first, and then the
+ // final head scale will be used for updated the rest of the members.
+ for (const BlockNode &N : Loop.Nodes) {
+ const auto &Working = BFI.Working[N.Index];
+ Scaled64 &F = Working.isAPackage() ? Working.getPackagedLoop()->Scale
+ : BFI.Freqs[N.Index].Scaled;
+ Scaled64 New = Loop.Scale * F;
+ LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(N) << ": " << F << " => "
+ << New << "\n");
+ F = New;
+ }
+}
+
+void BlockFrequencyInfoImplBase::unwrapLoops() {
+ // Set initial frequencies from loop-local masses.
+ for (size_t Index = 0; Index < Working.size(); ++Index)
+ Freqs[Index].Scaled = Working[Index].Mass.toScaled();
+
+ for (LoopData &Loop : Loops)
+ unwrapLoop(*this, Loop);
+}
+
+void BlockFrequencyInfoImplBase::finalizeMetrics() {
+ // Unwrap loop packages in reverse post-order, tracking min and max
+ // frequencies.
+ auto Min = Scaled64::getLargest();
+ auto Max = Scaled64::getZero();
+ for (size_t Index = 0; Index < Working.size(); ++Index) {
+ // Update min/max scale.
+ Min = std::min(Min, Freqs[Index].Scaled);
+ Max = std::max(Max, Freqs[Index].Scaled);
+ }
+
+ // Convert to integers.
+ convertFloatingToInteger(*this, Min, Max);
+
+ // Clean up data structures.
+ cleanup(*this);
+
+ // Print out the final stats.
+ LLVM_DEBUG(dump());
+}
+
+BlockFrequency
+BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const {
+ if (!Node.isValid())
+ return 0;
+ return Freqs[Node.Index].Integer;
+}
+
+Optional<uint64_t>
+BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
+ const BlockNode &Node,
+ bool AllowSynthetic) const {
+ return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency(),
+ AllowSynthetic);
+}
+
+Optional<uint64_t>
+BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
+ uint64_t Freq,
+ bool AllowSynthetic) const {
+ auto EntryCount = F.getEntryCount(AllowSynthetic);
+ if (!EntryCount)
+ return None;
+ // Use 128 bit APInt to do the arithmetic to avoid overflow.
+ APInt BlockCount(128, EntryCount.getCount());
+ APInt BlockFreq(128, Freq);
+ APInt EntryFreq(128, getEntryFreq());
+ BlockCount *= BlockFreq;
+ // Rounded division of BlockCount by EntryFreq. Since EntryFreq is unsigned
+ // lshr by 1 gives EntryFreq/2.
+ BlockCount = (BlockCount + EntryFreq.lshr(1)).udiv(EntryFreq);
+ return BlockCount.getLimitedValue();
+}
+
+bool
+BlockFrequencyInfoImplBase::isIrrLoopHeader(const BlockNode &Node) {
+ if (!Node.isValid())
+ return false;
+ return IsIrrLoopHeader.test(Node.Index);
+}
+
+Scaled64
+BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const {
+ if (!Node.isValid())
+ return Scaled64::getZero();
+ return Freqs[Node.Index].Scaled;
+}
+
+void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node,
+ uint64_t Freq) {
+ assert(Node.isValid() && "Expected valid node");
+ assert(Node.Index < Freqs.size() && "Expected legal index");
+ Freqs[Node.Index].Integer = Freq;
+}
+
+std::string
+BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const {
+ return {};
+}
+
+std::string
+BlockFrequencyInfoImplBase::getLoopName(const LoopData &Loop) const {
+ return getBlockName(Loop.getHeader()) + (Loop.isIrreducible() ? "**" : "*");
+}
+
+raw_ostream &
+BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,
+ const BlockNode &Node) const {
+ return OS << getFloatingBlockFreq(Node);
+}
+
+raw_ostream &
+BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,
+ const BlockFrequency &Freq) const {
+ Scaled64 Block(Freq.getFrequency(), 0);
+ Scaled64 Entry(getEntryFreq(), 0);
+
+ return OS << Block / Entry;
+}
+
+void IrreducibleGraph::addNodesInLoop(const BFIBase::LoopData &OuterLoop) {
+ Start = OuterLoop.getHeader();
+ Nodes.reserve(OuterLoop.Nodes.size());
+ for (auto N : OuterLoop.Nodes)
+ addNode(N);
+ indexNodes();
+}
+
+void IrreducibleGraph::addNodesInFunction() {
+ Start = 0;
+ for (uint32_t Index = 0; Index < BFI.Working.size(); ++Index)
+ if (!BFI.Working[Index].isPackaged())
+ addNode(Index);
+ indexNodes();
+}
+
+void IrreducibleGraph::indexNodes() {
+ for (auto &I : Nodes)
+ Lookup[I.Node.Index] = &I;
+}
+
+void IrreducibleGraph::addEdge(IrrNode &Irr, const BlockNode &Succ,
+ const BFIBase::LoopData *OuterLoop) {
+ if (OuterLoop && OuterLoop->isHeader(Succ))
+ return;
+ auto L = Lookup.find(Succ.Index);
+ if (L == Lookup.end())
+ return;
+ IrrNode &SuccIrr = *L->second;
+ Irr.Edges.push_back(&SuccIrr);
+ SuccIrr.Edges.push_front(&Irr);
+ ++SuccIrr.NumIn;
+}
+
+namespace llvm {
+
+template <> struct GraphTraits<IrreducibleGraph> {
+ using GraphT = bfi_detail::IrreducibleGraph;
+ using NodeRef = const GraphT::IrrNode *;
+ using ChildIteratorType = GraphT::IrrNode::iterator;
+
+ static NodeRef getEntryNode(const GraphT &G) { return G.StartIrr; }
+ static ChildIteratorType child_begin(NodeRef N) { return N->succ_begin(); }
+ static ChildIteratorType child_end(NodeRef N) { return N->succ_end(); }
+};
+
+} // end namespace llvm
+
+/// Find extra irreducible headers.
+///
+/// Find entry blocks and other blocks with backedges, which exist when \c G
+/// contains irreducible sub-SCCs.
+static void findIrreducibleHeaders(
+ const BlockFrequencyInfoImplBase &BFI,
+ const IrreducibleGraph &G,
+ const std::vector<const IrreducibleGraph::IrrNode *> &SCC,
+ LoopData::NodeList &Headers, LoopData::NodeList &Others) {
+ // Map from nodes in the SCC to whether it's an entry block.
+ SmallDenseMap<const IrreducibleGraph::IrrNode *, bool, 8> InSCC;
+
+ // InSCC also acts the set of nodes in the graph. Seed it.
+ for (const auto *I : SCC)
+ InSCC[I] = false;
+
+ for (auto I = InSCC.begin(), E = InSCC.end(); I != E; ++I) {
+ auto &Irr = *I->first;
+ for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) {
+ if (InSCC.count(P))
+ continue;
+
+ // This is an entry block.
+ I->second = true;
+ Headers.push_back(Irr.Node);
+ LLVM_DEBUG(dbgs() << " => entry = " << BFI.getBlockName(Irr.Node)
+ << "\n");
+ break;
+ }
+ }
+ assert(Headers.size() >= 2 &&
+ "Expected irreducible CFG; -loop-info is likely invalid");
+ if (Headers.size() == InSCC.size()) {
+ // Every block is a header.
+ llvm::sort(Headers);
+ return;
+ }
+
+ // Look for extra headers from irreducible sub-SCCs.
+ for (const auto &I : InSCC) {
+ // Entry blocks are already headers.
+ if (I.second)
+ continue;
+
+ auto &Irr = *I.first;
+ for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) {
+ // Skip forward edges.
+ if (P->Node < Irr.Node)
+ continue;
+
+ // Skip predecessors from entry blocks. These can have inverted
+ // ordering.
+ if (InSCC.lookup(P))
+ continue;
+
+ // Store the extra header.
+ Headers.push_back(Irr.Node);
+ LLVM_DEBUG(dbgs() << " => extra = " << BFI.getBlockName(Irr.Node)
+ << "\n");
+ break;
+ }
+ if (Headers.back() == Irr.Node)
+ // Added this as a header.
+ continue;
+
+ // This is not a header.
+ Others.push_back(Irr.Node);
+ LLVM_DEBUG(dbgs() << " => other = " << BFI.getBlockName(Irr.Node) << "\n");
+ }
+ llvm::sort(Headers);
+ llvm::sort(Others);
+}
+
+static void createIrreducibleLoop(
+ BlockFrequencyInfoImplBase &BFI, const IrreducibleGraph &G,
+ LoopData *OuterLoop, std::list<LoopData>::iterator Insert,
+ const std::vector<const IrreducibleGraph::IrrNode *> &SCC) {
+ // Translate the SCC into RPO.
+ LLVM_DEBUG(dbgs() << " - found-scc\n");
+
+ LoopData::NodeList Headers;
+ LoopData::NodeList Others;
+ findIrreducibleHeaders(BFI, G, SCC, Headers, Others);
+
+ auto Loop = BFI.Loops.emplace(Insert, OuterLoop, Headers.begin(),
+ Headers.end(), Others.begin(), Others.end());
+
+ // Update loop hierarchy.
+ for (const auto &N : Loop->Nodes)
+ if (BFI.Working[N.Index].isLoopHeader())
+ BFI.Working[N.Index].Loop->Parent = &*Loop;
+ else
+ BFI.Working[N.Index].Loop = &*Loop;
+}
+
+iterator_range<std::list<LoopData>::iterator>
+BlockFrequencyInfoImplBase::analyzeIrreducible(
+ const IrreducibleGraph &G, LoopData *OuterLoop,
+ std::list<LoopData>::iterator Insert) {
+ assert((OuterLoop == nullptr) == (Insert == Loops.begin()));
+ auto Prev = OuterLoop ? std::prev(Insert) : Loops.end();
+
+ for (auto I = scc_begin(G); !I.isAtEnd(); ++I) {
+ if (I->size() < 2)
+ continue;
+
+ // Translate the SCC into RPO.
+ createIrreducibleLoop(*this, G, OuterLoop, Insert, *I);
+ }
+
+ if (OuterLoop)
+ return make_range(std::next(Prev), Insert);
+ return make_range(Loops.begin(), Insert);
+}
+
+void
+BlockFrequencyInfoImplBase::updateLoopWithIrreducible(LoopData &OuterLoop) {
+ OuterLoop.Exits.clear();
+ for (auto &Mass : OuterLoop.BackedgeMass)
+ Mass = BlockMass::getEmpty();
+ auto O = OuterLoop.Nodes.begin() + 1;
+ for (auto I = O, E = OuterLoop.Nodes.end(); I != E; ++I)
+ if (!Working[I->Index].isPackaged())
+ *O++ = *I;
+ OuterLoop.Nodes.erase(O, OuterLoop.Nodes.end());
+}
+
+void BlockFrequencyInfoImplBase::adjustLoopHeaderMass(LoopData &Loop) {
+ assert(Loop.isIrreducible() && "this only makes sense on irreducible loops");
+
+ // Since the loop has more than one header block, the mass flowing back into
+ // each header will be different. Adjust the mass in each header loop to
+ // reflect the masses flowing through back edges.
+ //
+ // To do this, we distribute the initial mass using the backedge masses
+ // as weights for the distribution.
+ BlockMass LoopMass = BlockMass::getFull();
+ Distribution Dist;
+
+ LLVM_DEBUG(dbgs() << "adjust-loop-header-mass:\n");
+ for (uint32_t H = 0; H < Loop.NumHeaders; ++H) {
+ auto &HeaderNode = Loop.Nodes[H];
+ auto &BackedgeMass = Loop.BackedgeMass[Loop.getHeaderIndex(HeaderNode)];
+ LLVM_DEBUG(dbgs() << " - Add back edge mass for node "
+ << getBlockName(HeaderNode) << ": " << BackedgeMass
+ << "\n");
+ if (BackedgeMass.getMass() > 0)
+ Dist.addLocal(HeaderNode, BackedgeMass.getMass());
+ else
+ LLVM_DEBUG(dbgs() << " Nothing added. Back edge mass is zero\n");
+ }
+
+ DitheringDistributer D(Dist, LoopMass);
+
+ LLVM_DEBUG(dbgs() << " Distribute loop mass " << LoopMass
+ << " to headers using above weights\n");
+ for (const Weight &W : Dist.Weights) {
+ BlockMass Taken = D.takeMass(W.Amount);
+ assert(W.Type == Weight::Local && "all weights should be local");
+ Working[W.TargetNode.Index].getMass() = Taken;
+ LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr));
+ }
+}
+
+void BlockFrequencyInfoImplBase::distributeIrrLoopHeaderMass(Distribution &Dist) {
+ BlockMass LoopMass = BlockMass::getFull();
+ DitheringDistributer D(Dist, LoopMass);
+ for (const Weight &W : Dist.Weights) {
+ BlockMass Taken = D.takeMass(W.Amount);
+ assert(W.Type == Weight::Local && "all weights should be local");
+ Working[W.TargetNode.Index].getMass() = Taken;
+ LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr));
+ }
+}
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
new file mode 100644
index 000000000000..a06ee096d54c
--- /dev/null
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -0,0 +1,1057 @@
+//===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Loops should be simplified before this analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "branch-prob"
+
+static cl::opt<bool> PrintBranchProb(
+ "print-bpi", cl::init(false), cl::Hidden,
+ cl::desc("Print the branch probability info."));
+
+cl::opt<std::string> PrintBranchProbFuncName(
+ "print-bpi-func-name", cl::Hidden,
+ cl::desc("The option to specify the name of the function "
+ "whose branch probability info is printed."));
+
+INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
+ "Branch Probability Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
+ "Branch Probability Analysis", false, true)
+
+char BranchProbabilityInfoWrapperPass::ID = 0;
+
+// Weights are for internal use only. They are used by heuristics to help to
+// estimate edges' probability. Example:
+//
+// Using "Loop Branch Heuristics" we predict weights of edges for the
+// block BB2.
+// ...
+// |
+// V
+// BB1<-+
+// | |
+// | | (Weight = 124)
+// V |
+// BB2--+
+// |
+// | (Weight = 4)
+// V
+// BB3
+//
+// Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
+// Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
+static const uint32_t LBH_TAKEN_WEIGHT = 124;
+static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
+// Unlikely edges within a loop are half as likely as other edges
+static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
+
+/// Unreachable-terminating branch taken probability.
+///
+/// This is the probability for a branch being taken to a block that terminates
+/// (eventually) in unreachable. These are predicted as unlikely as possible.
+/// All reachable probability will equally share the remaining part.
+static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
+
+/// Weight for a branch taken going into a cold block.
+///
+/// This is the weight for a branch taken toward a block marked
+/// cold. A block is marked cold if it's postdominated by a
+/// block containing a call to a cold function. Cold functions
+/// are those marked with attribute 'cold'.
+static const uint32_t CC_TAKEN_WEIGHT = 4;
+
+/// Weight for a branch not-taken into a cold block.
+///
+/// This is the weight for a branch not taken toward a block marked
+/// cold.
+static const uint32_t CC_NONTAKEN_WEIGHT = 64;
+
+static const uint32_t PH_TAKEN_WEIGHT = 20;
+static const uint32_t PH_NONTAKEN_WEIGHT = 12;
+
+static const uint32_t ZH_TAKEN_WEIGHT = 20;
+static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
+
+static const uint32_t FPH_TAKEN_WEIGHT = 20;
+static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
+
+/// This is the probability for an ordered floating point comparison.
+static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
+/// This is the probability for an unordered floating point comparison, it means
+/// one or two of the operands are NaN. Usually it is used to test for an
+/// exceptional case, so the result is unlikely.
+static const uint32_t FPH_UNO_WEIGHT = 1;
+
+/// Invoke-terminating normal branch taken weight
+///
+/// This is the weight for branching to the normal destination of an invoke
+/// instruction. We expect this to happen most of the time. Set the weight to an
+/// absurdly high value so that nested loops subsume it.
+static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
+
+/// Invoke-terminating normal branch not-taken weight.
+///
+/// This is the weight for branching to the unwind destination of an invoke
+/// instruction. This is essentially never taken.
+static const uint32_t IH_NONTAKEN_WEIGHT = 1;
+
+/// Add \p BB to PostDominatedByUnreachable set if applicable.
+void
+BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
+ const Instruction *TI = BB->getTerminator();
+ if (TI->getNumSuccessors() == 0) {
+ if (isa<UnreachableInst>(TI) ||
+ // If this block is terminated by a call to
+ // @llvm.experimental.deoptimize then treat it like an unreachable since
+ // the @llvm.experimental.deoptimize call is expected to practically
+ // never execute.
+ BB->getTerminatingDeoptimizeCall())
+ PostDominatedByUnreachable.insert(BB);
+ return;
+ }
+
+ // If the terminator is an InvokeInst, check only the normal destination block
+ // as the unwind edge of InvokeInst is also very unlikely taken.
+ if (auto *II = dyn_cast<InvokeInst>(TI)) {
+ if (PostDominatedByUnreachable.count(II->getNormalDest()))
+ PostDominatedByUnreachable.insert(BB);
+ return;
+ }
+
+ for (auto *I : successors(BB))
+ // If any of successor is not post dominated then BB is also not.
+ if (!PostDominatedByUnreachable.count(I))
+ return;
+
+ PostDominatedByUnreachable.insert(BB);
+}
+
+/// Add \p BB to PostDominatedByColdCall set if applicable.
+void
+BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
+ assert(!PostDominatedByColdCall.count(BB));
+ const Instruction *TI = BB->getTerminator();
+ if (TI->getNumSuccessors() == 0)
+ return;
+
+ // If all of successor are post dominated then BB is also done.
+ if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
+ return PostDominatedByColdCall.count(SuccBB);
+ })) {
+ PostDominatedByColdCall.insert(BB);
+ return;
+ }
+
+ // If the terminator is an InvokeInst, check only the normal destination
+ // block as the unwind edge of InvokeInst is also very unlikely taken.
+ if (auto *II = dyn_cast<InvokeInst>(TI))
+ if (PostDominatedByColdCall.count(II->getNormalDest())) {
+ PostDominatedByColdCall.insert(BB);
+ return;
+ }
+
+ // Otherwise, if the block itself contains a cold function, add it to the
+ // set of blocks post-dominated by a cold call.
+ for (auto &I : *BB)
+ if (const CallInst *CI = dyn_cast<CallInst>(&I))
+ if (CI->hasFnAttr(Attribute::Cold)) {
+ PostDominatedByColdCall.insert(BB);
+ return;
+ }
+}
+
+/// Calculate edge weights for successors lead to unreachable.
+///
+/// Predict that a successor which leads necessarily to an
+/// unreachable-terminated block as extremely unlikely.
+bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
+ const Instruction *TI = BB->getTerminator();
+ (void) TI;
+ assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
+ assert(!isa<InvokeInst>(TI) &&
+ "Invokes should have already been handled by calcInvokeHeuristics");
+
+ SmallVector<unsigned, 4> UnreachableEdges;
+ SmallVector<unsigned, 4> ReachableEdges;
+
+ for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
+ if (PostDominatedByUnreachable.count(*I))
+ UnreachableEdges.push_back(I.getSuccessorIndex());
+ else
+ ReachableEdges.push_back(I.getSuccessorIndex());
+
+ // Skip probabilities if all were reachable.
+ if (UnreachableEdges.empty())
+ return false;
+
+ if (ReachableEdges.empty()) {
+ BranchProbability Prob(1, UnreachableEdges.size());
+ for (unsigned SuccIdx : UnreachableEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ return true;
+ }
+
+ auto UnreachableProb = UR_TAKEN_PROB;
+ auto ReachableProb =
+ (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
+ ReachableEdges.size();
+
+ for (unsigned SuccIdx : UnreachableEdges)
+ setEdgeProbability(BB, SuccIdx, UnreachableProb);
+ for (unsigned SuccIdx : ReachableEdges)
+ setEdgeProbability(BB, SuccIdx, ReachableProb);
+
+ return true;
+}
+
+// Propagate existing explicit probabilities from either profile data or
+// 'expect' intrinsic processing. Examine metadata against unreachable
+// heuristic. The probability of the edge coming to unreachable block is
+// set to min of metadata and unreachable heuristic.
+bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
+ const Instruction *TI = BB->getTerminator();
+ assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
+ if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
+ return false;
+
+ MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
+ if (!WeightsNode)
+ return false;
+
+ // Check that the number of successors is manageable.
+ assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
+
+ // Ensure there are weights for all of the successors. Note that the first
+ // operand to the metadata node is a name, not a weight.
+ if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
+ return false;
+
+ // Build up the final weights that will be used in a temporary buffer.
+ // Compute the sum of all weights to later decide whether they need to
+ // be scaled to fit in 32 bits.
+ uint64_t WeightSum = 0;
+ SmallVector<uint32_t, 2> Weights;
+ SmallVector<unsigned, 2> UnreachableIdxs;
+ SmallVector<unsigned, 2> ReachableIdxs;
+ Weights.reserve(TI->getNumSuccessors());
+ for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
+ ConstantInt *Weight =
+ mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
+ if (!Weight)
+ return false;
+ assert(Weight->getValue().getActiveBits() <= 32 &&
+ "Too many bits for uint32_t");
+ Weights.push_back(Weight->getZExtValue());
+ WeightSum += Weights.back();
+ if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
+ UnreachableIdxs.push_back(i - 1);
+ else
+ ReachableIdxs.push_back(i - 1);
+ }
+ assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
+
+ // If the sum of weights does not fit in 32 bits, scale every weight down
+ // accordingly.
+ uint64_t ScalingFactor =
+ (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
+
+ if (ScalingFactor > 1) {
+ WeightSum = 0;
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+ Weights[i] /= ScalingFactor;
+ WeightSum += Weights[i];
+ }
+ }
+ assert(WeightSum <= UINT32_MAX &&
+ "Expected weights to scale down to 32 bits");
+
+ if (WeightSum == 0 || ReachableIdxs.size() == 0) {
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
+ Weights[i] = 1;
+ WeightSum = TI->getNumSuccessors();
+ }
+
+ // Set the probability.
+ SmallVector<BranchProbability, 2> BP;
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
+ BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
+
+ // Examine the metadata against unreachable heuristic.
+ // If the unreachable heuristic is more strong then we use it for this edge.
+ if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
+ auto ToDistribute = BranchProbability::getZero();
+ auto UnreachableProb = UR_TAKEN_PROB;
+ for (auto i : UnreachableIdxs)
+ if (UnreachableProb < BP[i]) {
+ ToDistribute += BP[i] - UnreachableProb;
+ BP[i] = UnreachableProb;
+ }
+
+ // If we modified the probability of some edges then we must distribute
+ // the difference between reachable blocks.
+ if (ToDistribute > BranchProbability::getZero()) {
+ BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
+ for (auto i : ReachableIdxs)
+ BP[i] += PerEdge;
+ }
+ }
+
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
+ setEdgeProbability(BB, i, BP[i]);
+
+ return true;
+}
+
+/// Calculate edge weights for edges leading to cold blocks.
+///
+/// A cold block is one post-dominated by a block with a call to a
+/// cold function. Those edges are unlikely to be taken, so we give
+/// them relatively low weight.
+///
+/// Return true if we could compute the weights for cold edges.
+/// Return false, otherwise.
+bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
+ const Instruction *TI = BB->getTerminator();
+ (void) TI;
+ assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
+ assert(!isa<InvokeInst>(TI) &&
+ "Invokes should have already been handled by calcInvokeHeuristics");
+
+ // Determine which successors are post-dominated by a cold block.
+ SmallVector<unsigned, 4> ColdEdges;
+ SmallVector<unsigned, 4> NormalEdges;
+ for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
+ if (PostDominatedByColdCall.count(*I))
+ ColdEdges.push_back(I.getSuccessorIndex());
+ else
+ NormalEdges.push_back(I.getSuccessorIndex());
+
+ // Skip probabilities if no cold edges.
+ if (ColdEdges.empty())
+ return false;
+
+ if (NormalEdges.empty()) {
+ BranchProbability Prob(1, ColdEdges.size());
+ for (unsigned SuccIdx : ColdEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ return true;
+ }
+
+ auto ColdProb = BranchProbability::getBranchProbability(
+ CC_TAKEN_WEIGHT,
+ (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
+ auto NormalProb = BranchProbability::getBranchProbability(
+ CC_NONTAKEN_WEIGHT,
+ (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
+
+ for (unsigned SuccIdx : ColdEdges)
+ setEdgeProbability(BB, SuccIdx, ColdProb);
+ for (unsigned SuccIdx : NormalEdges)
+ setEdgeProbability(BB, SuccIdx, NormalProb);
+
+ return true;
+}
+
+// Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
+// between two pointer or pointer and NULL will fail.
+bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
+ const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BI || !BI->isConditional())
+ return false;
+
+ Value *Cond = BI->getCondition();
+ ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
+ if (!CI || !CI->isEquality())
+ return false;
+
+ Value *LHS = CI->getOperand(0);
+
+ if (!LHS->getType()->isPointerTy())
+ return false;
+
+ assert(CI->getOperand(1)->getType()->isPointerTy());
+
+ // p != 0 -> isProb = true
+ // p == 0 -> isProb = false
+ // p != q -> isProb = true
+ // p == q -> isProb = false;
+ unsigned TakenIdx = 0, NonTakenIdx = 1;
+ bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
+ if (!isProb)
+ std::swap(TakenIdx, NonTakenIdx);
+
+ BranchProbability TakenProb(PH_TAKEN_WEIGHT,
+ PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
+ setEdgeProbability(BB, TakenIdx, TakenProb);
+ setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
+ return true;
+}
+
+static int getSCCNum(const BasicBlock *BB,
+ const BranchProbabilityInfo::SccInfo &SccI) {
+ auto SccIt = SccI.SccNums.find(BB);
+ if (SccIt == SccI.SccNums.end())
+ return -1;
+ return SccIt->second;
+}
+
+// Consider any block that is an entry point to the SCC as a header.
+static bool isSCCHeader(const BasicBlock *BB, int SccNum,
+ BranchProbabilityInfo::SccInfo &SccI) {
+ assert(getSCCNum(BB, SccI) == SccNum);
+
+ // Lazily compute the set of headers for a given SCC and cache the results
+ // in the SccHeaderMap.
+ if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
+ SccI.SccHeaders.resize(SccNum + 1);
+ auto &HeaderMap = SccI.SccHeaders[SccNum];
+ bool Inserted;
+ BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
+ std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
+ if (Inserted) {
+ bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
+ [&](const BasicBlock *Pred) {
+ return getSCCNum(Pred, SccI) != SccNum;
+ });
+ HeaderMapIt->second = IsHeader;
+ return IsHeader;
+ } else
+ return HeaderMapIt->second;
+}
+
+// Compute the unlikely successors to the block BB in the loop L, specifically
+// those that are unlikely because this is a loop, and add them to the
+// UnlikelyBlocks set.
+static void
+computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
+ SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
+ // Sometimes in a loop we have a branch whose condition is made false by
+ // taking it. This is typically something like
+ // int n = 0;
+ // while (...) {
+ // if (++n >= MAX) {
+ // n = 0;
+ // }
+ // }
+ // In this sort of situation taking the branch means that at the very least it
+ // won't be taken again in the next iteration of the loop, so we should
+ // consider it less likely than a typical branch.
+ //
+ // We detect this by looking back through the graph of PHI nodes that sets the
+ // value that the condition depends on, and seeing if we can reach a successor
+ // block which can be determined to make the condition false.
+ //
+ // FIXME: We currently consider unlikely blocks to be half as likely as other
+ // blocks, but if we consider the example above the likelyhood is actually
+ // 1/MAX. We could therefore be more precise in how unlikely we consider
+ // blocks to be, but it would require more careful examination of the form
+ // of the comparison expression.
+ const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BI || !BI->isConditional())
+ return;
+
+ // Check if the branch is based on an instruction compared with a constant
+ CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
+ if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
+ !isa<Constant>(CI->getOperand(1)))
+ return;
+
+ // Either the instruction must be a PHI, or a chain of operations involving
+ // constants that ends in a PHI which we can then collapse into a single value
+ // if the PHI value is known.
+ Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
+ PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
+ Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
+ // Collect the instructions until we hit a PHI
+ SmallVector<BinaryOperator *, 1> InstChain;
+ while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
+ isa<Constant>(CmpLHS->getOperand(1))) {
+ // Stop if the chain extends outside of the loop
+ if (!L->contains(CmpLHS))
+ return;
+ InstChain.push_back(cast<BinaryOperator>(CmpLHS));
+ CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
+ if (CmpLHS)
+ CmpPHI = dyn_cast<PHINode>(CmpLHS);
+ }
+ if (!CmpPHI || !L->contains(CmpPHI))
+ return;
+
+ // Trace the phi node to find all values that come from successors of BB
+ SmallPtrSet<PHINode*, 8> VisitedInsts;
+ SmallVector<PHINode*, 8> WorkList;
+ WorkList.push_back(CmpPHI);
+ VisitedInsts.insert(CmpPHI);
+ while (!WorkList.empty()) {
+ PHINode *P = WorkList.back();
+ WorkList.pop_back();
+ for (BasicBlock *B : P->blocks()) {
+ // Skip blocks that aren't part of the loop
+ if (!L->contains(B))
+ continue;
+ Value *V = P->getIncomingValueForBlock(B);
+ // If the source is a PHI add it to the work list if we haven't
+ // already visited it.
+ if (PHINode *PN = dyn_cast<PHINode>(V)) {
+ if (VisitedInsts.insert(PN).second)
+ WorkList.push_back(PN);
+ continue;
+ }
+ // If this incoming value is a constant and B is a successor of BB, then
+ // we can constant-evaluate the compare to see if it makes the branch be
+ // taken or not.
+ Constant *CmpLHSConst = dyn_cast<Constant>(V);
+ if (!CmpLHSConst ||
+ std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
+ continue;
+ // First collapse InstChain
+ for (Instruction *I : llvm::reverse(InstChain)) {
+ CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
+ cast<Constant>(I->getOperand(1)), true);
+ if (!CmpLHSConst)
+ break;
+ }
+ if (!CmpLHSConst)
+ continue;
+ // Now constant-evaluate the compare
+ Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
+ CmpLHSConst, CmpConst, true);
+ // If the result means we don't branch to the block then that block is
+ // unlikely.
+ if (Result &&
+ ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
+ (Result->isOneValue() && B == BI->getSuccessor(1))))
+ UnlikelyBlocks.insert(B);
+ }
+ }
+}
+
+// Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
+// as taken, exiting edges as not-taken.
+bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
+ const LoopInfo &LI,
+ SccInfo &SccI) {
+ int SccNum;
+ Loop *L = LI.getLoopFor(BB);
+ if (!L) {
+ SccNum = getSCCNum(BB, SccI);
+ if (SccNum < 0)
+ return false;
+ }
+
+ SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
+ if (L)
+ computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
+
+ SmallVector<unsigned, 8> BackEdges;
+ SmallVector<unsigned, 8> ExitingEdges;
+ SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
+ SmallVector<unsigned, 8> UnlikelyEdges;
+
+ for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
+ // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
+ // irreducible loops.
+ if (L) {
+ if (UnlikelyBlocks.count(*I) != 0)
+ UnlikelyEdges.push_back(I.getSuccessorIndex());
+ else if (!L->contains(*I))
+ ExitingEdges.push_back(I.getSuccessorIndex());
+ else if (L->getHeader() == *I)
+ BackEdges.push_back(I.getSuccessorIndex());
+ else
+ InEdges.push_back(I.getSuccessorIndex());
+ } else {
+ if (getSCCNum(*I, SccI) != SccNum)
+ ExitingEdges.push_back(I.getSuccessorIndex());
+ else if (isSCCHeader(*I, SccNum, SccI))
+ BackEdges.push_back(I.getSuccessorIndex());
+ else
+ InEdges.push_back(I.getSuccessorIndex());
+ }
+ }
+
+ if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
+ return false;
+
+ // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
+ // normalize them so that they sum up to one.
+ unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
+ (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
+ (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
+ (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
+
+ if (uint32_t numBackEdges = BackEdges.size()) {
+ BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
+ auto Prob = TakenProb / numBackEdges;
+ for (unsigned SuccIdx : BackEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ }
+
+ if (uint32_t numInEdges = InEdges.size()) {
+ BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
+ auto Prob = TakenProb / numInEdges;
+ for (unsigned SuccIdx : InEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ }
+
+ if (uint32_t numExitingEdges = ExitingEdges.size()) {
+ BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
+ Denom);
+ auto Prob = NotTakenProb / numExitingEdges;
+ for (unsigned SuccIdx : ExitingEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ }
+
+ if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
+ BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
+ Denom);
+ auto Prob = UnlikelyProb / numUnlikelyEdges;
+ for (unsigned SuccIdx : UnlikelyEdges)
+ setEdgeProbability(BB, SuccIdx, Prob);
+ }
+
+ return true;
+}
+
+bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
+ const TargetLibraryInfo *TLI) {
+ const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BI || !BI->isConditional())
+ return false;
+
+ Value *Cond = BI->getCondition();
+ ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
+ if (!CI)
+ return false;
+
+ auto GetConstantInt = [](Value *V) {
+ if (auto *I = dyn_cast<BitCastInst>(V))
+ return dyn_cast<ConstantInt>(I->getOperand(0));
+ return dyn_cast<ConstantInt>(V);
+ };
+
+ Value *RHS = CI->getOperand(1);
+ ConstantInt *CV = GetConstantInt(RHS);
+ if (!CV)
+ return false;
+
+ // If the LHS is the result of AND'ing a value with a single bit bitmask,
+ // we don't have information about probabilities.
+ if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
+ if (LHS->getOpcode() == Instruction::And)
+ if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
+ if (AndRHS->getValue().isPowerOf2())
+ return false;
+
+ // Check if the LHS is the return value of a library function
+ LibFunc Func = NumLibFuncs;
+ if (TLI)
+ if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
+ if (Function *CalledFn = Call->getCalledFunction())
+ TLI->getLibFunc(*CalledFn, Func);
+
+ bool isProb;
+ if (Func == LibFunc_strcasecmp ||
+ Func == LibFunc_strcmp ||
+ Func == LibFunc_strncasecmp ||
+ Func == LibFunc_strncmp ||
+ Func == LibFunc_memcmp) {
+ // strcmp and similar functions return zero, negative, or positive, if the
+ // first string is equal, less, or greater than the second. We consider it
+ // likely that the strings are not equal, so a comparison with zero is
+ // probably false, but also a comparison with any other number is also
+ // probably false given that what exactly is returned for nonzero values is
+ // not specified. Any kind of comparison other than equality we know
+ // nothing about.
+ switch (CI->getPredicate()) {
+ case CmpInst::ICMP_EQ:
+ isProb = false;
+ break;
+ case CmpInst::ICMP_NE:
+ isProb = true;
+ break;
+ default:
+ return false;
+ }
+ } else if (CV->isZero()) {
+ switch (CI->getPredicate()) {
+ case CmpInst::ICMP_EQ:
+ // X == 0 -> Unlikely
+ isProb = false;
+ break;
+ case CmpInst::ICMP_NE:
+ // X != 0 -> Likely
+ isProb = true;
+ break;
+ case CmpInst::ICMP_SLT:
+ // X < 0 -> Unlikely
+ isProb = false;
+ break;
+ case CmpInst::ICMP_SGT:
+ // X > 0 -> Likely
+ isProb = true;
+ break;
+ default:
+ return false;
+ }
+ } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
+ // InstCombine canonicalizes X <= 0 into X < 1.
+ // X <= 0 -> Unlikely
+ isProb = false;
+ } else if (CV->isMinusOne()) {
+ switch (CI->getPredicate()) {
+ case CmpInst::ICMP_EQ:
+ // X == -1 -> Unlikely
+ isProb = false;
+ break;
+ case CmpInst::ICMP_NE:
+ // X != -1 -> Likely
+ isProb = true;
+ break;
+ case CmpInst::ICMP_SGT:
+ // InstCombine canonicalizes X >= 0 into X > -1.
+ // X >= 0 -> Likely
+ isProb = true;
+ break;
+ default:
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ unsigned TakenIdx = 0, NonTakenIdx = 1;
+
+ if (!isProb)
+ std::swap(TakenIdx, NonTakenIdx);
+
+ BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
+ ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
+ setEdgeProbability(BB, TakenIdx, TakenProb);
+ setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
+ return true;
+}
+
+bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
+ const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BI || !BI->isConditional())
+ return false;
+
+ Value *Cond = BI->getCondition();
+ FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
+ if (!FCmp)
+ return false;
+
+ uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
+ uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
+ bool isProb;
+ if (FCmp->isEquality()) {
+ // f1 == f2 -> Unlikely
+ // f1 != f2 -> Likely
+ isProb = !FCmp->isTrueWhenEqual();
+ } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
+ // !isnan -> Likely
+ isProb = true;
+ TakenWeight = FPH_ORD_WEIGHT;
+ NontakenWeight = FPH_UNO_WEIGHT;
+ } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
+ // isnan -> Unlikely
+ isProb = false;
+ TakenWeight = FPH_ORD_WEIGHT;
+ NontakenWeight = FPH_UNO_WEIGHT;
+ } else {
+ return false;
+ }
+
+ unsigned TakenIdx = 0, NonTakenIdx = 1;
+
+ if (!isProb)
+ std::swap(TakenIdx, NonTakenIdx);
+
+ BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
+ setEdgeProbability(BB, TakenIdx, TakenProb);
+ setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
+ return true;
+}
+
+bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
+ const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
+ if (!II)
+ return false;
+
+ BranchProbability TakenProb(IH_TAKEN_WEIGHT,
+ IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
+ setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
+ setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
+ return true;
+}
+
+void BranchProbabilityInfo::releaseMemory() {
+ Probs.clear();
+}
+
+void BranchProbabilityInfo::print(raw_ostream &OS) const {
+ OS << "---- Branch Probabilities ----\n";
+ // We print the probabilities from the last function the analysis ran over,
+ // or the function it is currently running over.
+ assert(LastF && "Cannot print prior to running over a function");
+ for (const auto &BI : *LastF) {
+ for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
+ ++SI) {
+ printEdgeProbability(OS << " ", &BI, *SI);
+ }
+ }
+}
+
+bool BranchProbabilityInfo::
+isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
+ // Hot probability is at least 4/5 = 80%
+ // FIXME: Compare against a static "hot" BranchProbability.
+ return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
+}
+
+const BasicBlock *
+BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
+ auto MaxProb = BranchProbability::getZero();
+ const BasicBlock *MaxSucc = nullptr;
+
+ for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
+ const BasicBlock *Succ = *I;
+ auto Prob = getEdgeProbability(BB, Succ);
+ if (Prob > MaxProb) {
+ MaxProb = Prob;
+ MaxSucc = Succ;
+ }
+ }
+
+ // Hot probability is at least 4/5 = 80%
+ if (MaxProb > BranchProbability(4, 5))
+ return MaxSucc;
+
+ return nullptr;
+}
+
+/// Get the raw edge probability for the edge. If can't find it, return a
+/// default probability 1/N where N is the number of successors. Here an edge is
+/// specified using PredBlock and an
+/// index to the successors.
+BranchProbability
+BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
+ unsigned IndexInSuccessors) const {
+ auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
+
+ if (I != Probs.end())
+ return I->second;
+
+ return {1, static_cast<uint32_t>(succ_size(Src))};
+}
+
+BranchProbability
+BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
+ succ_const_iterator Dst) const {
+ return getEdgeProbability(Src, Dst.getSuccessorIndex());
+}
+
+/// Get the raw edge probability calculated for the block pair. This returns the
+/// sum of all raw edge probabilities from Src to Dst.
+BranchProbability
+BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
+ const BasicBlock *Dst) const {
+ auto Prob = BranchProbability::getZero();
+ bool FoundProb = false;
+ for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
+ if (*I == Dst) {
+ auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
+ if (MapI != Probs.end()) {
+ FoundProb = true;
+ Prob += MapI->second;
+ }
+ }
+ uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
+ return FoundProb ? Prob : BranchProbability(1, succ_num);
+}
+
+/// Set the edge probability for a given edge specified by PredBlock and an
+/// index to the successors.
+void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
+ unsigned IndexInSuccessors,
+ BranchProbability Prob) {
+ Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
+ Handles.insert(BasicBlockCallbackVH(Src, this));
+ LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
+ << IndexInSuccessors << " successor probability to " << Prob
+ << "\n");
+}
+
+raw_ostream &
+BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
+ const BasicBlock *Src,
+ const BasicBlock *Dst) const {
+ const BranchProbability Prob = getEdgeProbability(Src, Dst);
+ OS << "edge " << Src->getName() << " -> " << Dst->getName()
+ << " probability is " << Prob
+ << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
+
+ return OS;
+}
+
+void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
+ for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
+ auto Key = I->first;
+ if (Key.first == BB)
+ Probs.erase(Key);
+ }
+}
+
+void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
+ const TargetLibraryInfo *TLI) {
+ LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
+ << " ----\n\n");
+ LastF = &F; // Store the last function we ran on for printing.
+ assert(PostDominatedByUnreachable.empty());
+ assert(PostDominatedByColdCall.empty());
+
+ // Record SCC numbers of blocks in the CFG to identify irreducible loops.
+ // FIXME: We could only calculate this if the CFG is known to be irreducible
+ // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
+ int SccNum = 0;
+ SccInfo SccI;
+ for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
+ ++It, ++SccNum) {
+ // Ignore single-block SCCs since they either aren't loops or LoopInfo will
+ // catch them.
+ const std::vector<const BasicBlock *> &Scc = *It;
+ if (Scc.size() == 1)
+ continue;
+
+ LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
+ for (auto *BB : Scc) {
+ LLVM_DEBUG(dbgs() << " " << BB->getName());
+ SccI.SccNums[BB] = SccNum;
+ }
+ LLVM_DEBUG(dbgs() << "\n");
+ }
+
+ // Walk the basic blocks in post-order so that we can build up state about
+ // the successors of a block iteratively.
+ for (auto BB : post_order(&F.getEntryBlock())) {
+ LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
+ << "\n");
+ updatePostDominatedByUnreachable(BB);
+ updatePostDominatedByColdCall(BB);
+ // If there is no at least two successors, no sense to set probability.
+ if (BB->getTerminator()->getNumSuccessors() < 2)
+ continue;
+ if (calcMetadataWeights(BB))
+ continue;
+ if (calcInvokeHeuristics(BB))
+ continue;
+ if (calcUnreachableHeuristics(BB))
+ continue;
+ if (calcColdCallHeuristics(BB))
+ continue;
+ if (calcLoopBranchHeuristics(BB, LI, SccI))
+ continue;
+ if (calcPointerHeuristics(BB))
+ continue;
+ if (calcZeroHeuristics(BB, TLI))
+ continue;
+ if (calcFloatingPointHeuristics(BB))
+ continue;
+ }
+
+ PostDominatedByUnreachable.clear();
+ PostDominatedByColdCall.clear();
+
+ if (PrintBranchProb &&
+ (PrintBranchProbFuncName.empty() ||
+ F.getName().equals(PrintBranchProbFuncName))) {
+ print(dbgs());
+ }
+}
+
+void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
+ AnalysisUsage &AU) const {
+ // We require DT so it's available when LI is available. The LI updating code
+ // asserts that DT is also present so if we don't make sure that we have DT
+ // here, that assert will trigger.
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.setPreservesAll();
+}
+
+bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
+ const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ const TargetLibraryInfo &TLI =
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ BPI.calculate(F, LI, &TLI);
+ return false;
+}
+
+void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
+
+void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
+ const Module *) const {
+ BPI.print(OS);
+}
+
+AnalysisKey BranchProbabilityAnalysis::Key;
+BranchProbabilityInfo
+BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+ BranchProbabilityInfo BPI;
+ BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
+ return BPI;
+}
+
+PreservedAnalyses
+BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ OS << "Printing analysis results of BPI for function "
+ << "'" << F.getName() << "':"
+ << "\n";
+ AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/CFG.cpp b/llvm/lib/Analysis/CFG.cpp
new file mode 100644
index 000000000000..8215b4ecbb03
--- /dev/null
+++ b/llvm/lib/Analysis/CFG.cpp
@@ -0,0 +1,279 @@
+//===-- CFG.cpp - BasicBlock analysis --------------------------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This family of functions performs analyses on basic blocks, and instructions
+// contained within basic blocks.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CFG.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Dominators.h"
+
+using namespace llvm;
+
+/// FindFunctionBackedges - Analyze the specified function to find all of the
+/// loop backedges in the function and return them. This is a relatively cheap
+/// (compared to computing dominators and loop info) analysis.
+///
+/// The output is added to Result, as pairs of <from,to> edge info.
+void llvm::FindFunctionBackedges(const Function &F,
+ SmallVectorImpl<std::pair<const BasicBlock*,const BasicBlock*> > &Result) {
+ const BasicBlock *BB = &F.getEntryBlock();
+ if (succ_empty(BB))
+ return;
+
+ SmallPtrSet<const BasicBlock*, 8> Visited;
+ SmallVector<std::pair<const BasicBlock*, succ_const_iterator>, 8> VisitStack;
+ SmallPtrSet<const BasicBlock*, 8> InStack;
+
+ Visited.insert(BB);
+ VisitStack.push_back(std::make_pair(BB, succ_begin(BB)));
+ InStack.insert(BB);
+ do {
+ std::pair<const BasicBlock*, succ_const_iterator> &Top = VisitStack.back();
+ const BasicBlock *ParentBB = Top.first;
+ succ_const_iterator &I = Top.second;
+
+ bool FoundNew = false;
+ while (I != succ_end(ParentBB)) {
+ BB = *I++;
+ if (Visited.insert(BB).second) {
+ FoundNew = true;
+ break;
+ }
+ // Successor is in VisitStack, it's a back edge.
+ if (InStack.count(BB))
+ Result.push_back(std::make_pair(ParentBB, BB));
+ }
+
+ if (FoundNew) {
+ // Go down one level if there is a unvisited successor.
+ InStack.insert(BB);
+ VisitStack.push_back(std::make_pair(BB, succ_begin(BB)));
+ } else {
+ // Go up one level.
+ InStack.erase(VisitStack.pop_back_val().first);
+ }
+ } while (!VisitStack.empty());
+}
+
+/// GetSuccessorNumber - Search for the specified successor of basic block BB
+/// and return its position in the terminator instruction's list of
+/// successors. It is an error to call this with a block that is not a
+/// successor.
+unsigned llvm::GetSuccessorNumber(const BasicBlock *BB,
+ const BasicBlock *Succ) {
+ const Instruction *Term = BB->getTerminator();
+#ifndef NDEBUG
+ unsigned e = Term->getNumSuccessors();
+#endif
+ for (unsigned i = 0; ; ++i) {
+ assert(i != e && "Didn't find edge?");
+ if (Term->getSuccessor(i) == Succ)
+ return i;
+ }
+}
+
+/// isCriticalEdge - Return true if the specified edge is a critical edge.
+/// Critical edges are edges from a block with multiple successors to a block
+/// with multiple predecessors.
+bool llvm::isCriticalEdge(const Instruction *TI, unsigned SuccNum,
+ bool AllowIdenticalEdges) {
+ assert(SuccNum < TI->getNumSuccessors() && "Illegal edge specification!");
+ return isCriticalEdge(TI, TI->getSuccessor(SuccNum), AllowIdenticalEdges);
+}
+
+bool llvm::isCriticalEdge(const Instruction *TI, const BasicBlock *Dest,
+ bool AllowIdenticalEdges) {
+ assert(TI->isTerminator() && "Must be a terminator to have successors!");
+ if (TI->getNumSuccessors() == 1) return false;
+
+ assert(find(predecessors(Dest), TI->getParent()) != pred_end(Dest) &&
+ "No edge between TI's block and Dest.");
+
+ const_pred_iterator I = pred_begin(Dest), E = pred_end(Dest);
+
+ // If there is more than one predecessor, this is a critical edge...
+ assert(I != E && "No preds, but we have an edge to the block?");
+ const BasicBlock *FirstPred = *I;
+ ++I; // Skip one edge due to the incoming arc from TI.
+ if (!AllowIdenticalEdges)
+ return I != E;
+
+ // If AllowIdenticalEdges is true, then we allow this edge to be considered
+ // non-critical iff all preds come from TI's block.
+ for (; I != E; ++I)
+ if (*I != FirstPred)
+ return true;
+ return false;
+}
+
+// LoopInfo contains a mapping from basic block to the innermost loop. Find
+// the outermost loop in the loop nest that contains BB.
+static const Loop *getOutermostLoop(const LoopInfo *LI, const BasicBlock *BB) {
+ const Loop *L = LI->getLoopFor(BB);
+ if (L) {
+ while (const Loop *Parent = L->getParentLoop())
+ L = Parent;
+ }
+ return L;
+}
+
+bool llvm::isPotentiallyReachableFromMany(
+ SmallVectorImpl<BasicBlock *> &Worklist, BasicBlock *StopBB,
+ const SmallPtrSetImpl<BasicBlock *> *ExclusionSet, const DominatorTree *DT,
+ const LoopInfo *LI) {
+ // When the stop block is unreachable, it's dominated from everywhere,
+ // regardless of whether there's a path between the two blocks.
+ if (DT && !DT->isReachableFromEntry(StopBB))
+ DT = nullptr;
+
+ // We can't skip directly from a block that dominates the stop block if the
+ // exclusion block is potentially in between.
+ if (ExclusionSet && !ExclusionSet->empty())
+ DT = nullptr;
+
+ // Normally any block in a loop is reachable from any other block in a loop,
+ // however excluded blocks might partition the body of a loop to make that
+ // untrue.
+ SmallPtrSet<const Loop *, 8> LoopsWithHoles;
+ if (LI && ExclusionSet) {
+ for (auto BB : *ExclusionSet) {
+ if (const Loop *L = getOutermostLoop(LI, BB))
+ LoopsWithHoles.insert(L);
+ }
+ }
+
+ const Loop *StopLoop = LI ? getOutermostLoop(LI, StopBB) : nullptr;
+
+ // Limit the number of blocks we visit. The goal is to avoid run-away compile
+ // times on large CFGs without hampering sensible code. Arbitrarily chosen.
+ unsigned Limit = 32;
+ SmallPtrSet<const BasicBlock*, 32> Visited;
+ do {
+ BasicBlock *BB = Worklist.pop_back_val();
+ if (!Visited.insert(BB).second)
+ continue;
+ if (BB == StopBB)
+ return true;
+ if (ExclusionSet && ExclusionSet->count(BB))
+ continue;
+ if (DT && DT->dominates(BB, StopBB))
+ return true;
+
+ const Loop *Outer = nullptr;
+ if (LI) {
+ Outer = getOutermostLoop(LI, BB);
+ // If we're in a loop with a hole, not all blocks in the loop are
+ // reachable from all other blocks. That implies we can't simply jump to
+ // the loop's exit blocks, as that exit might need to pass through an
+ // excluded block. Clear Outer so we process BB's successors.
+ if (LoopsWithHoles.count(Outer))
+ Outer = nullptr;
+ if (StopLoop && Outer == StopLoop)
+ return true;
+ }
+
+ if (!--Limit) {
+ // We haven't been able to prove it one way or the other. Conservatively
+ // answer true -- that there is potentially a path.
+ return true;
+ }
+
+ if (Outer) {
+ // All blocks in a single loop are reachable from all other blocks. From
+ // any of these blocks, we can skip directly to the exits of the loop,
+ // ignoring any other blocks inside the loop body.
+ Outer->getExitBlocks(Worklist);
+ } else {
+ Worklist.append(succ_begin(BB), succ_end(BB));
+ }
+ } while (!Worklist.empty());
+
+ // We have exhausted all possible paths and are certain that 'To' can not be
+ // reached from 'From'.
+ return false;
+}
+
+bool llvm::isPotentiallyReachable(const BasicBlock *A, const BasicBlock *B,
+ const DominatorTree *DT, const LoopInfo *LI) {
+ assert(A->getParent() == B->getParent() &&
+ "This analysis is function-local!");
+
+ SmallVector<BasicBlock*, 32> Worklist;
+ Worklist.push_back(const_cast<BasicBlock*>(A));
+
+ return isPotentiallyReachableFromMany(Worklist, const_cast<BasicBlock *>(B),
+ nullptr, DT, LI);
+}
+
+bool llvm::isPotentiallyReachable(
+ const Instruction *A, const Instruction *B,
+ const SmallPtrSetImpl<BasicBlock *> *ExclusionSet, const DominatorTree *DT,
+ const LoopInfo *LI) {
+ assert(A->getParent()->getParent() == B->getParent()->getParent() &&
+ "This analysis is function-local!");
+
+ SmallVector<BasicBlock*, 32> Worklist;
+
+ if (A->getParent() == B->getParent()) {
+ // The same block case is special because it's the only time we're looking
+ // within a single block to see which instruction comes first. Once we
+ // start looking at multiple blocks, the first instruction of the block is
+ // reachable, so we only need to determine reachability between whole
+ // blocks.
+ BasicBlock *BB = const_cast<BasicBlock *>(A->getParent());
+
+ // If the block is in a loop then we can reach any instruction in the block
+ // from any other instruction in the block by going around a backedge.
+ if (LI && LI->getLoopFor(BB) != nullptr)
+ return true;
+
+ // Linear scan, start at 'A', see whether we hit 'B' or the end first.
+ for (BasicBlock::const_iterator I = A->getIterator(), E = BB->end(); I != E;
+ ++I) {
+ if (&*I == B)
+ return true;
+ }
+
+ // Can't be in a loop if it's the entry block -- the entry block may not
+ // have predecessors.
+ if (BB == &BB->getParent()->getEntryBlock())
+ return false;
+
+ // Otherwise, continue doing the normal per-BB CFG walk.
+ Worklist.append(succ_begin(BB), succ_end(BB));
+
+ if (Worklist.empty()) {
+ // We've proven that there's no path!
+ return false;
+ }
+ } else {
+ Worklist.push_back(const_cast<BasicBlock*>(A->getParent()));
+ }
+
+ if (DT) {
+ if (DT->isReachableFromEntry(A->getParent()) &&
+ !DT->isReachableFromEntry(B->getParent()))
+ return false;
+ if (!ExclusionSet || ExclusionSet->empty()) {
+ if (A->getParent() == &A->getParent()->getParent()->getEntryBlock() &&
+ DT->isReachableFromEntry(B->getParent()))
+ return true;
+ if (B->getParent() == &A->getParent()->getParent()->getEntryBlock() &&
+ DT->isReachableFromEntry(A->getParent()))
+ return false;
+ }
+ }
+
+ return isPotentiallyReachableFromMany(
+ Worklist, const_cast<BasicBlock *>(B->getParent()), ExclusionSet, DT, LI);
+}
diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp
new file mode 100644
index 000000000000..4f4103fefa25
--- /dev/null
+++ b/llvm/lib/Analysis/CFGPrinter.cpp
@@ -0,0 +1,200 @@
+//===- CFGPrinter.cpp - DOT printer for the control flow graph ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a `-dot-cfg` analysis pass, which emits the
+// `<prefix>.<fnname>.dot` file for each function in the program, with a graph
+// of the CFG for that function. The default value for `<prefix>` is `cfg` but
+// can be customized as needed.
+//
+// The other main feature of this file is that it implements the
+// Function::viewCFG method, which is useful for debugging passes which operate
+// on the CFG.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CFGPrinter.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/FileSystem.h"
+using namespace llvm;
+
+static cl::opt<std::string> CFGFuncName(
+ "cfg-func-name", cl::Hidden,
+ cl::desc("The name of a function (or its substring)"
+ " whose CFG is viewed/printed."));
+
+static cl::opt<std::string> CFGDotFilenamePrefix(
+ "cfg-dot-filename-prefix", cl::Hidden,
+ cl::desc("The prefix used for the CFG dot file names."));
+
+namespace {
+ struct CFGViewerLegacyPass : public FunctionPass {
+ static char ID; // Pass identifcation, replacement for typeid
+ CFGViewerLegacyPass() : FunctionPass(ID) {
+ initializeCFGViewerLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ F.viewCFG();
+ return false;
+ }
+
+ void print(raw_ostream &OS, const Module* = nullptr) const override {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ };
+}
+
+char CFGViewerLegacyPass::ID = 0;
+INITIALIZE_PASS(CFGViewerLegacyPass, "view-cfg", "View CFG of function", false, true)
+
+PreservedAnalyses CFGViewerPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ F.viewCFG();
+ return PreservedAnalyses::all();
+}
+
+
+namespace {
+ struct CFGOnlyViewerLegacyPass : public FunctionPass {
+ static char ID; // Pass identifcation, replacement for typeid
+ CFGOnlyViewerLegacyPass() : FunctionPass(ID) {
+ initializeCFGOnlyViewerLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ F.viewCFGOnly();
+ return false;
+ }
+
+ void print(raw_ostream &OS, const Module* = nullptr) const override {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ };
+}
+
+char CFGOnlyViewerLegacyPass::ID = 0;
+INITIALIZE_PASS(CFGOnlyViewerLegacyPass, "view-cfg-only",
+ "View CFG of function (with no function bodies)", false, true)
+
+PreservedAnalyses CFGOnlyViewerPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ F.viewCFGOnly();
+ return PreservedAnalyses::all();
+}
+
+static void writeCFGToDotFile(Function &F, bool CFGOnly = false) {
+ if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName))
+ return;
+ std::string Filename =
+ (CFGDotFilenamePrefix + "." + F.getName() + ".dot").str();
+ errs() << "Writing '" << Filename << "'...";
+
+ std::error_code EC;
+ raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
+
+ if (!EC)
+ WriteGraph(File, (const Function*)&F, CFGOnly);
+ else
+ errs() << " error opening file for writing!";
+ errs() << "\n";
+}
+
+namespace {
+ struct CFGPrinterLegacyPass : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ CFGPrinterLegacyPass() : FunctionPass(ID) {
+ initializeCFGPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ writeCFGToDotFile(F);
+ return false;
+ }
+
+ void print(raw_ostream &OS, const Module* = nullptr) const override {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ };
+}
+
+char CFGPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS(CFGPrinterLegacyPass, "dot-cfg", "Print CFG of function to 'dot' file",
+ false, true)
+
+PreservedAnalyses CFGPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ writeCFGToDotFile(F);
+ return PreservedAnalyses::all();
+}
+
+namespace {
+ struct CFGOnlyPrinterLegacyPass : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ CFGOnlyPrinterLegacyPass() : FunctionPass(ID) {
+ initializeCFGOnlyPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ writeCFGToDotFile(F, /*CFGOnly=*/true);
+ return false;
+ }
+ void print(raw_ostream &OS, const Module* = nullptr) const override {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ };
+}
+
+char CFGOnlyPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS(CFGOnlyPrinterLegacyPass, "dot-cfg-only",
+ "Print CFG of function to 'dot' file (with no function bodies)",
+ false, true)
+
+PreservedAnalyses CFGOnlyPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ writeCFGToDotFile(F, /*CFGOnly=*/true);
+ return PreservedAnalyses::all();
+}
+
+/// viewCFG - This function is meant for use from the debugger. You can just
+/// say 'call F->viewCFG()' and a ghostview window should pop up from the
+/// program, displaying the CFG of the current function. This depends on there
+/// being a 'dot' and 'gv' program in your path.
+///
+void Function::viewCFG() const {
+ if (!CFGFuncName.empty() && !getName().contains(CFGFuncName))
+ return;
+ ViewGraph(this, "cfg" + getName());
+}
+
+/// viewCFGOnly - This function is meant for use from the debugger. It works
+/// just like viewCFG, but it does not include the contents of basic blocks
+/// into the nodes, just the label. If you are only interested in the CFG
+/// this can make the graph smaller.
+///
+void Function::viewCFGOnly() const {
+ if (!CFGFuncName.empty() && !getName().contains(CFGFuncName))
+ return;
+ ViewGraph(this, "cfg" + getName(), true);
+}
+
+FunctionPass *llvm::createCFGPrinterLegacyPassPass () {
+ return new CFGPrinterLegacyPass();
+}
+
+FunctionPass *llvm::createCFGOnlyPrinterLegacyPassPass () {
+ return new CFGOnlyPrinterLegacyPass();
+}
+
diff --git a/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp
new file mode 100644
index 000000000000..fd90bd1521d6
--- /dev/null
+++ b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp
@@ -0,0 +1,931 @@
+//===- CFLAndersAliasAnalysis.cpp - Unification-based Alias Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a CFL-based, summary-based alias analysis algorithm. It
+// differs from CFLSteensAliasAnalysis in its inclusion-based nature while
+// CFLSteensAliasAnalysis is unification-based. This pass has worse performance
+// than CFLSteensAliasAnalysis (the worst case complexity of
+// CFLAndersAliasAnalysis is cubic, while the worst case complexity of
+// CFLSteensAliasAnalysis is almost linear), but it is able to yield more
+// precise analysis result. The precision of this analysis is roughly the same
+// as that of an one level context-sensitive Andersen's algorithm.
+//
+// The algorithm used here is based on recursive state machine matching scheme
+// proposed in "Demand-driven alias analysis for C" by Xin Zheng and Radu
+// Rugina. The general idea is to extend the traditional transitive closure
+// algorithm to perform CFL matching along the way: instead of recording
+// "whether X is reachable from Y", we keep track of "whether X is reachable
+// from Y at state Z", where the "state" field indicates where we are in the CFL
+// matching process. To understand the matching better, it is advisable to have
+// the state machine shown in Figure 3 of the paper available when reading the
+// codes: all we do here is to selectively expand the transitive closure by
+// discarding edges that are not recognized by the state machine.
+//
+// There are two differences between our current implementation and the one
+// described in the paper:
+// - Our algorithm eagerly computes all alias pairs after the CFLGraph is built,
+// while in the paper the authors did the computation in a demand-driven
+// fashion. We did not implement the demand-driven algorithm due to the
+// additional coding complexity and higher memory profile, but if we found it
+// necessary we may switch to it eventually.
+// - In the paper the authors use a state machine that does not distinguish
+// value reads from value writes. For example, if Y is reachable from X at state
+// S3, it may be the case that X is written into Y, or it may be the case that
+// there's a third value Z that writes into both X and Y. To make that
+// distinction (which is crucial in building function summary as well as
+// retrieving mod-ref info), we choose to duplicate some of the states in the
+// paper's proposed state machine. The duplication does not change the set the
+// machine accepts. Given a pair of reachable values, it only provides more
+// detailed information on which value is being written into and which is being
+// read from.
+//
+//===----------------------------------------------------------------------===//
+
+// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and
+// CFLAndersAA is interprocedural. This is *technically* A Bad Thing, because
+// FunctionPasses are only allowed to inspect the Function that they're being
+// run on. Realistically, this likely isn't a problem until we allow
+// FunctionPasses to run concurrently.
+
+#include "llvm/Analysis/CFLAndersAliasAnalysis.h"
+#include "AliasAnalysisSummary.h"
+#include "CFLGraph.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <bitset>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+using namespace llvm::cflaa;
+
+#define DEBUG_TYPE "cfl-anders-aa"
+
+CFLAndersAAResult::CFLAndersAAResult(
+ std::function<const TargetLibraryInfo &(Function &F)> GetTLI)
+ : GetTLI(std::move(GetTLI)) {}
+CFLAndersAAResult::CFLAndersAAResult(CFLAndersAAResult &&RHS)
+ : AAResultBase(std::move(RHS)), GetTLI(std::move(RHS.GetTLI)) {}
+CFLAndersAAResult::~CFLAndersAAResult() = default;
+
+namespace {
+
+enum class MatchState : uint8_t {
+ // The following state represents S1 in the paper.
+ FlowFromReadOnly = 0,
+ // The following two states together represent S2 in the paper.
+ // The 'NoReadWrite' suffix indicates that there exists an alias path that
+ // does not contain assignment and reverse assignment edges.
+ // The 'ReadOnly' suffix indicates that there exists an alias path that
+ // contains reverse assignment edges only.
+ FlowFromMemAliasNoReadWrite,
+ FlowFromMemAliasReadOnly,
+ // The following two states together represent S3 in the paper.
+ // The 'WriteOnly' suffix indicates that there exists an alias path that
+ // contains assignment edges only.
+ // The 'ReadWrite' suffix indicates that there exists an alias path that
+ // contains both assignment and reverse assignment edges. Note that if X and Y
+ // are reachable at 'ReadWrite' state, it does NOT mean X is both read from
+ // and written to Y. Instead, it means that a third value Z is written to both
+ // X and Y.
+ FlowToWriteOnly,
+ FlowToReadWrite,
+ // The following two states together represent S4 in the paper.
+ FlowToMemAliasWriteOnly,
+ FlowToMemAliasReadWrite,
+};
+
+using StateSet = std::bitset<7>;
+
+const unsigned ReadOnlyStateMask =
+ (1U << static_cast<uint8_t>(MatchState::FlowFromReadOnly)) |
+ (1U << static_cast<uint8_t>(MatchState::FlowFromMemAliasReadOnly));
+const unsigned WriteOnlyStateMask =
+ (1U << static_cast<uint8_t>(MatchState::FlowToWriteOnly)) |
+ (1U << static_cast<uint8_t>(MatchState::FlowToMemAliasWriteOnly));
+
+// A pair that consists of a value and an offset
+struct OffsetValue {
+ const Value *Val;
+ int64_t Offset;
+};
+
+bool operator==(OffsetValue LHS, OffsetValue RHS) {
+ return LHS.Val == RHS.Val && LHS.Offset == RHS.Offset;
+}
+bool operator<(OffsetValue LHS, OffsetValue RHS) {
+ return std::less<const Value *>()(LHS.Val, RHS.Val) ||
+ (LHS.Val == RHS.Val && LHS.Offset < RHS.Offset);
+}
+
+// A pair that consists of an InstantiatedValue and an offset
+struct OffsetInstantiatedValue {
+ InstantiatedValue IVal;
+ int64_t Offset;
+};
+
+bool operator==(OffsetInstantiatedValue LHS, OffsetInstantiatedValue RHS) {
+ return LHS.IVal == RHS.IVal && LHS.Offset == RHS.Offset;
+}
+
+// We use ReachabilitySet to keep track of value aliases (The nonterminal "V" in
+// the paper) during the analysis.
+class ReachabilitySet {
+ using ValueStateMap = DenseMap<InstantiatedValue, StateSet>;
+ using ValueReachMap = DenseMap<InstantiatedValue, ValueStateMap>;
+
+ ValueReachMap ReachMap;
+
+public:
+ using const_valuestate_iterator = ValueStateMap::const_iterator;
+ using const_value_iterator = ValueReachMap::const_iterator;
+
+ // Insert edge 'From->To' at state 'State'
+ bool insert(InstantiatedValue From, InstantiatedValue To, MatchState State) {
+ assert(From != To);
+ auto &States = ReachMap[To][From];
+ auto Idx = static_cast<size_t>(State);
+ if (!States.test(Idx)) {
+ States.set(Idx);
+ return true;
+ }
+ return false;
+ }
+
+ // Return the set of all ('From', 'State') pair for a given node 'To'
+ iterator_range<const_valuestate_iterator>
+ reachableValueAliases(InstantiatedValue V) const {
+ auto Itr = ReachMap.find(V);
+ if (Itr == ReachMap.end())
+ return make_range<const_valuestate_iterator>(const_valuestate_iterator(),
+ const_valuestate_iterator());
+ return make_range<const_valuestate_iterator>(Itr->second.begin(),
+ Itr->second.end());
+ }
+
+ iterator_range<const_value_iterator> value_mappings() const {
+ return make_range<const_value_iterator>(ReachMap.begin(), ReachMap.end());
+ }
+};
+
+// We use AliasMemSet to keep track of all memory aliases (the nonterminal "M"
+// in the paper) during the analysis.
+class AliasMemSet {
+ using MemSet = DenseSet<InstantiatedValue>;
+ using MemMapType = DenseMap<InstantiatedValue, MemSet>;
+
+ MemMapType MemMap;
+
+public:
+ using const_mem_iterator = MemSet::const_iterator;
+
+ bool insert(InstantiatedValue LHS, InstantiatedValue RHS) {
+ // Top-level values can never be memory aliases because one cannot take the
+ // addresses of them
+ assert(LHS.DerefLevel > 0 && RHS.DerefLevel > 0);
+ return MemMap[LHS].insert(RHS).second;
+ }
+
+ const MemSet *getMemoryAliases(InstantiatedValue V) const {
+ auto Itr = MemMap.find(V);
+ if (Itr == MemMap.end())
+ return nullptr;
+ return &Itr->second;
+ }
+};
+
+// We use AliasAttrMap to keep track of the AliasAttr of each node.
+class AliasAttrMap {
+ using MapType = DenseMap<InstantiatedValue, AliasAttrs>;
+
+ MapType AttrMap;
+
+public:
+ using const_iterator = MapType::const_iterator;
+
+ bool add(InstantiatedValue V, AliasAttrs Attr) {
+ auto &OldAttr = AttrMap[V];
+ auto NewAttr = OldAttr | Attr;
+ if (OldAttr == NewAttr)
+ return false;
+ OldAttr = NewAttr;
+ return true;
+ }
+
+ AliasAttrs getAttrs(InstantiatedValue V) const {
+ AliasAttrs Attr;
+ auto Itr = AttrMap.find(V);
+ if (Itr != AttrMap.end())
+ Attr = Itr->second;
+ return Attr;
+ }
+
+ iterator_range<const_iterator> mappings() const {
+ return make_range<const_iterator>(AttrMap.begin(), AttrMap.end());
+ }
+};
+
+struct WorkListItem {
+ InstantiatedValue From;
+ InstantiatedValue To;
+ MatchState State;
+};
+
+struct ValueSummary {
+ struct Record {
+ InterfaceValue IValue;
+ unsigned DerefLevel;
+ };
+ SmallVector<Record, 4> FromRecords, ToRecords;
+};
+
+} // end anonymous namespace
+
+namespace llvm {
+
+// Specialize DenseMapInfo for OffsetValue.
+template <> struct DenseMapInfo<OffsetValue> {
+ static OffsetValue getEmptyKey() {
+ return OffsetValue{DenseMapInfo<const Value *>::getEmptyKey(),
+ DenseMapInfo<int64_t>::getEmptyKey()};
+ }
+
+ static OffsetValue getTombstoneKey() {
+ return OffsetValue{DenseMapInfo<const Value *>::getTombstoneKey(),
+ DenseMapInfo<int64_t>::getEmptyKey()};
+ }
+
+ static unsigned getHashValue(const OffsetValue &OVal) {
+ return DenseMapInfo<std::pair<const Value *, int64_t>>::getHashValue(
+ std::make_pair(OVal.Val, OVal.Offset));
+ }
+
+ static bool isEqual(const OffsetValue &LHS, const OffsetValue &RHS) {
+ return LHS == RHS;
+ }
+};
+
+// Specialize DenseMapInfo for OffsetInstantiatedValue.
+template <> struct DenseMapInfo<OffsetInstantiatedValue> {
+ static OffsetInstantiatedValue getEmptyKey() {
+ return OffsetInstantiatedValue{
+ DenseMapInfo<InstantiatedValue>::getEmptyKey(),
+ DenseMapInfo<int64_t>::getEmptyKey()};
+ }
+
+ static OffsetInstantiatedValue getTombstoneKey() {
+ return OffsetInstantiatedValue{
+ DenseMapInfo<InstantiatedValue>::getTombstoneKey(),
+ DenseMapInfo<int64_t>::getEmptyKey()};
+ }
+
+ static unsigned getHashValue(const OffsetInstantiatedValue &OVal) {
+ return DenseMapInfo<std::pair<InstantiatedValue, int64_t>>::getHashValue(
+ std::make_pair(OVal.IVal, OVal.Offset));
+ }
+
+ static bool isEqual(const OffsetInstantiatedValue &LHS,
+ const OffsetInstantiatedValue &RHS) {
+ return LHS == RHS;
+ }
+};
+
+} // end namespace llvm
+
+class CFLAndersAAResult::FunctionInfo {
+ /// Map a value to other values that may alias it
+ /// Since the alias relation is symmetric, to save some space we assume values
+ /// are properly ordered: if a and b alias each other, and a < b, then b is in
+ /// AliasMap[a] but not vice versa.
+ DenseMap<const Value *, std::vector<OffsetValue>> AliasMap;
+
+ /// Map a value to its corresponding AliasAttrs
+ DenseMap<const Value *, AliasAttrs> AttrMap;
+
+ /// Summary of externally visible effects.
+ AliasSummary Summary;
+
+ Optional<AliasAttrs> getAttrs(const Value *) const;
+
+public:
+ FunctionInfo(const Function &, const SmallVectorImpl<Value *> &,
+ const ReachabilitySet &, const AliasAttrMap &);
+
+ bool mayAlias(const Value *, LocationSize, const Value *, LocationSize) const;
+ const AliasSummary &getAliasSummary() const { return Summary; }
+};
+
+static bool hasReadOnlyState(StateSet Set) {
+ return (Set & StateSet(ReadOnlyStateMask)).any();
+}
+
+static bool hasWriteOnlyState(StateSet Set) {
+ return (Set & StateSet(WriteOnlyStateMask)).any();
+}
+
+static Optional<InterfaceValue>
+getInterfaceValue(InstantiatedValue IValue,
+ const SmallVectorImpl<Value *> &RetVals) {
+ auto Val = IValue.Val;
+
+ Optional<unsigned> Index;
+ if (auto Arg = dyn_cast<Argument>(Val))
+ Index = Arg->getArgNo() + 1;
+ else if (is_contained(RetVals, Val))
+ Index = 0;
+
+ if (Index)
+ return InterfaceValue{*Index, IValue.DerefLevel};
+ return None;
+}
+
+static void populateAttrMap(DenseMap<const Value *, AliasAttrs> &AttrMap,
+ const AliasAttrMap &AMap) {
+ for (const auto &Mapping : AMap.mappings()) {
+ auto IVal = Mapping.first;
+
+ // Insert IVal into the map
+ auto &Attr = AttrMap[IVal.Val];
+ // AttrMap only cares about top-level values
+ if (IVal.DerefLevel == 0)
+ Attr |= Mapping.second;
+ }
+}
+
+static void
+populateAliasMap(DenseMap<const Value *, std::vector<OffsetValue>> &AliasMap,
+ const ReachabilitySet &ReachSet) {
+ for (const auto &OuterMapping : ReachSet.value_mappings()) {
+ // AliasMap only cares about top-level values
+ if (OuterMapping.first.DerefLevel > 0)
+ continue;
+
+ auto Val = OuterMapping.first.Val;
+ auto &AliasList = AliasMap[Val];
+ for (const auto &InnerMapping : OuterMapping.second) {
+ // Again, AliasMap only cares about top-level values
+ if (InnerMapping.first.DerefLevel == 0)
+ AliasList.push_back(OffsetValue{InnerMapping.first.Val, UnknownOffset});
+ }
+
+ // Sort AliasList for faster lookup
+ llvm::sort(AliasList);
+ }
+}
+
+static void populateExternalRelations(
+ SmallVectorImpl<ExternalRelation> &ExtRelations, const Function &Fn,
+ const SmallVectorImpl<Value *> &RetVals, const ReachabilitySet &ReachSet) {
+ // If a function only returns one of its argument X, then X will be both an
+ // argument and a return value at the same time. This is an edge case that
+ // needs special handling here.
+ for (const auto &Arg : Fn.args()) {
+ if (is_contained(RetVals, &Arg)) {
+ auto ArgVal = InterfaceValue{Arg.getArgNo() + 1, 0};
+ auto RetVal = InterfaceValue{0, 0};
+ ExtRelations.push_back(ExternalRelation{ArgVal, RetVal, 0});
+ }
+ }
+
+ // Below is the core summary construction logic.
+ // A naive solution of adding only the value aliases that are parameters or
+ // return values in ReachSet to the summary won't work: It is possible that a
+ // parameter P is written into an intermediate value I, and the function
+ // subsequently returns *I. In that case, *I is does not value alias anything
+ // in ReachSet, and the naive solution will miss a summary edge from (P, 1) to
+ // (I, 1).
+ // To account for the aforementioned case, we need to check each non-parameter
+ // and non-return value for the possibility of acting as an intermediate.
+ // 'ValueMap' here records, for each value, which InterfaceValues read from or
+ // write into it. If both the read list and the write list of a given value
+ // are non-empty, we know that a particular value is an intermidate and we
+ // need to add summary edges from the writes to the reads.
+ DenseMap<Value *, ValueSummary> ValueMap;
+ for (const auto &OuterMapping : ReachSet.value_mappings()) {
+ if (auto Dst = getInterfaceValue(OuterMapping.first, RetVals)) {
+ for (const auto &InnerMapping : OuterMapping.second) {
+ // If Src is a param/return value, we get a same-level assignment.
+ if (auto Src = getInterfaceValue(InnerMapping.first, RetVals)) {
+ // This may happen if both Dst and Src are return values
+ if (*Dst == *Src)
+ continue;
+
+ if (hasReadOnlyState(InnerMapping.second))
+ ExtRelations.push_back(ExternalRelation{*Dst, *Src, UnknownOffset});
+ // No need to check for WriteOnly state, since ReachSet is symmetric
+ } else {
+ // If Src is not a param/return, add it to ValueMap
+ auto SrcIVal = InnerMapping.first;
+ if (hasReadOnlyState(InnerMapping.second))
+ ValueMap[SrcIVal.Val].FromRecords.push_back(
+ ValueSummary::Record{*Dst, SrcIVal.DerefLevel});
+ if (hasWriteOnlyState(InnerMapping.second))
+ ValueMap[SrcIVal.Val].ToRecords.push_back(
+ ValueSummary::Record{*Dst, SrcIVal.DerefLevel});
+ }
+ }
+ }
+ }
+
+ for (const auto &Mapping : ValueMap) {
+ for (const auto &FromRecord : Mapping.second.FromRecords) {
+ for (const auto &ToRecord : Mapping.second.ToRecords) {
+ auto ToLevel = ToRecord.DerefLevel;
+ auto FromLevel = FromRecord.DerefLevel;
+ // Same-level assignments should have already been processed by now
+ if (ToLevel == FromLevel)
+ continue;
+
+ auto SrcIndex = FromRecord.IValue.Index;
+ auto SrcLevel = FromRecord.IValue.DerefLevel;
+ auto DstIndex = ToRecord.IValue.Index;
+ auto DstLevel = ToRecord.IValue.DerefLevel;
+ if (ToLevel > FromLevel)
+ SrcLevel += ToLevel - FromLevel;
+ else
+ DstLevel += FromLevel - ToLevel;
+
+ ExtRelations.push_back(ExternalRelation{
+ InterfaceValue{SrcIndex, SrcLevel},
+ InterfaceValue{DstIndex, DstLevel}, UnknownOffset});
+ }
+ }
+ }
+
+ // Remove duplicates in ExtRelations
+ llvm::sort(ExtRelations);
+ ExtRelations.erase(std::unique(ExtRelations.begin(), ExtRelations.end()),
+ ExtRelations.end());
+}
+
+static void populateExternalAttributes(
+ SmallVectorImpl<ExternalAttribute> &ExtAttributes, const Function &Fn,
+ const SmallVectorImpl<Value *> &RetVals, const AliasAttrMap &AMap) {
+ for (const auto &Mapping : AMap.mappings()) {
+ if (auto IVal = getInterfaceValue(Mapping.first, RetVals)) {
+ auto Attr = getExternallyVisibleAttrs(Mapping.second);
+ if (Attr.any())
+ ExtAttributes.push_back(ExternalAttribute{*IVal, Attr});
+ }
+ }
+}
+
+CFLAndersAAResult::FunctionInfo::FunctionInfo(
+ const Function &Fn, const SmallVectorImpl<Value *> &RetVals,
+ const ReachabilitySet &ReachSet, const AliasAttrMap &AMap) {
+ populateAttrMap(AttrMap, AMap);
+ populateExternalAttributes(Summary.RetParamAttributes, Fn, RetVals, AMap);
+ populateAliasMap(AliasMap, ReachSet);
+ populateExternalRelations(Summary.RetParamRelations, Fn, RetVals, ReachSet);
+}
+
+Optional<AliasAttrs>
+CFLAndersAAResult::FunctionInfo::getAttrs(const Value *V) const {
+ assert(V != nullptr);
+
+ auto Itr = AttrMap.find(V);
+ if (Itr != AttrMap.end())
+ return Itr->second;
+ return None;
+}
+
+bool CFLAndersAAResult::FunctionInfo::mayAlias(
+ const Value *LHS, LocationSize MaybeLHSSize, const Value *RHS,
+ LocationSize MaybeRHSSize) const {
+ assert(LHS && RHS);
+
+ // Check if we've seen LHS and RHS before. Sometimes LHS or RHS can be created
+ // after the analysis gets executed, and we want to be conservative in those
+ // cases.
+ auto MaybeAttrsA = getAttrs(LHS);
+ auto MaybeAttrsB = getAttrs(RHS);
+ if (!MaybeAttrsA || !MaybeAttrsB)
+ return true;
+
+ // Check AliasAttrs before AliasMap lookup since it's cheaper
+ auto AttrsA = *MaybeAttrsA;
+ auto AttrsB = *MaybeAttrsB;
+ if (hasUnknownOrCallerAttr(AttrsA))
+ return AttrsB.any();
+ if (hasUnknownOrCallerAttr(AttrsB))
+ return AttrsA.any();
+ if (isGlobalOrArgAttr(AttrsA))
+ return isGlobalOrArgAttr(AttrsB);
+ if (isGlobalOrArgAttr(AttrsB))
+ return isGlobalOrArgAttr(AttrsA);
+
+ // At this point both LHS and RHS should point to locally allocated objects
+
+ auto Itr = AliasMap.find(LHS);
+ if (Itr != AliasMap.end()) {
+
+ // Find out all (X, Offset) where X == RHS
+ auto Comparator = [](OffsetValue LHS, OffsetValue RHS) {
+ return std::less<const Value *>()(LHS.Val, RHS.Val);
+ };
+#ifdef EXPENSIVE_CHECKS
+ assert(std::is_sorted(Itr->second.begin(), Itr->second.end(), Comparator));
+#endif
+ auto RangePair = std::equal_range(Itr->second.begin(), Itr->second.end(),
+ OffsetValue{RHS, 0}, Comparator);
+
+ if (RangePair.first != RangePair.second) {
+ // Be conservative about unknown sizes
+ if (MaybeLHSSize == LocationSize::unknown() ||
+ MaybeRHSSize == LocationSize::unknown())
+ return true;
+
+ const uint64_t LHSSize = MaybeLHSSize.getValue();
+ const uint64_t RHSSize = MaybeRHSSize.getValue();
+
+ for (const auto &OVal : make_range(RangePair)) {
+ // Be conservative about UnknownOffset
+ if (OVal.Offset == UnknownOffset)
+ return true;
+
+ // We know that LHS aliases (RHS + OVal.Offset) if the control flow
+ // reaches here. The may-alias query essentially becomes integer
+ // range-overlap queries over two ranges [OVal.Offset, OVal.Offset +
+ // LHSSize) and [0, RHSSize).
+
+ // Try to be conservative on super large offsets
+ if (LLVM_UNLIKELY(LHSSize > INT64_MAX || RHSSize > INT64_MAX))
+ return true;
+
+ auto LHSStart = OVal.Offset;
+ // FIXME: Do we need to guard against integer overflow?
+ auto LHSEnd = OVal.Offset + static_cast<int64_t>(LHSSize);
+ auto RHSStart = 0;
+ auto RHSEnd = static_cast<int64_t>(RHSSize);
+ if (LHSEnd > RHSStart && LHSStart < RHSEnd)
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+static void propagate(InstantiatedValue From, InstantiatedValue To,
+ MatchState State, ReachabilitySet &ReachSet,
+ std::vector<WorkListItem> &WorkList) {
+ if (From == To)
+ return;
+ if (ReachSet.insert(From, To, State))
+ WorkList.push_back(WorkListItem{From, To, State});
+}
+
+static void initializeWorkList(std::vector<WorkListItem> &WorkList,
+ ReachabilitySet &ReachSet,
+ const CFLGraph &Graph) {
+ for (const auto &Mapping : Graph.value_mappings()) {
+ auto Val = Mapping.first;
+ auto &ValueInfo = Mapping.second;
+ assert(ValueInfo.getNumLevels() > 0);
+
+ // Insert all immediate assignment neighbors to the worklist
+ for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) {
+ auto Src = InstantiatedValue{Val, I};
+ // If there's an assignment edge from X to Y, it means Y is reachable from
+ // X at S3 and X is reachable from Y at S1
+ for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) {
+ propagate(Edge.Other, Src, MatchState::FlowFromReadOnly, ReachSet,
+ WorkList);
+ propagate(Src, Edge.Other, MatchState::FlowToWriteOnly, ReachSet,
+ WorkList);
+ }
+ }
+ }
+}
+
+static Optional<InstantiatedValue> getNodeBelow(const CFLGraph &Graph,
+ InstantiatedValue V) {
+ auto NodeBelow = InstantiatedValue{V.Val, V.DerefLevel + 1};
+ if (Graph.getNode(NodeBelow))
+ return NodeBelow;
+ return None;
+}
+
+static void processWorkListItem(const WorkListItem &Item, const CFLGraph &Graph,
+ ReachabilitySet &ReachSet, AliasMemSet &MemSet,
+ std::vector<WorkListItem> &WorkList) {
+ auto FromNode = Item.From;
+ auto ToNode = Item.To;
+
+ auto NodeInfo = Graph.getNode(ToNode);
+ assert(NodeInfo != nullptr);
+
+ // TODO: propagate field offsets
+
+ // FIXME: Here is a neat trick we can do: since both ReachSet and MemSet holds
+ // relations that are symmetric, we could actually cut the storage by half by
+ // sorting FromNode and ToNode before insertion happens.
+
+ // The newly added value alias pair may potentially generate more memory
+ // alias pairs. Check for them here.
+ auto FromNodeBelow = getNodeBelow(Graph, FromNode);
+ auto ToNodeBelow = getNodeBelow(Graph, ToNode);
+ if (FromNodeBelow && ToNodeBelow &&
+ MemSet.insert(*FromNodeBelow, *ToNodeBelow)) {
+ propagate(*FromNodeBelow, *ToNodeBelow,
+ MatchState::FlowFromMemAliasNoReadWrite, ReachSet, WorkList);
+ for (const auto &Mapping : ReachSet.reachableValueAliases(*FromNodeBelow)) {
+ auto Src = Mapping.first;
+ auto MemAliasPropagate = [&](MatchState FromState, MatchState ToState) {
+ if (Mapping.second.test(static_cast<size_t>(FromState)))
+ propagate(Src, *ToNodeBelow, ToState, ReachSet, WorkList);
+ };
+
+ MemAliasPropagate(MatchState::FlowFromReadOnly,
+ MatchState::FlowFromMemAliasReadOnly);
+ MemAliasPropagate(MatchState::FlowToWriteOnly,
+ MatchState::FlowToMemAliasWriteOnly);
+ MemAliasPropagate(MatchState::FlowToReadWrite,
+ MatchState::FlowToMemAliasReadWrite);
+ }
+ }
+
+ // This is the core of the state machine walking algorithm. We expand ReachSet
+ // based on which state we are at (which in turn dictates what edges we
+ // should examine)
+ // From a high-level point of view, the state machine here guarantees two
+ // properties:
+ // - If *X and *Y are memory aliases, then X and Y are value aliases
+ // - If Y is an alias of X, then reverse assignment edges (if there is any)
+ // should precede any assignment edges on the path from X to Y.
+ auto NextAssignState = [&](MatchState State) {
+ for (const auto &AssignEdge : NodeInfo->Edges)
+ propagate(FromNode, AssignEdge.Other, State, ReachSet, WorkList);
+ };
+ auto NextRevAssignState = [&](MatchState State) {
+ for (const auto &RevAssignEdge : NodeInfo->ReverseEdges)
+ propagate(FromNode, RevAssignEdge.Other, State, ReachSet, WorkList);
+ };
+ auto NextMemState = [&](MatchState State) {
+ if (auto AliasSet = MemSet.getMemoryAliases(ToNode)) {
+ for (const auto &MemAlias : *AliasSet)
+ propagate(FromNode, MemAlias, State, ReachSet, WorkList);
+ }
+ };
+
+ switch (Item.State) {
+ case MatchState::FlowFromReadOnly:
+ NextRevAssignState(MatchState::FlowFromReadOnly);
+ NextAssignState(MatchState::FlowToReadWrite);
+ NextMemState(MatchState::FlowFromMemAliasReadOnly);
+ break;
+
+ case MatchState::FlowFromMemAliasNoReadWrite:
+ NextRevAssignState(MatchState::FlowFromReadOnly);
+ NextAssignState(MatchState::FlowToWriteOnly);
+ break;
+
+ case MatchState::FlowFromMemAliasReadOnly:
+ NextRevAssignState(MatchState::FlowFromReadOnly);
+ NextAssignState(MatchState::FlowToReadWrite);
+ break;
+
+ case MatchState::FlowToWriteOnly:
+ NextAssignState(MatchState::FlowToWriteOnly);
+ NextMemState(MatchState::FlowToMemAliasWriteOnly);
+ break;
+
+ case MatchState::FlowToReadWrite:
+ NextAssignState(MatchState::FlowToReadWrite);
+ NextMemState(MatchState::FlowToMemAliasReadWrite);
+ break;
+
+ case MatchState::FlowToMemAliasWriteOnly:
+ NextAssignState(MatchState::FlowToWriteOnly);
+ break;
+
+ case MatchState::FlowToMemAliasReadWrite:
+ NextAssignState(MatchState::FlowToReadWrite);
+ break;
+ }
+}
+
+static AliasAttrMap buildAttrMap(const CFLGraph &Graph,
+ const ReachabilitySet &ReachSet) {
+ AliasAttrMap AttrMap;
+ std::vector<InstantiatedValue> WorkList, NextList;
+
+ // Initialize each node with its original AliasAttrs in CFLGraph
+ for (const auto &Mapping : Graph.value_mappings()) {
+ auto Val = Mapping.first;
+ auto &ValueInfo = Mapping.second;
+ for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) {
+ auto Node = InstantiatedValue{Val, I};
+ AttrMap.add(Node, ValueInfo.getNodeInfoAtLevel(I).Attr);
+ WorkList.push_back(Node);
+ }
+ }
+
+ while (!WorkList.empty()) {
+ for (const auto &Dst : WorkList) {
+ auto DstAttr = AttrMap.getAttrs(Dst);
+ if (DstAttr.none())
+ continue;
+
+ // Propagate attr on the same level
+ for (const auto &Mapping : ReachSet.reachableValueAliases(Dst)) {
+ auto Src = Mapping.first;
+ if (AttrMap.add(Src, DstAttr))
+ NextList.push_back(Src);
+ }
+
+ // Propagate attr to the levels below
+ auto DstBelow = getNodeBelow(Graph, Dst);
+ while (DstBelow) {
+ if (AttrMap.add(*DstBelow, DstAttr)) {
+ NextList.push_back(*DstBelow);
+ break;
+ }
+ DstBelow = getNodeBelow(Graph, *DstBelow);
+ }
+ }
+ WorkList.swap(NextList);
+ NextList.clear();
+ }
+
+ return AttrMap;
+}
+
+CFLAndersAAResult::FunctionInfo
+CFLAndersAAResult::buildInfoFrom(const Function &Fn) {
+ CFLGraphBuilder<CFLAndersAAResult> GraphBuilder(
+ *this, GetTLI(const_cast<Function &>(Fn)),
+ // Cast away the constness here due to GraphBuilder's API requirement
+ const_cast<Function &>(Fn));
+ auto &Graph = GraphBuilder.getCFLGraph();
+
+ ReachabilitySet ReachSet;
+ AliasMemSet MemSet;
+
+ std::vector<WorkListItem> WorkList, NextList;
+ initializeWorkList(WorkList, ReachSet, Graph);
+ // TODO: make sure we don't stop before the fix point is reached
+ while (!WorkList.empty()) {
+ for (const auto &Item : WorkList)
+ processWorkListItem(Item, Graph, ReachSet, MemSet, NextList);
+
+ NextList.swap(WorkList);
+ NextList.clear();
+ }
+
+ // Now that we have all the reachability info, propagate AliasAttrs according
+ // to it
+ auto IValueAttrMap = buildAttrMap(Graph, ReachSet);
+
+ return FunctionInfo(Fn, GraphBuilder.getReturnValues(), ReachSet,
+ std::move(IValueAttrMap));
+}
+
+void CFLAndersAAResult::scan(const Function &Fn) {
+ auto InsertPair = Cache.insert(std::make_pair(&Fn, Optional<FunctionInfo>()));
+ (void)InsertPair;
+ assert(InsertPair.second &&
+ "Trying to scan a function that has already been cached");
+
+ // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call
+ // may get evaluated after operator[], potentially triggering a DenseMap
+ // resize and invalidating the reference returned by operator[]
+ auto FunInfo = buildInfoFrom(Fn);
+ Cache[&Fn] = std::move(FunInfo);
+ Handles.emplace_front(const_cast<Function *>(&Fn), this);
+}
+
+void CFLAndersAAResult::evict(const Function *Fn) { Cache.erase(Fn); }
+
+const Optional<CFLAndersAAResult::FunctionInfo> &
+CFLAndersAAResult::ensureCached(const Function &Fn) {
+ auto Iter = Cache.find(&Fn);
+ if (Iter == Cache.end()) {
+ scan(Fn);
+ Iter = Cache.find(&Fn);
+ assert(Iter != Cache.end());
+ assert(Iter->second.hasValue());
+ }
+ return Iter->second;
+}
+
+const AliasSummary *CFLAndersAAResult::getAliasSummary(const Function &Fn) {
+ auto &FunInfo = ensureCached(Fn);
+ if (FunInfo.hasValue())
+ return &FunInfo->getAliasSummary();
+ else
+ return nullptr;
+}
+
+AliasResult CFLAndersAAResult::query(const MemoryLocation &LocA,
+ const MemoryLocation &LocB) {
+ auto *ValA = LocA.Ptr;
+ auto *ValB = LocB.Ptr;
+
+ if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy())
+ return NoAlias;
+
+ auto *Fn = parentFunctionOfValue(ValA);
+ if (!Fn) {
+ Fn = parentFunctionOfValue(ValB);
+ if (!Fn) {
+ // The only times this is known to happen are when globals + InlineAsm are
+ // involved
+ LLVM_DEBUG(
+ dbgs()
+ << "CFLAndersAA: could not extract parent function information.\n");
+ return MayAlias;
+ }
+ } else {
+ assert(!parentFunctionOfValue(ValB) || parentFunctionOfValue(ValB) == Fn);
+ }
+
+ assert(Fn != nullptr);
+ auto &FunInfo = ensureCached(*Fn);
+
+ // AliasMap lookup
+ if (FunInfo->mayAlias(ValA, LocA.Size, ValB, LocB.Size))
+ return MayAlias;
+ return NoAlias;
+}
+
+AliasResult CFLAndersAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ if (LocA.Ptr == LocB.Ptr)
+ return MustAlias;
+
+ // Comparisons between global variables and other constants should be
+ // handled by BasicAA.
+ // CFLAndersAA may report NoAlias when comparing a GlobalValue and
+ // ConstantExpr, but every query needs to have at least one Value tied to a
+ // Function, and neither GlobalValues nor ConstantExprs are.
+ if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr))
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ AliasResult QueryResult = query(LocA, LocB);
+ if (QueryResult == MayAlias)
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ return QueryResult;
+}
+
+AnalysisKey CFLAndersAA::Key;
+
+CFLAndersAAResult CFLAndersAA::run(Function &F, FunctionAnalysisManager &AM) {
+ auto GetTLI = [&AM](Function &F) -> TargetLibraryInfo & {
+ return AM.getResult<TargetLibraryAnalysis>(F);
+ };
+ return CFLAndersAAResult(GetTLI);
+}
+
+char CFLAndersAAWrapperPass::ID = 0;
+INITIALIZE_PASS(CFLAndersAAWrapperPass, "cfl-anders-aa",
+ "Inclusion-Based CFL Alias Analysis", false, true)
+
+ImmutablePass *llvm::createCFLAndersAAWrapperPass() {
+ return new CFLAndersAAWrapperPass();
+}
+
+CFLAndersAAWrapperPass::CFLAndersAAWrapperPass() : ImmutablePass(ID) {
+ initializeCFLAndersAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void CFLAndersAAWrapperPass::initializePass() {
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+ Result.reset(new CFLAndersAAResult(GetTLI));
+}
+
+void CFLAndersAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/CFLGraph.h b/llvm/lib/Analysis/CFLGraph.h
new file mode 100644
index 000000000000..21842ed36487
--- /dev/null
+++ b/llvm/lib/Analysis/CFLGraph.h
@@ -0,0 +1,660 @@
+//===- CFLGraph.h - Abstract stratified sets implementation. -----*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file defines CFLGraph, an auxiliary data structure used by CFL-based
+/// alias analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_ANALYSIS_CFLGRAPH_H
+#define LLVM_LIB_ANALYSIS_CFLGRAPH_H
+
+#include "AliasAnalysisSummary.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+#include <cstdint>
+#include <vector>
+
+namespace llvm {
+namespace cflaa {
+
+/// The Program Expression Graph (PEG) of CFL analysis
+/// CFLGraph is auxiliary data structure used by CFL-based alias analysis to
+/// describe flow-insensitive pointer-related behaviors. Given an LLVM function,
+/// the main purpose of this graph is to abstract away unrelated facts and
+/// translate the rest into a form that can be easily digested by CFL analyses.
+/// Each Node in the graph is an InstantiatedValue, and each edge represent a
+/// pointer assignment between InstantiatedValue. Pointer
+/// references/dereferences are not explicitly stored in the graph: we
+/// implicitly assume that for each node (X, I) it has a dereference edge to (X,
+/// I+1) and a reference edge to (X, I-1).
+class CFLGraph {
+public:
+ using Node = InstantiatedValue;
+
+ struct Edge {
+ Node Other;
+ int64_t Offset;
+ };
+
+ using EdgeList = std::vector<Edge>;
+
+ struct NodeInfo {
+ EdgeList Edges, ReverseEdges;
+ AliasAttrs Attr;
+ };
+
+ class ValueInfo {
+ std::vector<NodeInfo> Levels;
+
+ public:
+ bool addNodeToLevel(unsigned Level) {
+ auto NumLevels = Levels.size();
+ if (NumLevels > Level)
+ return false;
+ Levels.resize(Level + 1);
+ return true;
+ }
+
+ NodeInfo &getNodeInfoAtLevel(unsigned Level) {
+ assert(Level < Levels.size());
+ return Levels[Level];
+ }
+ const NodeInfo &getNodeInfoAtLevel(unsigned Level) const {
+ assert(Level < Levels.size());
+ return Levels[Level];
+ }
+
+ unsigned getNumLevels() const { return Levels.size(); }
+ };
+
+private:
+ using ValueMap = DenseMap<Value *, ValueInfo>;
+
+ ValueMap ValueImpls;
+
+ NodeInfo *getNode(Node N) {
+ auto Itr = ValueImpls.find(N.Val);
+ if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel)
+ return nullptr;
+ return &Itr->second.getNodeInfoAtLevel(N.DerefLevel);
+ }
+
+public:
+ using const_value_iterator = ValueMap::const_iterator;
+
+ bool addNode(Node N, AliasAttrs Attr = AliasAttrs()) {
+ assert(N.Val != nullptr);
+ auto &ValInfo = ValueImpls[N.Val];
+ auto Changed = ValInfo.addNodeToLevel(N.DerefLevel);
+ ValInfo.getNodeInfoAtLevel(N.DerefLevel).Attr |= Attr;
+ return Changed;
+ }
+
+ void addAttr(Node N, AliasAttrs Attr) {
+ auto *Info = getNode(N);
+ assert(Info != nullptr);
+ Info->Attr |= Attr;
+ }
+
+ void addEdge(Node From, Node To, int64_t Offset = 0) {
+ auto *FromInfo = getNode(From);
+ assert(FromInfo != nullptr);
+ auto *ToInfo = getNode(To);
+ assert(ToInfo != nullptr);
+
+ FromInfo->Edges.push_back(Edge{To, Offset});
+ ToInfo->ReverseEdges.push_back(Edge{From, Offset});
+ }
+
+ const NodeInfo *getNode(Node N) const {
+ auto Itr = ValueImpls.find(N.Val);
+ if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel)
+ return nullptr;
+ return &Itr->second.getNodeInfoAtLevel(N.DerefLevel);
+ }
+
+ AliasAttrs attrFor(Node N) const {
+ auto *Info = getNode(N);
+ assert(Info != nullptr);
+ return Info->Attr;
+ }
+
+ iterator_range<const_value_iterator> value_mappings() const {
+ return make_range<const_value_iterator>(ValueImpls.begin(),
+ ValueImpls.end());
+ }
+};
+
+/// A builder class used to create CFLGraph instance from a given function
+/// The CFL-AA that uses this builder must provide its own type as a template
+/// argument. This is necessary for interprocedural processing: CFLGraphBuilder
+/// needs a way of obtaining the summary of other functions when callinsts are
+/// encountered.
+/// As a result, we expect the said CFL-AA to expose a getAliasSummary() public
+/// member function that takes a Function& and returns the corresponding summary
+/// as a const AliasSummary*.
+template <typename CFLAA> class CFLGraphBuilder {
+ // Input of the builder
+ CFLAA &Analysis;
+ const TargetLibraryInfo &TLI;
+
+ // Output of the builder
+ CFLGraph Graph;
+ SmallVector<Value *, 4> ReturnedValues;
+
+ // Helper class
+ /// Gets the edges our graph should have, based on an Instruction*
+ class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> {
+ CFLAA &AA;
+ const DataLayout &DL;
+ const TargetLibraryInfo &TLI;
+
+ CFLGraph &Graph;
+ SmallVectorImpl<Value *> &ReturnValues;
+
+ static bool hasUsefulEdges(ConstantExpr *CE) {
+ // ConstantExpr doesn't have terminators, invokes, or fences, so only
+ // needs to check for compares.
+ return CE->getOpcode() != Instruction::ICmp &&
+ CE->getOpcode() != Instruction::FCmp;
+ }
+
+ // Returns possible functions called by CS into the given SmallVectorImpl.
+ // Returns true if targets found, false otherwise.
+ static bool getPossibleTargets(CallBase &Call,
+ SmallVectorImpl<Function *> &Output) {
+ if (auto *Fn = Call.getCalledFunction()) {
+ Output.push_back(Fn);
+ return true;
+ }
+
+ // TODO: If the call is indirect, we might be able to enumerate all
+ // potential targets of the call and return them, rather than just
+ // failing.
+ return false;
+ }
+
+ void addNode(Value *Val, AliasAttrs Attr = AliasAttrs()) {
+ assert(Val != nullptr && Val->getType()->isPointerTy());
+ if (auto GVal = dyn_cast<GlobalValue>(Val)) {
+ if (Graph.addNode(InstantiatedValue{GVal, 0},
+ getGlobalOrArgAttrFromValue(*GVal)))
+ Graph.addNode(InstantiatedValue{GVal, 1}, getAttrUnknown());
+ } else if (auto CExpr = dyn_cast<ConstantExpr>(Val)) {
+ if (hasUsefulEdges(CExpr)) {
+ if (Graph.addNode(InstantiatedValue{CExpr, 0}))
+ visitConstantExpr(CExpr);
+ }
+ } else
+ Graph.addNode(InstantiatedValue{Val, 0}, Attr);
+ }
+
+ void addAssignEdge(Value *From, Value *To, int64_t Offset = 0) {
+ assert(From != nullptr && To != nullptr);
+ if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy())
+ return;
+ addNode(From);
+ if (To != From) {
+ addNode(To);
+ Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 0},
+ Offset);
+ }
+ }
+
+ void addDerefEdge(Value *From, Value *To, bool IsRead) {
+ assert(From != nullptr && To != nullptr);
+ // FIXME: This is subtly broken, due to how we model some instructions
+ // (e.g. extractvalue, extractelement) as loads. Since those take
+ // non-pointer operands, we'll entirely skip adding edges for those.
+ //
+ // addAssignEdge seems to have a similar issue with insertvalue, etc.
+ if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy())
+ return;
+ addNode(From);
+ addNode(To);
+ if (IsRead) {
+ Graph.addNode(InstantiatedValue{From, 1});
+ Graph.addEdge(InstantiatedValue{From, 1}, InstantiatedValue{To, 0});
+ } else {
+ Graph.addNode(InstantiatedValue{To, 1});
+ Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 1});
+ }
+ }
+
+ void addLoadEdge(Value *From, Value *To) { addDerefEdge(From, To, true); }
+ void addStoreEdge(Value *From, Value *To) { addDerefEdge(From, To, false); }
+
+ public:
+ GetEdgesVisitor(CFLGraphBuilder &Builder, const DataLayout &DL)
+ : AA(Builder.Analysis), DL(DL), TLI(Builder.TLI), Graph(Builder.Graph),
+ ReturnValues(Builder.ReturnedValues) {}
+
+ void visitInstruction(Instruction &) {
+ llvm_unreachable("Unsupported instruction encountered");
+ }
+
+ void visitReturnInst(ReturnInst &Inst) {
+ if (auto RetVal = Inst.getReturnValue()) {
+ if (RetVal->getType()->isPointerTy()) {
+ addNode(RetVal);
+ ReturnValues.push_back(RetVal);
+ }
+ }
+ }
+
+ void visitPtrToIntInst(PtrToIntInst &Inst) {
+ auto *Ptr = Inst.getOperand(0);
+ addNode(Ptr, getAttrEscaped());
+ }
+
+ void visitIntToPtrInst(IntToPtrInst &Inst) {
+ auto *Ptr = &Inst;
+ addNode(Ptr, getAttrUnknown());
+ }
+
+ void visitCastInst(CastInst &Inst) {
+ auto *Src = Inst.getOperand(0);
+ addAssignEdge(Src, &Inst);
+ }
+
+ void visitBinaryOperator(BinaryOperator &Inst) {
+ auto *Op1 = Inst.getOperand(0);
+ auto *Op2 = Inst.getOperand(1);
+ addAssignEdge(Op1, &Inst);
+ addAssignEdge(Op2, &Inst);
+ }
+
+ void visitUnaryOperator(UnaryOperator &Inst) {
+ auto *Src = Inst.getOperand(0);
+ addAssignEdge(Src, &Inst);
+ }
+
+ void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) {
+ auto *Ptr = Inst.getPointerOperand();
+ auto *Val = Inst.getNewValOperand();
+ addStoreEdge(Val, Ptr);
+ }
+
+ void visitAtomicRMWInst(AtomicRMWInst &Inst) {
+ auto *Ptr = Inst.getPointerOperand();
+ auto *Val = Inst.getValOperand();
+ addStoreEdge(Val, Ptr);
+ }
+
+ void visitPHINode(PHINode &Inst) {
+ for (Value *Val : Inst.incoming_values())
+ addAssignEdge(Val, &Inst);
+ }
+
+ void visitGEP(GEPOperator &GEPOp) {
+ uint64_t Offset = UnknownOffset;
+ APInt APOffset(DL.getPointerSizeInBits(GEPOp.getPointerAddressSpace()),
+ 0);
+ if (GEPOp.accumulateConstantOffset(DL, APOffset))
+ Offset = APOffset.getSExtValue();
+
+ auto *Op = GEPOp.getPointerOperand();
+ addAssignEdge(Op, &GEPOp, Offset);
+ }
+
+ void visitGetElementPtrInst(GetElementPtrInst &Inst) {
+ auto *GEPOp = cast<GEPOperator>(&Inst);
+ visitGEP(*GEPOp);
+ }
+
+ void visitSelectInst(SelectInst &Inst) {
+ // Condition is not processed here (The actual statement producing
+ // the condition result is processed elsewhere). For select, the
+ // condition is evaluated, but not loaded, stored, or assigned
+ // simply as a result of being the condition of a select.
+
+ auto *TrueVal = Inst.getTrueValue();
+ auto *FalseVal = Inst.getFalseValue();
+ addAssignEdge(TrueVal, &Inst);
+ addAssignEdge(FalseVal, &Inst);
+ }
+
+ void visitAllocaInst(AllocaInst &Inst) { addNode(&Inst); }
+
+ void visitLoadInst(LoadInst &Inst) {
+ auto *Ptr = Inst.getPointerOperand();
+ auto *Val = &Inst;
+ addLoadEdge(Ptr, Val);
+ }
+
+ void visitStoreInst(StoreInst &Inst) {
+ auto *Ptr = Inst.getPointerOperand();
+ auto *Val = Inst.getValueOperand();
+ addStoreEdge(Val, Ptr);
+ }
+
+ void visitVAArgInst(VAArgInst &Inst) {
+ // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it
+ // does
+ // two things:
+ // 1. Loads a value from *((T*)*Ptr).
+ // 2. Increments (stores to) *Ptr by some target-specific amount.
+ // For now, we'll handle this like a landingpad instruction (by placing
+ // the
+ // result in its own group, and having that group alias externals).
+ if (Inst.getType()->isPointerTy())
+ addNode(&Inst, getAttrUnknown());
+ }
+
+ static bool isFunctionExternal(Function *Fn) {
+ return !Fn->hasExactDefinition();
+ }
+
+ bool tryInterproceduralAnalysis(CallBase &Call,
+ const SmallVectorImpl<Function *> &Fns) {
+ assert(Fns.size() > 0);
+
+ if (Call.arg_size() > MaxSupportedArgsInSummary)
+ return false;
+
+ // Exit early if we'll fail anyway
+ for (auto *Fn : Fns) {
+ if (isFunctionExternal(Fn) || Fn->isVarArg())
+ return false;
+ // Fail if the caller does not provide enough arguments
+ assert(Fn->arg_size() <= Call.arg_size());
+ if (!AA.getAliasSummary(*Fn))
+ return false;
+ }
+
+ for (auto *Fn : Fns) {
+ auto Summary = AA.getAliasSummary(*Fn);
+ assert(Summary != nullptr);
+
+ auto &RetParamRelations = Summary->RetParamRelations;
+ for (auto &Relation : RetParamRelations) {
+ auto IRelation = instantiateExternalRelation(Relation, Call);
+ if (IRelation.hasValue()) {
+ Graph.addNode(IRelation->From);
+ Graph.addNode(IRelation->To);
+ Graph.addEdge(IRelation->From, IRelation->To);
+ }
+ }
+
+ auto &RetParamAttributes = Summary->RetParamAttributes;
+ for (auto &Attribute : RetParamAttributes) {
+ auto IAttr = instantiateExternalAttribute(Attribute, Call);
+ if (IAttr.hasValue())
+ Graph.addNode(IAttr->IValue, IAttr->Attr);
+ }
+ }
+
+ return true;
+ }
+
+ void visitCallBase(CallBase &Call) {
+ // Make sure all arguments and return value are added to the graph first
+ for (Value *V : Call.args())
+ if (V->getType()->isPointerTy())
+ addNode(V);
+ if (Call.getType()->isPointerTy())
+ addNode(&Call);
+
+ // Check if Inst is a call to a library function that
+ // allocates/deallocates on the heap. Those kinds of functions do not
+ // introduce any aliases.
+ // TODO: address other common library functions such as realloc(),
+ // strdup(), etc.
+ if (isMallocOrCallocLikeFn(&Call, &TLI) || isFreeCall(&Call, &TLI))
+ return;
+
+ // TODO: Add support for noalias args/all the other fun function
+ // attributes that we can tack on.
+ SmallVector<Function *, 4> Targets;
+ if (getPossibleTargets(Call, Targets))
+ if (tryInterproceduralAnalysis(Call, Targets))
+ return;
+
+ // Because the function is opaque, we need to note that anything
+ // could have happened to the arguments (unless the function is marked
+ // readonly or readnone), and that the result could alias just about
+ // anything, too (unless the result is marked noalias).
+ if (!Call.onlyReadsMemory())
+ for (Value *V : Call.args()) {
+ if (V->getType()->isPointerTy()) {
+ // The argument itself escapes.
+ Graph.addAttr(InstantiatedValue{V, 0}, getAttrEscaped());
+ // The fate of argument memory is unknown. Note that since
+ // AliasAttrs is transitive with respect to dereference, we only
+ // need to specify it for the first-level memory.
+ Graph.addNode(InstantiatedValue{V, 1}, getAttrUnknown());
+ }
+ }
+
+ if (Call.getType()->isPointerTy()) {
+ auto *Fn = Call.getCalledFunction();
+ if (Fn == nullptr || !Fn->returnDoesNotAlias())
+ // No need to call addNode() since we've added Inst at the
+ // beginning of this function and we know it is not a global.
+ Graph.addAttr(InstantiatedValue{&Call, 0}, getAttrUnknown());
+ }
+ }
+
+ /// Because vectors/aggregates are immutable and unaddressable, there's
+ /// nothing we can do to coax a value out of them, other than calling
+ /// Extract{Element,Value}. We can effectively treat them as pointers to
+ /// arbitrary memory locations we can store in and load from.
+ void visitExtractElementInst(ExtractElementInst &Inst) {
+ auto *Ptr = Inst.getVectorOperand();
+ auto *Val = &Inst;
+ addLoadEdge(Ptr, Val);
+ }
+
+ void visitInsertElementInst(InsertElementInst &Inst) {
+ auto *Vec = Inst.getOperand(0);
+ auto *Val = Inst.getOperand(1);
+ addAssignEdge(Vec, &Inst);
+ addStoreEdge(Val, &Inst);
+ }
+
+ void visitLandingPadInst(LandingPadInst &Inst) {
+ // Exceptions come from "nowhere", from our analysis' perspective.
+ // So we place the instruction its own group, noting that said group may
+ // alias externals
+ if (Inst.getType()->isPointerTy())
+ addNode(&Inst, getAttrUnknown());
+ }
+
+ void visitInsertValueInst(InsertValueInst &Inst) {
+ auto *Agg = Inst.getOperand(0);
+ auto *Val = Inst.getOperand(1);
+ addAssignEdge(Agg, &Inst);
+ addStoreEdge(Val, &Inst);
+ }
+
+ void visitExtractValueInst(ExtractValueInst &Inst) {
+ auto *Ptr = Inst.getAggregateOperand();
+ addLoadEdge(Ptr, &Inst);
+ }
+
+ void visitShuffleVectorInst(ShuffleVectorInst &Inst) {
+ auto *From1 = Inst.getOperand(0);
+ auto *From2 = Inst.getOperand(1);
+ addAssignEdge(From1, &Inst);
+ addAssignEdge(From2, &Inst);
+ }
+
+ void visitConstantExpr(ConstantExpr *CE) {
+ switch (CE->getOpcode()) {
+ case Instruction::GetElementPtr: {
+ auto GEPOp = cast<GEPOperator>(CE);
+ visitGEP(*GEPOp);
+ break;
+ }
+
+ case Instruction::PtrToInt: {
+ addNode(CE->getOperand(0), getAttrEscaped());
+ break;
+ }
+
+ case Instruction::IntToPtr: {
+ addNode(CE, getAttrUnknown());
+ break;
+ }
+
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
+ case Instruction::Trunc:
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPExt:
+ case Instruction::FPTrunc:
+ case Instruction::UIToFP:
+ case Instruction::SIToFP:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI: {
+ addAssignEdge(CE->getOperand(0), CE);
+ break;
+ }
+
+ case Instruction::Select: {
+ addAssignEdge(CE->getOperand(1), CE);
+ addAssignEdge(CE->getOperand(2), CE);
+ break;
+ }
+
+ case Instruction::InsertElement:
+ case Instruction::InsertValue: {
+ addAssignEdge(CE->getOperand(0), CE);
+ addStoreEdge(CE->getOperand(1), CE);
+ break;
+ }
+
+ case Instruction::ExtractElement:
+ case Instruction::ExtractValue: {
+ addLoadEdge(CE->getOperand(0), CE);
+ break;
+ }
+
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::FDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::FRem:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ case Instruction::ShuffleVector: {
+ addAssignEdge(CE->getOperand(0), CE);
+ addAssignEdge(CE->getOperand(1), CE);
+ break;
+ }
+
+ case Instruction::FNeg: {
+ addAssignEdge(CE->getOperand(0), CE);
+ break;
+ }
+
+ default:
+ llvm_unreachable("Unknown instruction type encountered!");
+ }
+ }
+ };
+
+ // Helper functions
+
+ // Determines whether or not we an instruction is useless to us (e.g.
+ // FenceInst)
+ static bool hasUsefulEdges(Instruction *Inst) {
+ bool IsNonInvokeRetTerminator = Inst->isTerminator() &&
+ !isa<InvokeInst>(Inst) &&
+ !isa<ReturnInst>(Inst);
+ return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) &&
+ !IsNonInvokeRetTerminator;
+ }
+
+ void addArgumentToGraph(Argument &Arg) {
+ if (Arg.getType()->isPointerTy()) {
+ Graph.addNode(InstantiatedValue{&Arg, 0},
+ getGlobalOrArgAttrFromValue(Arg));
+ // Pointees of a formal parameter is known to the caller
+ Graph.addNode(InstantiatedValue{&Arg, 1}, getAttrCaller());
+ }
+ }
+
+ // Given an Instruction, this will add it to the graph, along with any
+ // Instructions that are potentially only available from said Instruction
+ // For example, given the following line:
+ // %0 = load i16* getelementptr ([1 x i16]* @a, 0, 0), align 2
+ // addInstructionToGraph would add both the `load` and `getelementptr`
+ // instructions to the graph appropriately.
+ void addInstructionToGraph(GetEdgesVisitor &Visitor, Instruction &Inst) {
+ if (!hasUsefulEdges(&Inst))
+ return;
+
+ Visitor.visit(Inst);
+ }
+
+ // Builds the graph needed for constructing the StratifiedSets for the given
+ // function
+ void buildGraphFrom(Function &Fn) {
+ GetEdgesVisitor Visitor(*this, Fn.getParent()->getDataLayout());
+
+ for (auto &Bb : Fn.getBasicBlockList())
+ for (auto &Inst : Bb.getInstList())
+ addInstructionToGraph(Visitor, Inst);
+
+ for (auto &Arg : Fn.args())
+ addArgumentToGraph(Arg);
+ }
+
+public:
+ CFLGraphBuilder(CFLAA &Analysis, const TargetLibraryInfo &TLI, Function &Fn)
+ : Analysis(Analysis), TLI(TLI) {
+ buildGraphFrom(Fn);
+ }
+
+ const CFLGraph &getCFLGraph() const { return Graph; }
+ const SmallVector<Value *, 4> &getReturnValues() const {
+ return ReturnedValues;
+ }
+};
+
+} // end namespace cflaa
+} // end namespace llvm
+
+#endif // LLVM_LIB_ANALYSIS_CFLGRAPH_H
diff --git a/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp
new file mode 100644
index 000000000000..b87aa4065392
--- /dev/null
+++ b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp
@@ -0,0 +1,363 @@
+//===- CFLSteensAliasAnalysis.cpp - Unification-based Alias Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a CFL-base, summary-based alias analysis algorithm. It
+// does not depend on types. The algorithm is a mixture of the one described in
+// "Demand-driven alias analysis for C" by Xin Zheng and Radu Rugina, and "Fast
+// algorithms for Dyck-CFL-reachability with applications to Alias Analysis" by
+// Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the papers, we build a
+// graph of the uses of a variable, where each node is a memory location, and
+// each edge is an action that happened on that memory location. The "actions"
+// can be one of Dereference, Reference, or Assign. The precision of this
+// analysis is roughly the same as that of an one level context-sensitive
+// Steensgaard's algorithm.
+//
+// Two variables are considered as aliasing iff you can reach one value's node
+// from the other value's node and the language formed by concatenating all of
+// the edge labels (actions) conforms to a context-free grammar.
+//
+// Because this algorithm requires a graph search on each query, we execute the
+// algorithm outlined in "Fast algorithms..." (mentioned above)
+// in order to transform the graph into sets of variables that may alias in
+// ~nlogn time (n = number of variables), which makes queries take constant
+// time.
+//===----------------------------------------------------------------------===//
+
+// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and
+// CFLSteensAA is interprocedural. This is *technically* A Bad Thing, because
+// FunctionPasses are only allowed to inspect the Function that they're being
+// run on. Realistically, this likely isn't a problem until we allow
+// FunctionPasses to run concurrently.
+
+#include "llvm/Analysis/CFLSteensAliasAnalysis.h"
+#include "AliasAnalysisSummary.h"
+#include "CFLGraph.h"
+#include "StratifiedSets.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <limits>
+#include <memory>
+#include <utility>
+
+using namespace llvm;
+using namespace llvm::cflaa;
+
+#define DEBUG_TYPE "cfl-steens-aa"
+
+CFLSteensAAResult::CFLSteensAAResult(
+ std::function<const TargetLibraryInfo &(Function &F)> GetTLI)
+ : AAResultBase(), GetTLI(std::move(GetTLI)) {}
+CFLSteensAAResult::CFLSteensAAResult(CFLSteensAAResult &&Arg)
+ : AAResultBase(std::move(Arg)), GetTLI(std::move(Arg.GetTLI)) {}
+CFLSteensAAResult::~CFLSteensAAResult() = default;
+
+/// Information we have about a function and would like to keep around.
+class CFLSteensAAResult::FunctionInfo {
+ StratifiedSets<InstantiatedValue> Sets;
+ AliasSummary Summary;
+
+public:
+ FunctionInfo(Function &Fn, const SmallVectorImpl<Value *> &RetVals,
+ StratifiedSets<InstantiatedValue> S);
+
+ const StratifiedSets<InstantiatedValue> &getStratifiedSets() const {
+ return Sets;
+ }
+
+ const AliasSummary &getAliasSummary() const { return Summary; }
+};
+
+const StratifiedIndex StratifiedLink::SetSentinel =
+ std::numeric_limits<StratifiedIndex>::max();
+
+//===----------------------------------------------------------------------===//
+// Function declarations that require types defined in the namespace above
+//===----------------------------------------------------------------------===//
+
+/// Determines whether it would be pointless to add the given Value to our sets.
+static bool canSkipAddingToSets(Value *Val) {
+ // Constants can share instances, which may falsely unify multiple
+ // sets, e.g. in
+ // store i32* null, i32** %ptr1
+ // store i32* null, i32** %ptr2
+ // clearly ptr1 and ptr2 should not be unified into the same set, so
+ // we should filter out the (potentially shared) instance to
+ // i32* null.
+ if (isa<Constant>(Val)) {
+ // TODO: Because all of these things are constant, we can determine whether
+ // the data is *actually* mutable at graph building time. This will probably
+ // come for free/cheap with offset awareness.
+ bool CanStoreMutableData = isa<GlobalValue>(Val) ||
+ isa<ConstantExpr>(Val) ||
+ isa<ConstantAggregate>(Val);
+ return !CanStoreMutableData;
+ }
+
+ return false;
+}
+
+CFLSteensAAResult::FunctionInfo::FunctionInfo(
+ Function &Fn, const SmallVectorImpl<Value *> &RetVals,
+ StratifiedSets<InstantiatedValue> S)
+ : Sets(std::move(S)) {
+ // Historically, an arbitrary upper-bound of 50 args was selected. We may want
+ // to remove this if it doesn't really matter in practice.
+ if (Fn.arg_size() > MaxSupportedArgsInSummary)
+ return;
+
+ DenseMap<StratifiedIndex, InterfaceValue> InterfaceMap;
+
+ // Our intention here is to record all InterfaceValues that share the same
+ // StratifiedIndex in RetParamRelations. For each valid InterfaceValue, we
+ // have its StratifiedIndex scanned here and check if the index is presented
+ // in InterfaceMap: if it is not, we add the correspondence to the map;
+ // otherwise, an aliasing relation is found and we add it to
+ // RetParamRelations.
+
+ auto AddToRetParamRelations = [&](unsigned InterfaceIndex,
+ StratifiedIndex SetIndex) {
+ unsigned Level = 0;
+ while (true) {
+ InterfaceValue CurrValue{InterfaceIndex, Level};
+
+ auto Itr = InterfaceMap.find(SetIndex);
+ if (Itr != InterfaceMap.end()) {
+ if (CurrValue != Itr->second)
+ Summary.RetParamRelations.push_back(
+ ExternalRelation{CurrValue, Itr->second, UnknownOffset});
+ break;
+ }
+
+ auto &Link = Sets.getLink(SetIndex);
+ InterfaceMap.insert(std::make_pair(SetIndex, CurrValue));
+ auto ExternalAttrs = getExternallyVisibleAttrs(Link.Attrs);
+ if (ExternalAttrs.any())
+ Summary.RetParamAttributes.push_back(
+ ExternalAttribute{CurrValue, ExternalAttrs});
+
+ if (!Link.hasBelow())
+ break;
+
+ ++Level;
+ SetIndex = Link.Below;
+ }
+ };
+
+ // Populate RetParamRelations for return values
+ for (auto *RetVal : RetVals) {
+ assert(RetVal != nullptr);
+ assert(RetVal->getType()->isPointerTy());
+ auto RetInfo = Sets.find(InstantiatedValue{RetVal, 0});
+ if (RetInfo.hasValue())
+ AddToRetParamRelations(0, RetInfo->Index);
+ }
+
+ // Populate RetParamRelations for parameters
+ unsigned I = 0;
+ for (auto &Param : Fn.args()) {
+ if (Param.getType()->isPointerTy()) {
+ auto ParamInfo = Sets.find(InstantiatedValue{&Param, 0});
+ if (ParamInfo.hasValue())
+ AddToRetParamRelations(I + 1, ParamInfo->Index);
+ }
+ ++I;
+ }
+}
+
+// Builds the graph + StratifiedSets for a function.
+CFLSteensAAResult::FunctionInfo CFLSteensAAResult::buildSetsFrom(Function *Fn) {
+ CFLGraphBuilder<CFLSteensAAResult> GraphBuilder(*this, GetTLI(*Fn), *Fn);
+ StratifiedSetsBuilder<InstantiatedValue> SetBuilder;
+
+ // Add all CFLGraph nodes and all Dereference edges to StratifiedSets
+ auto &Graph = GraphBuilder.getCFLGraph();
+ for (const auto &Mapping : Graph.value_mappings()) {
+ auto Val = Mapping.first;
+ if (canSkipAddingToSets(Val))
+ continue;
+ auto &ValueInfo = Mapping.second;
+
+ assert(ValueInfo.getNumLevels() > 0);
+ SetBuilder.add(InstantiatedValue{Val, 0});
+ SetBuilder.noteAttributes(InstantiatedValue{Val, 0},
+ ValueInfo.getNodeInfoAtLevel(0).Attr);
+ for (unsigned I = 0, E = ValueInfo.getNumLevels() - 1; I < E; ++I) {
+ SetBuilder.add(InstantiatedValue{Val, I + 1});
+ SetBuilder.noteAttributes(InstantiatedValue{Val, I + 1},
+ ValueInfo.getNodeInfoAtLevel(I + 1).Attr);
+ SetBuilder.addBelow(InstantiatedValue{Val, I},
+ InstantiatedValue{Val, I + 1});
+ }
+ }
+
+ // Add all assign edges to StratifiedSets
+ for (const auto &Mapping : Graph.value_mappings()) {
+ auto Val = Mapping.first;
+ if (canSkipAddingToSets(Val))
+ continue;
+ auto &ValueInfo = Mapping.second;
+
+ for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) {
+ auto Src = InstantiatedValue{Val, I};
+ for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges)
+ SetBuilder.addWith(Src, Edge.Other);
+ }
+ }
+
+ return FunctionInfo(*Fn, GraphBuilder.getReturnValues(), SetBuilder.build());
+}
+
+void CFLSteensAAResult::scan(Function *Fn) {
+ auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>()));
+ (void)InsertPair;
+ assert(InsertPair.second &&
+ "Trying to scan a function that has already been cached");
+
+ // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call
+ // may get evaluated after operator[], potentially triggering a DenseMap
+ // resize and invalidating the reference returned by operator[]
+ auto FunInfo = buildSetsFrom(Fn);
+ Cache[Fn] = std::move(FunInfo);
+
+ Handles.emplace_front(Fn, this);
+}
+
+void CFLSteensAAResult::evict(Function *Fn) { Cache.erase(Fn); }
+
+/// Ensures that the given function is available in the cache, and returns the
+/// entry.
+const Optional<CFLSteensAAResult::FunctionInfo> &
+CFLSteensAAResult::ensureCached(Function *Fn) {
+ auto Iter = Cache.find(Fn);
+ if (Iter == Cache.end()) {
+ scan(Fn);
+ Iter = Cache.find(Fn);
+ assert(Iter != Cache.end());
+ assert(Iter->second.hasValue());
+ }
+ return Iter->second;
+}
+
+const AliasSummary *CFLSteensAAResult::getAliasSummary(Function &Fn) {
+ auto &FunInfo = ensureCached(&Fn);
+ if (FunInfo.hasValue())
+ return &FunInfo->getAliasSummary();
+ else
+ return nullptr;
+}
+
+AliasResult CFLSteensAAResult::query(const MemoryLocation &LocA,
+ const MemoryLocation &LocB) {
+ auto *ValA = const_cast<Value *>(LocA.Ptr);
+ auto *ValB = const_cast<Value *>(LocB.Ptr);
+
+ if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy())
+ return NoAlias;
+
+ Function *Fn = nullptr;
+ Function *MaybeFnA = const_cast<Function *>(parentFunctionOfValue(ValA));
+ Function *MaybeFnB = const_cast<Function *>(parentFunctionOfValue(ValB));
+ if (!MaybeFnA && !MaybeFnB) {
+ // The only times this is known to happen are when globals + InlineAsm are
+ // involved
+ LLVM_DEBUG(
+ dbgs()
+ << "CFLSteensAA: could not extract parent function information.\n");
+ return MayAlias;
+ }
+
+ if (MaybeFnA) {
+ Fn = MaybeFnA;
+ assert((!MaybeFnB || MaybeFnB == MaybeFnA) &&
+ "Interprocedural queries not supported");
+ } else {
+ Fn = MaybeFnB;
+ }
+
+ assert(Fn != nullptr);
+ auto &MaybeInfo = ensureCached(Fn);
+ assert(MaybeInfo.hasValue());
+
+ auto &Sets = MaybeInfo->getStratifiedSets();
+ auto MaybeA = Sets.find(InstantiatedValue{ValA, 0});
+ if (!MaybeA.hasValue())
+ return MayAlias;
+
+ auto MaybeB = Sets.find(InstantiatedValue{ValB, 0});
+ if (!MaybeB.hasValue())
+ return MayAlias;
+
+ auto SetA = *MaybeA;
+ auto SetB = *MaybeB;
+ auto AttrsA = Sets.getLink(SetA.Index).Attrs;
+ auto AttrsB = Sets.getLink(SetB.Index).Attrs;
+
+ // If both values are local (meaning the corresponding set has attribute
+ // AttrNone or AttrEscaped), then we know that CFLSteensAA fully models them:
+ // they may-alias each other if and only if they are in the same set.
+ // If at least one value is non-local (meaning it either is global/argument or
+ // it comes from unknown sources like integer cast), the situation becomes a
+ // bit more interesting. We follow three general rules described below:
+ // - Non-local values may alias each other
+ // - AttrNone values do not alias any non-local values
+ // - AttrEscaped do not alias globals/arguments, but they may alias
+ // AttrUnknown values
+ if (SetA.Index == SetB.Index)
+ return MayAlias;
+ if (AttrsA.none() || AttrsB.none())
+ return NoAlias;
+ if (hasUnknownOrCallerAttr(AttrsA) || hasUnknownOrCallerAttr(AttrsB))
+ return MayAlias;
+ if (isGlobalOrArgAttr(AttrsA) && isGlobalOrArgAttr(AttrsB))
+ return MayAlias;
+ return NoAlias;
+}
+
+AnalysisKey CFLSteensAA::Key;
+
+CFLSteensAAResult CFLSteensAA::run(Function &F, FunctionAnalysisManager &AM) {
+ auto GetTLI = [&AM](Function &F) -> const TargetLibraryInfo & {
+ return AM.getResult<TargetLibraryAnalysis>(F);
+ };
+ return CFLSteensAAResult(GetTLI);
+}
+
+char CFLSteensAAWrapperPass::ID = 0;
+INITIALIZE_PASS(CFLSteensAAWrapperPass, "cfl-steens-aa",
+ "Unification-Based CFL Alias Analysis", false, true)
+
+ImmutablePass *llvm::createCFLSteensAAWrapperPass() {
+ return new CFLSteensAAWrapperPass();
+}
+
+CFLSteensAAWrapperPass::CFLSteensAAWrapperPass() : ImmutablePass(ID) {
+ initializeCFLSteensAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void CFLSteensAAWrapperPass::initializePass() {
+ auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+ Result.reset(new CFLSteensAAResult(GetTLI));
+}
+
+void CFLSteensAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/CGSCCPassManager.cpp b/llvm/lib/Analysis/CGSCCPassManager.cpp
new file mode 100644
index 000000000000..a0b3f83cca6a
--- /dev/null
+++ b/llvm/lib/Analysis/CGSCCPassManager.cpp
@@ -0,0 +1,709 @@
+//===- CGSCCPassManager.cpp - Managing & running CGSCC passes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <iterator>
+
+#define DEBUG_TYPE "cgscc"
+
+using namespace llvm;
+
+// Explicit template instantiations and specialization definitions for core
+// template typedefs.
+namespace llvm {
+
+// Explicit instantiations for the core proxy templates.
+template class AllAnalysesOn<LazyCallGraph::SCC>;
+template class AnalysisManager<LazyCallGraph::SCC, LazyCallGraph &>;
+template class PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager,
+ LazyCallGraph &, CGSCCUpdateResult &>;
+template class InnerAnalysisManagerProxy<CGSCCAnalysisManager, Module>;
+template class OuterAnalysisManagerProxy<ModuleAnalysisManager,
+ LazyCallGraph::SCC, LazyCallGraph &>;
+template class OuterAnalysisManagerProxy<CGSCCAnalysisManager, Function>;
+
+/// Explicitly specialize the pass manager run method to handle call graph
+/// updates.
+template <>
+PreservedAnalyses
+PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &,
+ CGSCCUpdateResult &>::run(LazyCallGraph::SCC &InitialC,
+ CGSCCAnalysisManager &AM,
+ LazyCallGraph &G, CGSCCUpdateResult &UR) {
+ // Request PassInstrumentation from analysis manager, will use it to run
+ // instrumenting callbacks for the passes later.
+ PassInstrumentation PI =
+ AM.getResult<PassInstrumentationAnalysis>(InitialC, G);
+
+ PreservedAnalyses PA = PreservedAnalyses::all();
+
+ if (DebugLogging)
+ dbgs() << "Starting CGSCC pass manager run.\n";
+
+ // The SCC may be refined while we are running passes over it, so set up
+ // a pointer that we can update.
+ LazyCallGraph::SCC *C = &InitialC;
+
+ for (auto &Pass : Passes) {
+ if (DebugLogging)
+ dbgs() << "Running pass: " << Pass->name() << " on " << *C << "\n";
+
+ // Check the PassInstrumentation's BeforePass callbacks before running the
+ // pass, skip its execution completely if asked to (callback returns false).
+ if (!PI.runBeforePass(*Pass, *C))
+ continue;
+
+ PreservedAnalyses PassPA = Pass->run(*C, AM, G, UR);
+
+ if (UR.InvalidatedSCCs.count(C))
+ PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass);
+ else
+ PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C);
+
+ // Update the SCC if necessary.
+ C = UR.UpdatedC ? UR.UpdatedC : C;
+
+ // If the CGSCC pass wasn't able to provide a valid updated SCC, the
+ // current SCC may simply need to be skipped if invalid.
+ if (UR.InvalidatedSCCs.count(C)) {
+ LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n");
+ break;
+ }
+ // Check that we didn't miss any update scenario.
+ assert(C->begin() != C->end() && "Cannot have an empty SCC!");
+
+ // Update the analysis manager as each pass runs and potentially
+ // invalidates analyses.
+ AM.invalidate(*C, PassPA);
+
+ // Finally, we intersect the final preserved analyses to compute the
+ // aggregate preserved set for this pass manager.
+ PA.intersect(std::move(PassPA));
+
+ // FIXME: Historically, the pass managers all called the LLVM context's
+ // yield function here. We don't have a generic way to acquire the
+ // context and it isn't yet clear what the right pattern is for yielding
+ // in the new pass manager so it is currently omitted.
+ // ...getContext().yield();
+ }
+
+ // Before we mark all of *this* SCC's analyses as preserved below, intersect
+ // this with the cross-SCC preserved analysis set. This is used to allow
+ // CGSCC passes to mutate ancestor SCCs and still trigger proper invalidation
+ // for them.
+ UR.CrossSCCPA.intersect(PA);
+
+ // Invalidation was handled after each pass in the above loop for the current
+ // SCC. Therefore, the remaining analysis results in the AnalysisManager are
+ // preserved. We mark this with a set so that we don't need to inspect each
+ // one individually.
+ PA.preserveSet<AllAnalysesOn<LazyCallGraph::SCC>>();
+
+ if (DebugLogging)
+ dbgs() << "Finished CGSCC pass manager run.\n";
+
+ return PA;
+}
+
+bool CGSCCAnalysisManagerModuleProxy::Result::invalidate(
+ Module &M, const PreservedAnalyses &PA,
+ ModuleAnalysisManager::Invalidator &Inv) {
+ // If literally everything is preserved, we're done.
+ if (PA.areAllPreserved())
+ return false; // This is still a valid proxy.
+
+ // If this proxy or the call graph is going to be invalidated, we also need
+ // to clear all the keys coming from that analysis.
+ //
+ // We also directly invalidate the FAM's module proxy if necessary, and if
+ // that proxy isn't preserved we can't preserve this proxy either. We rely on
+ // it to handle module -> function analysis invalidation in the face of
+ // structural changes and so if it's unavailable we conservatively clear the
+ // entire SCC layer as well rather than trying to do invalidation ourselves.
+ auto PAC = PA.getChecker<CGSCCAnalysisManagerModuleProxy>();
+ if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>()) ||
+ Inv.invalidate<LazyCallGraphAnalysis>(M, PA) ||
+ Inv.invalidate<FunctionAnalysisManagerModuleProxy>(M, PA)) {
+ InnerAM->clear();
+
+ // And the proxy itself should be marked as invalid so that we can observe
+ // the new call graph. This isn't strictly necessary because we cheat
+ // above, but is still useful.
+ return true;
+ }
+
+ // Directly check if the relevant set is preserved so we can short circuit
+ // invalidating SCCs below.
+ bool AreSCCAnalysesPreserved =
+ PA.allAnalysesInSetPreserved<AllAnalysesOn<LazyCallGraph::SCC>>();
+
+ // Ok, we have a graph, so we can propagate the invalidation down into it.
+ G->buildRefSCCs();
+ for (auto &RC : G->postorder_ref_sccs())
+ for (auto &C : RC) {
+ Optional<PreservedAnalyses> InnerPA;
+
+ // Check to see whether the preserved set needs to be adjusted based on
+ // module-level analysis invalidation triggering deferred invalidation
+ // for this SCC.
+ if (auto *OuterProxy =
+ InnerAM->getCachedResult<ModuleAnalysisManagerCGSCCProxy>(C))
+ for (const auto &OuterInvalidationPair :
+ OuterProxy->getOuterInvalidations()) {
+ AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first;
+ const auto &InnerAnalysisIDs = OuterInvalidationPair.second;
+ if (Inv.invalidate(OuterAnalysisID, M, PA)) {
+ if (!InnerPA)
+ InnerPA = PA;
+ for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs)
+ InnerPA->abandon(InnerAnalysisID);
+ }
+ }
+
+ // Check if we needed a custom PA set. If so we'll need to run the inner
+ // invalidation.
+ if (InnerPA) {
+ InnerAM->invalidate(C, *InnerPA);
+ continue;
+ }
+
+ // Otherwise we only need to do invalidation if the original PA set didn't
+ // preserve all SCC analyses.
+ if (!AreSCCAnalysesPreserved)
+ InnerAM->invalidate(C, PA);
+ }
+
+ // Return false to indicate that this result is still a valid proxy.
+ return false;
+}
+
+template <>
+CGSCCAnalysisManagerModuleProxy::Result
+CGSCCAnalysisManagerModuleProxy::run(Module &M, ModuleAnalysisManager &AM) {
+ // Force the Function analysis manager to also be available so that it can
+ // be accessed in an SCC analysis and proxied onward to function passes.
+ // FIXME: It is pretty awkward to just drop the result here and assert that
+ // we can find it again later.
+ (void)AM.getResult<FunctionAnalysisManagerModuleProxy>(M);
+
+ return Result(*InnerAM, AM.getResult<LazyCallGraphAnalysis>(M));
+}
+
+AnalysisKey FunctionAnalysisManagerCGSCCProxy::Key;
+
+FunctionAnalysisManagerCGSCCProxy::Result
+FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC &C,
+ CGSCCAnalysisManager &AM,
+ LazyCallGraph &CG) {
+ // Collect the FunctionAnalysisManager from the Module layer and use that to
+ // build the proxy result.
+ //
+ // This allows us to rely on the FunctionAnalysisMangaerModuleProxy to
+ // invalidate the function analyses.
+ auto &MAM = AM.getResult<ModuleAnalysisManagerCGSCCProxy>(C, CG).getManager();
+ Module &M = *C.begin()->getFunction().getParent();
+ auto *FAMProxy = MAM.getCachedResult<FunctionAnalysisManagerModuleProxy>(M);
+ assert(FAMProxy && "The CGSCC pass manager requires that the FAM module "
+ "proxy is run on the module prior to entering the CGSCC "
+ "walk.");
+
+ // Note that we special-case invalidation handling of this proxy in the CGSCC
+ // analysis manager's Module proxy. This avoids the need to do anything
+ // special here to recompute all of this if ever the FAM's module proxy goes
+ // away.
+ return Result(FAMProxy->getManager());
+}
+
+bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate(
+ LazyCallGraph::SCC &C, const PreservedAnalyses &PA,
+ CGSCCAnalysisManager::Invalidator &Inv) {
+ // If literally everything is preserved, we're done.
+ if (PA.areAllPreserved())
+ return false; // This is still a valid proxy.
+
+ // If this proxy isn't marked as preserved, then even if the result remains
+ // valid, the key itself may no longer be valid, so we clear everything.
+ //
+ // Note that in order to preserve this proxy, a module pass must ensure that
+ // the FAM has been completely updated to handle the deletion of functions.
+ // Specifically, any FAM-cached results for those functions need to have been
+ // forcibly cleared. When preserved, this proxy will only invalidate results
+ // cached on functions *still in the module* at the end of the module pass.
+ auto PAC = PA.getChecker<FunctionAnalysisManagerCGSCCProxy>();
+ if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<LazyCallGraph::SCC>>()) {
+ for (LazyCallGraph::Node &N : C)
+ FAM->clear(N.getFunction(), N.getFunction().getName());
+
+ return true;
+ }
+
+ // Directly check if the relevant set is preserved.
+ bool AreFunctionAnalysesPreserved =
+ PA.allAnalysesInSetPreserved<AllAnalysesOn<Function>>();
+
+ // Now walk all the functions to see if any inner analysis invalidation is
+ // necessary.
+ for (LazyCallGraph::Node &N : C) {
+ Function &F = N.getFunction();
+ Optional<PreservedAnalyses> FunctionPA;
+
+ // Check to see whether the preserved set needs to be pruned based on
+ // SCC-level analysis invalidation that triggers deferred invalidation
+ // registered with the outer analysis manager proxy for this function.
+ if (auto *OuterProxy =
+ FAM->getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F))
+ for (const auto &OuterInvalidationPair :
+ OuterProxy->getOuterInvalidations()) {
+ AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first;
+ const auto &InnerAnalysisIDs = OuterInvalidationPair.second;
+ if (Inv.invalidate(OuterAnalysisID, C, PA)) {
+ if (!FunctionPA)
+ FunctionPA = PA;
+ for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs)
+ FunctionPA->abandon(InnerAnalysisID);
+ }
+ }
+
+ // Check if we needed a custom PA set, and if so we'll need to run the
+ // inner invalidation.
+ if (FunctionPA) {
+ FAM->invalidate(F, *FunctionPA);
+ continue;
+ }
+
+ // Otherwise we only need to do invalidation if the original PA set didn't
+ // preserve all function analyses.
+ if (!AreFunctionAnalysesPreserved)
+ FAM->invalidate(F, PA);
+ }
+
+ // Return false to indicate that this result is still a valid proxy.
+ return false;
+}
+
+} // end namespace llvm
+
+/// When a new SCC is created for the graph and there might be function
+/// analysis results cached for the functions now in that SCC two forms of
+/// updates are required.
+///
+/// First, a proxy from the SCC to the FunctionAnalysisManager needs to be
+/// created so that any subsequent invalidation events to the SCC are
+/// propagated to the function analysis results cached for functions within it.
+///
+/// Second, if any of the functions within the SCC have analysis results with
+/// outer analysis dependencies, then those dependencies would point to the
+/// *wrong* SCC's analysis result. We forcibly invalidate the necessary
+/// function analyses so that they don't retain stale handles.
+static void updateNewSCCFunctionAnalyses(LazyCallGraph::SCC &C,
+ LazyCallGraph &G,
+ CGSCCAnalysisManager &AM) {
+ // Get the relevant function analysis manager.
+ auto &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, G).getManager();
+
+ // Now walk the functions in this SCC and invalidate any function analysis
+ // results that might have outer dependencies on an SCC analysis.
+ for (LazyCallGraph::Node &N : C) {
+ Function &F = N.getFunction();
+
+ auto *OuterProxy =
+ FAM.getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F);
+ if (!OuterProxy)
+ // No outer analyses were queried, nothing to do.
+ continue;
+
+ // Forcibly abandon all the inner analyses with dependencies, but
+ // invalidate nothing else.
+ auto PA = PreservedAnalyses::all();
+ for (const auto &OuterInvalidationPair :
+ OuterProxy->getOuterInvalidations()) {
+ const auto &InnerAnalysisIDs = OuterInvalidationPair.second;
+ for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs)
+ PA.abandon(InnerAnalysisID);
+ }
+
+ // Now invalidate anything we found.
+ FAM.invalidate(F, PA);
+ }
+}
+
+/// Helper function to update both the \c CGSCCAnalysisManager \p AM and the \c
+/// CGSCCPassManager's \c CGSCCUpdateResult \p UR based on a range of newly
+/// added SCCs.
+///
+/// The range of new SCCs must be in postorder already. The SCC they were split
+/// out of must be provided as \p C. The current node being mutated and
+/// triggering updates must be passed as \p N.
+///
+/// This function returns the SCC containing \p N. This will be either \p C if
+/// no new SCCs have been split out, or it will be the new SCC containing \p N.
+template <typename SCCRangeT>
+static LazyCallGraph::SCC *
+incorporateNewSCCRange(const SCCRangeT &NewSCCRange, LazyCallGraph &G,
+ LazyCallGraph::Node &N, LazyCallGraph::SCC *C,
+ CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) {
+ using SCC = LazyCallGraph::SCC;
+
+ if (NewSCCRange.begin() == NewSCCRange.end())
+ return C;
+
+ // Add the current SCC to the worklist as its shape has changed.
+ UR.CWorklist.insert(C);
+ LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist:" << *C
+ << "\n");
+
+ SCC *OldC = C;
+
+ // Update the current SCC. Note that if we have new SCCs, this must actually
+ // change the SCC.
+ assert(C != &*NewSCCRange.begin() &&
+ "Cannot insert new SCCs without changing current SCC!");
+ C = &*NewSCCRange.begin();
+ assert(G.lookupSCC(N) == C && "Failed to update current SCC!");
+
+ // If we had a cached FAM proxy originally, we will want to create more of
+ // them for each SCC that was split off.
+ bool NeedFAMProxy =
+ AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>(*OldC) != nullptr;
+
+ // We need to propagate an invalidation call to all but the newly current SCC
+ // because the outer pass manager won't do that for us after splitting them.
+ // FIXME: We should accept a PreservedAnalysis from the CG updater so that if
+ // there are preserved analysis we can avoid invalidating them here for
+ // split-off SCCs.
+ // We know however that this will preserve any FAM proxy so go ahead and mark
+ // that.
+ PreservedAnalyses PA;
+ PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+ AM.invalidate(*OldC, PA);
+
+ // Ensure the now-current SCC's function analyses are updated.
+ if (NeedFAMProxy)
+ updateNewSCCFunctionAnalyses(*C, G, AM);
+
+ for (SCC &NewC : llvm::reverse(make_range(std::next(NewSCCRange.begin()),
+ NewSCCRange.end()))) {
+ assert(C != &NewC && "No need to re-visit the current SCC!");
+ assert(OldC != &NewC && "Already handled the original SCC!");
+ UR.CWorklist.insert(&NewC);
+ LLVM_DEBUG(dbgs() << "Enqueuing a newly formed SCC:" << NewC << "\n");
+
+ // Ensure new SCCs' function analyses are updated.
+ if (NeedFAMProxy)
+ updateNewSCCFunctionAnalyses(NewC, G, AM);
+
+ // Also propagate a normal invalidation to the new SCC as only the current
+ // will get one from the pass manager infrastructure.
+ AM.invalidate(NewC, PA);
+ }
+ return C;
+}
+
+LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass(
+ LazyCallGraph &G, LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N,
+ CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) {
+ using Node = LazyCallGraph::Node;
+ using Edge = LazyCallGraph::Edge;
+ using SCC = LazyCallGraph::SCC;
+ using RefSCC = LazyCallGraph::RefSCC;
+
+ RefSCC &InitialRC = InitialC.getOuterRefSCC();
+ SCC *C = &InitialC;
+ RefSCC *RC = &InitialRC;
+ Function &F = N.getFunction();
+
+ // Walk the function body and build up the set of retained, promoted, and
+ // demoted edges.
+ SmallVector<Constant *, 16> Worklist;
+ SmallPtrSet<Constant *, 16> Visited;
+ SmallPtrSet<Node *, 16> RetainedEdges;
+ SmallSetVector<Node *, 4> PromotedRefTargets;
+ SmallSetVector<Node *, 4> DemotedCallTargets;
+
+ // First walk the function and handle all called functions. We do this first
+ // because if there is a single call edge, whether there are ref edges is
+ // irrelevant.
+ for (Instruction &I : instructions(F))
+ if (auto CS = CallSite(&I))
+ if (Function *Callee = CS.getCalledFunction())
+ if (Visited.insert(Callee).second && !Callee->isDeclaration()) {
+ Node &CalleeN = *G.lookup(*Callee);
+ Edge *E = N->lookup(CalleeN);
+ // FIXME: We should really handle adding new calls. While it will
+ // make downstream usage more complex, there is no fundamental
+ // limitation and it will allow passes within the CGSCC to be a bit
+ // more flexible in what transforms they can do. Until then, we
+ // verify that new calls haven't been introduced.
+ assert(E && "No function transformations should introduce *new* "
+ "call edges! Any new calls should be modeled as "
+ "promoted existing ref edges!");
+ bool Inserted = RetainedEdges.insert(&CalleeN).second;
+ (void)Inserted;
+ assert(Inserted && "We should never visit a function twice.");
+ if (!E->isCall())
+ PromotedRefTargets.insert(&CalleeN);
+ }
+
+ // Now walk all references.
+ for (Instruction &I : instructions(F))
+ for (Value *Op : I.operand_values())
+ if (auto *C = dyn_cast<Constant>(Op))
+ if (Visited.insert(C).second)
+ Worklist.push_back(C);
+
+ auto VisitRef = [&](Function &Referee) {
+ Node &RefereeN = *G.lookup(Referee);
+ Edge *E = N->lookup(RefereeN);
+ // FIXME: Similarly to new calls, we also currently preclude
+ // introducing new references. See above for details.
+ assert(E && "No function transformations should introduce *new* ref "
+ "edges! Any new ref edges would require IPO which "
+ "function passes aren't allowed to do!");
+ bool Inserted = RetainedEdges.insert(&RefereeN).second;
+ (void)Inserted;
+ assert(Inserted && "We should never visit a function twice.");
+ if (E->isCall())
+ DemotedCallTargets.insert(&RefereeN);
+ };
+ LazyCallGraph::visitReferences(Worklist, Visited, VisitRef);
+
+ // Include synthetic reference edges to known, defined lib functions.
+ for (auto *F : G.getLibFunctions())
+ // While the list of lib functions doesn't have repeats, don't re-visit
+ // anything handled above.
+ if (!Visited.count(F))
+ VisitRef(*F);
+
+ // First remove all of the edges that are no longer present in this function.
+ // The first step makes these edges uniformly ref edges and accumulates them
+ // into a separate data structure so removal doesn't invalidate anything.
+ SmallVector<Node *, 4> DeadTargets;
+ for (Edge &E : *N) {
+ if (RetainedEdges.count(&E.getNode()))
+ continue;
+
+ SCC &TargetC = *G.lookupSCC(E.getNode());
+ RefSCC &TargetRC = TargetC.getOuterRefSCC();
+ if (&TargetRC == RC && E.isCall()) {
+ if (C != &TargetC) {
+ // For separate SCCs this is trivial.
+ RC->switchTrivialInternalEdgeToRef(N, E.getNode());
+ } else {
+ // Now update the call graph.
+ C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, E.getNode()),
+ G, N, C, AM, UR);
+ }
+ }
+
+ // Now that this is ready for actual removal, put it into our list.
+ DeadTargets.push_back(&E.getNode());
+ }
+ // Remove the easy cases quickly and actually pull them out of our list.
+ DeadTargets.erase(
+ llvm::remove_if(DeadTargets,
+ [&](Node *TargetN) {
+ SCC &TargetC = *G.lookupSCC(*TargetN);
+ RefSCC &TargetRC = TargetC.getOuterRefSCC();
+
+ // We can't trivially remove internal targets, so skip
+ // those.
+ if (&TargetRC == RC)
+ return false;
+
+ RC->removeOutgoingEdge(N, *TargetN);
+ LLVM_DEBUG(dbgs() << "Deleting outgoing edge from '"
+ << N << "' to '" << TargetN << "'\n");
+ return true;
+ }),
+ DeadTargets.end());
+
+ // Now do a batch removal of the internal ref edges left.
+ auto NewRefSCCs = RC->removeInternalRefEdge(N, DeadTargets);
+ if (!NewRefSCCs.empty()) {
+ // The old RefSCC is dead, mark it as such.
+ UR.InvalidatedRefSCCs.insert(RC);
+
+ // Note that we don't bother to invalidate analyses as ref-edge
+ // connectivity is not really observable in any way and is intended
+ // exclusively to be used for ordering of transforms rather than for
+ // analysis conclusions.
+
+ // Update RC to the "bottom".
+ assert(G.lookupSCC(N) == C && "Changed the SCC when splitting RefSCCs!");
+ RC = &C->getOuterRefSCC();
+ assert(G.lookupRefSCC(N) == RC && "Failed to update current RefSCC!");
+
+ // The RC worklist is in reverse postorder, so we enqueue the new ones in
+ // RPO except for the one which contains the source node as that is the
+ // "bottom" we will continue processing in the bottom-up walk.
+ assert(NewRefSCCs.front() == RC &&
+ "New current RefSCC not first in the returned list!");
+ for (RefSCC *NewRC : llvm::reverse(make_range(std::next(NewRefSCCs.begin()),
+ NewRefSCCs.end()))) {
+ assert(NewRC != RC && "Should not encounter the current RefSCC further "
+ "in the postorder list of new RefSCCs.");
+ UR.RCWorklist.insert(NewRC);
+ LLVM_DEBUG(dbgs() << "Enqueuing a new RefSCC in the update worklist: "
+ << *NewRC << "\n");
+ }
+ }
+
+ // Next demote all the call edges that are now ref edges. This helps make
+ // the SCCs small which should minimize the work below as we don't want to
+ // form cycles that this would break.
+ for (Node *RefTarget : DemotedCallTargets) {
+ SCC &TargetC = *G.lookupSCC(*RefTarget);
+ RefSCC &TargetRC = TargetC.getOuterRefSCC();
+
+ // The easy case is when the target RefSCC is not this RefSCC. This is
+ // only supported when the target RefSCC is a child of this RefSCC.
+ if (&TargetRC != RC) {
+ assert(RC->isAncestorOf(TargetRC) &&
+ "Cannot potentially form RefSCC cycles here!");
+ RC->switchOutgoingEdgeToRef(N, *RefTarget);
+ LLVM_DEBUG(dbgs() << "Switch outgoing call edge to a ref edge from '" << N
+ << "' to '" << *RefTarget << "'\n");
+ continue;
+ }
+
+ // We are switching an internal call edge to a ref edge. This may split up
+ // some SCCs.
+ if (C != &TargetC) {
+ // For separate SCCs this is trivial.
+ RC->switchTrivialInternalEdgeToRef(N, *RefTarget);
+ continue;
+ }
+
+ // Now update the call graph.
+ C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, *RefTarget), G, N,
+ C, AM, UR);
+ }
+
+ // Now promote ref edges into call edges.
+ for (Node *CallTarget : PromotedRefTargets) {
+ SCC &TargetC = *G.lookupSCC(*CallTarget);
+ RefSCC &TargetRC = TargetC.getOuterRefSCC();
+
+ // The easy case is when the target RefSCC is not this RefSCC. This is
+ // only supported when the target RefSCC is a child of this RefSCC.
+ if (&TargetRC != RC) {
+ assert(RC->isAncestorOf(TargetRC) &&
+ "Cannot potentially form RefSCC cycles here!");
+ RC->switchOutgoingEdgeToCall(N, *CallTarget);
+ LLVM_DEBUG(dbgs() << "Switch outgoing ref edge to a call edge from '" << N
+ << "' to '" << *CallTarget << "'\n");
+ continue;
+ }
+ LLVM_DEBUG(dbgs() << "Switch an internal ref edge to a call edge from '"
+ << N << "' to '" << *CallTarget << "'\n");
+
+ // Otherwise we are switching an internal ref edge to a call edge. This
+ // may merge away some SCCs, and we add those to the UpdateResult. We also
+ // need to make sure to update the worklist in the event SCCs have moved
+ // before the current one in the post-order sequence
+ bool HasFunctionAnalysisProxy = false;
+ auto InitialSCCIndex = RC->find(*C) - RC->begin();
+ bool FormedCycle = RC->switchInternalEdgeToCall(
+ N, *CallTarget, [&](ArrayRef<SCC *> MergedSCCs) {
+ for (SCC *MergedC : MergedSCCs) {
+ assert(MergedC != &TargetC && "Cannot merge away the target SCC!");
+
+ HasFunctionAnalysisProxy |=
+ AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>(
+ *MergedC) != nullptr;
+
+ // Mark that this SCC will no longer be valid.
+ UR.InvalidatedSCCs.insert(MergedC);
+
+ // FIXME: We should really do a 'clear' here to forcibly release
+ // memory, but we don't have a good way of doing that and
+ // preserving the function analyses.
+ auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>();
+ PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+ AM.invalidate(*MergedC, PA);
+ }
+ });
+
+ // If we formed a cycle by creating this call, we need to update more data
+ // structures.
+ if (FormedCycle) {
+ C = &TargetC;
+ assert(G.lookupSCC(N) == C && "Failed to update current SCC!");
+
+ // If one of the invalidated SCCs had a cached proxy to a function
+ // analysis manager, we need to create a proxy in the new current SCC as
+ // the invalidated SCCs had their functions moved.
+ if (HasFunctionAnalysisProxy)
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, G);
+
+ // Any analyses cached for this SCC are no longer precise as the shape
+ // has changed by introducing this cycle. However, we have taken care to
+ // update the proxies so it remains valide.
+ auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>();
+ PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+ AM.invalidate(*C, PA);
+ }
+ auto NewSCCIndex = RC->find(*C) - RC->begin();
+ // If we have actually moved an SCC to be topologically "below" the current
+ // one due to merging, we will need to revisit the current SCC after
+ // visiting those moved SCCs.
+ //
+ // It is critical that we *do not* revisit the current SCC unless we
+ // actually move SCCs in the process of merging because otherwise we may
+ // form a cycle where an SCC is split apart, merged, split, merged and so
+ // on infinitely.
+ if (InitialSCCIndex < NewSCCIndex) {
+ // Put our current SCC back onto the worklist as we'll visit other SCCs
+ // that are now definitively ordered prior to the current one in the
+ // post-order sequence, and may end up observing more precise context to
+ // optimize the current SCC.
+ UR.CWorklist.insert(C);
+ LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist: " << *C
+ << "\n");
+ // Enqueue in reverse order as we pop off the back of the worklist.
+ for (SCC &MovedC : llvm::reverse(make_range(RC->begin() + InitialSCCIndex,
+ RC->begin() + NewSCCIndex))) {
+ UR.CWorklist.insert(&MovedC);
+ LLVM_DEBUG(dbgs() << "Enqueuing a newly earlier in post-order SCC: "
+ << MovedC << "\n");
+ }
+ }
+ }
+
+ assert(!UR.InvalidatedSCCs.count(C) && "Invalidated the current SCC!");
+ assert(!UR.InvalidatedRefSCCs.count(RC) && "Invalidated the current RefSCC!");
+ assert(&C->getOuterRefSCC() == RC && "Current SCC not in current RefSCC!");
+
+ // Record the current RefSCC and SCC for higher layers of the CGSCC pass
+ // manager now that all the updates have been applied.
+ if (RC != &InitialRC)
+ UR.UpdatedRC = RC;
+ if (C != &InitialC)
+ UR.UpdatedC = C;
+
+ return *C;
+}
diff --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp
new file mode 100644
index 000000000000..70aeb1a688ee
--- /dev/null
+++ b/llvm/lib/Analysis/CallGraph.cpp
@@ -0,0 +1,326 @@
+//===- CallGraph.cpp - Build a Module's call graph ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+
+using namespace llvm;
+
+//===----------------------------------------------------------------------===//
+// Implementations of the CallGraph class methods.
+//
+
+CallGraph::CallGraph(Module &M)
+ : M(M), ExternalCallingNode(getOrInsertFunction(nullptr)),
+ CallsExternalNode(std::make_unique<CallGraphNode>(nullptr)) {
+ // Add every function to the call graph.
+ for (Function &F : M)
+ addToCallGraph(&F);
+}
+
+CallGraph::CallGraph(CallGraph &&Arg)
+ : M(Arg.M), FunctionMap(std::move(Arg.FunctionMap)),
+ ExternalCallingNode(Arg.ExternalCallingNode),
+ CallsExternalNode(std::move(Arg.CallsExternalNode)) {
+ Arg.FunctionMap.clear();
+ Arg.ExternalCallingNode = nullptr;
+}
+
+CallGraph::~CallGraph() {
+ // CallsExternalNode is not in the function map, delete it explicitly.
+ if (CallsExternalNode)
+ CallsExternalNode->allReferencesDropped();
+
+// Reset all node's use counts to zero before deleting them to prevent an
+// assertion from firing.
+#ifndef NDEBUG
+ for (auto &I : FunctionMap)
+ I.second->allReferencesDropped();
+#endif
+}
+
+void CallGraph::addToCallGraph(Function *F) {
+ CallGraphNode *Node = getOrInsertFunction(F);
+
+ // If this function has external linkage or has its address taken, anything
+ // could call it.
+ if (!F->hasLocalLinkage() || F->hasAddressTaken())
+ ExternalCallingNode->addCalledFunction(nullptr, Node);
+
+ // If this function is not defined in this translation unit, it could call
+ // anything.
+ if (F->isDeclaration() && !F->isIntrinsic())
+ Node->addCalledFunction(nullptr, CallsExternalNode.get());
+
+ // Look for calls by this function.
+ for (BasicBlock &BB : *F)
+ for (Instruction &I : BB) {
+ if (auto *Call = dyn_cast<CallBase>(&I)) {
+ const Function *Callee = Call->getCalledFunction();
+ if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
+ // Indirect calls of intrinsics are not allowed so no need to check.
+ // We can be more precise here by using TargetArg returned by
+ // Intrinsic::isLeaf.
+ Node->addCalledFunction(Call, CallsExternalNode.get());
+ else if (!Callee->isIntrinsic())
+ Node->addCalledFunction(Call, getOrInsertFunction(Callee));
+ }
+ }
+}
+
+void CallGraph::print(raw_ostream &OS) const {
+ // Print in a deterministic order by sorting CallGraphNodes by name. We do
+ // this here to avoid slowing down the non-printing fast path.
+
+ SmallVector<CallGraphNode *, 16> Nodes;
+ Nodes.reserve(FunctionMap.size());
+
+ for (const auto &I : *this)
+ Nodes.push_back(I.second.get());
+
+ llvm::sort(Nodes, [](CallGraphNode *LHS, CallGraphNode *RHS) {
+ if (Function *LF = LHS->getFunction())
+ if (Function *RF = RHS->getFunction())
+ return LF->getName() < RF->getName();
+
+ return RHS->getFunction() != nullptr;
+ });
+
+ for (CallGraphNode *CN : Nodes)
+ CN->print(OS);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void CallGraph::dump() const { print(dbgs()); }
+#endif
+
+// removeFunctionFromModule - Unlink the function from this module, returning
+// it. Because this removes the function from the module, the call graph node
+// is destroyed. This is only valid if the function does not call any other
+// functions (ie, there are no edges in it's CGN). The easiest way to do this
+// is to dropAllReferences before calling this.
+//
+Function *CallGraph::removeFunctionFromModule(CallGraphNode *CGN) {
+ assert(CGN->empty() && "Cannot remove function from call "
+ "graph if it references other functions!");
+ Function *F = CGN->getFunction(); // Get the function for the call graph node
+ FunctionMap.erase(F); // Remove the call graph node from the map
+
+ M.getFunctionList().remove(F);
+ return F;
+}
+
+/// spliceFunction - Replace the function represented by this node by another.
+/// This does not rescan the body of the function, so it is suitable when
+/// splicing the body of the old function to the new while also updating all
+/// callers from old to new.
+void CallGraph::spliceFunction(const Function *From, const Function *To) {
+ assert(FunctionMap.count(From) && "No CallGraphNode for function!");
+ assert(!FunctionMap.count(To) &&
+ "Pointing CallGraphNode at a function that already exists");
+ FunctionMapTy::iterator I = FunctionMap.find(From);
+ I->second->F = const_cast<Function*>(To);
+ FunctionMap[To] = std::move(I->second);
+ FunctionMap.erase(I);
+}
+
+// getOrInsertFunction - This method is identical to calling operator[], but
+// it will insert a new CallGraphNode for the specified function if one does
+// not already exist.
+CallGraphNode *CallGraph::getOrInsertFunction(const Function *F) {
+ auto &CGN = FunctionMap[F];
+ if (CGN)
+ return CGN.get();
+
+ assert((!F || F->getParent() == &M) && "Function not in current module!");
+ CGN = std::make_unique<CallGraphNode>(const_cast<Function *>(F));
+ return CGN.get();
+}
+
+//===----------------------------------------------------------------------===//
+// Implementations of the CallGraphNode class methods.
+//
+
+void CallGraphNode::print(raw_ostream &OS) const {
+ if (Function *F = getFunction())
+ OS << "Call graph node for function: '" << F->getName() << "'";
+ else
+ OS << "Call graph node <<null function>>";
+
+ OS << "<<" << this << ">> #uses=" << getNumReferences() << '\n';
+
+ for (const auto &I : *this) {
+ OS << " CS<" << I.first << "> calls ";
+ if (Function *FI = I.second->getFunction())
+ OS << "function '" << FI->getName() <<"'\n";
+ else
+ OS << "external node\n";
+ }
+ OS << '\n';
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void CallGraphNode::dump() const { print(dbgs()); }
+#endif
+
+/// removeCallEdgeFor - This method removes the edge in the node for the
+/// specified call site. Note that this method takes linear time, so it
+/// should be used sparingly.
+void CallGraphNode::removeCallEdgeFor(CallBase &Call) {
+ for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) {
+ assert(I != CalledFunctions.end() && "Cannot find callsite to remove!");
+ if (I->first == &Call) {
+ I->second->DropRef();
+ *I = CalledFunctions.back();
+ CalledFunctions.pop_back();
+ return;
+ }
+ }
+}
+
+// removeAnyCallEdgeTo - This method removes any call edges from this node to
+// the specified callee function. This takes more time to execute than
+// removeCallEdgeTo, so it should not be used unless necessary.
+void CallGraphNode::removeAnyCallEdgeTo(CallGraphNode *Callee) {
+ for (unsigned i = 0, e = CalledFunctions.size(); i != e; ++i)
+ if (CalledFunctions[i].second == Callee) {
+ Callee->DropRef();
+ CalledFunctions[i] = CalledFunctions.back();
+ CalledFunctions.pop_back();
+ --i; --e;
+ }
+}
+
+/// removeOneAbstractEdgeTo - Remove one edge associated with a null callsite
+/// from this node to the specified callee function.
+void CallGraphNode::removeOneAbstractEdgeTo(CallGraphNode *Callee) {
+ for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) {
+ assert(I != CalledFunctions.end() && "Cannot find callee to remove!");
+ CallRecord &CR = *I;
+ if (CR.second == Callee && CR.first == nullptr) {
+ Callee->DropRef();
+ *I = CalledFunctions.back();
+ CalledFunctions.pop_back();
+ return;
+ }
+ }
+}
+
+/// replaceCallEdge - This method replaces the edge in the node for the
+/// specified call site with a new one. Note that this method takes linear
+/// time, so it should be used sparingly.
+void CallGraphNode::replaceCallEdge(CallBase &Call, CallBase &NewCall,
+ CallGraphNode *NewNode) {
+ for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) {
+ assert(I != CalledFunctions.end() && "Cannot find callsite to remove!");
+ if (I->first == &Call) {
+ I->second->DropRef();
+ I->first = &NewCall;
+ I->second = NewNode;
+ NewNode->AddRef();
+ return;
+ }
+ }
+}
+
+// Provide an explicit template instantiation for the static ID.
+AnalysisKey CallGraphAnalysis::Key;
+
+PreservedAnalyses CallGraphPrinterPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ AM.getResult<CallGraphAnalysis>(M).print(OS);
+ return PreservedAnalyses::all();
+}
+
+//===----------------------------------------------------------------------===//
+// Out-of-line definitions of CallGraphAnalysis class members.
+//
+
+//===----------------------------------------------------------------------===//
+// Implementations of the CallGraphWrapperPass class methods.
+//
+
+CallGraphWrapperPass::CallGraphWrapperPass() : ModulePass(ID) {
+ initializeCallGraphWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+CallGraphWrapperPass::~CallGraphWrapperPass() = default;
+
+void CallGraphWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
+
+bool CallGraphWrapperPass::runOnModule(Module &M) {
+ // All the real work is done in the constructor for the CallGraph.
+ G.reset(new CallGraph(M));
+ return false;
+}
+
+INITIALIZE_PASS(CallGraphWrapperPass, "basiccg", "CallGraph Construction",
+ false, true)
+
+char CallGraphWrapperPass::ID = 0;
+
+void CallGraphWrapperPass::releaseMemory() { G.reset(); }
+
+void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const {
+ if (!G) {
+ OS << "No call graph has been built!\n";
+ return;
+ }
+
+ // Just delegate.
+ G->print(OS);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD
+void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); }
+#endif
+
+namespace {
+
+struct CallGraphPrinterLegacyPass : public ModulePass {
+ static char ID; // Pass ID, replacement for typeid
+
+ CallGraphPrinterLegacyPass() : ModulePass(ID) {
+ initializeCallGraphPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<CallGraphWrapperPass>();
+ }
+
+ bool runOnModule(Module &M) override {
+ getAnalysis<CallGraphWrapperPass>().print(errs(), &M);
+ return false;
+ }
+};
+
+} // end anonymous namespace
+
+char CallGraphPrinterLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(CallGraphPrinterLegacyPass, "print-callgraph",
+ "Print a call graph", true, true)
+INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
+INITIALIZE_PASS_END(CallGraphPrinterLegacyPass, "print-callgraph",
+ "Print a call graph", true, true)
diff --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp
new file mode 100644
index 000000000000..196ef400bc4e
--- /dev/null
+++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp
@@ -0,0 +1,711 @@
+//===- CallGraphSCCPass.cpp - Pass that operates BU on call graph ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the CallGraphSCCPass class, which is used for passes
+// which are implemented as bottom-up traversals on the call graph. Because
+// there may be cycles in the call graph, passes of this type operate on the
+// call-graph in SCC order: that is, they process function bottom-up, except for
+// recursive functions, which they process all at once.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRPrintingPasses.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManagers.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/OptBisect.h"
+#include "llvm/IR/PassTimingInfo.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Timer.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <string>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "cgscc-passmgr"
+
+static cl::opt<unsigned>
+MaxIterations("max-cg-scc-iterations", cl::ReallyHidden, cl::init(4));
+
+STATISTIC(MaxSCCIterations, "Maximum CGSCCPassMgr iterations on one SCC");
+
+//===----------------------------------------------------------------------===//
+// CGPassManager
+//
+/// CGPassManager manages FPPassManagers and CallGraphSCCPasses.
+
+namespace {
+
+class CGPassManager : public ModulePass, public PMDataManager {
+public:
+ static char ID;
+
+ explicit CGPassManager() : ModulePass(ID), PMDataManager() {}
+
+ /// Execute all of the passes scheduled for execution. Keep track of
+ /// whether any of the passes modifies the module, and if so, return true.
+ bool runOnModule(Module &M) override;
+
+ using ModulePass::doInitialization;
+ using ModulePass::doFinalization;
+
+ bool doInitialization(CallGraph &CG);
+ bool doFinalization(CallGraph &CG);
+
+ /// Pass Manager itself does not invalidate any analysis info.
+ void getAnalysisUsage(AnalysisUsage &Info) const override {
+ // CGPassManager walks SCC and it needs CallGraph.
+ Info.addRequired<CallGraphWrapperPass>();
+ Info.setPreservesAll();
+ }
+
+ StringRef getPassName() const override { return "CallGraph Pass Manager"; }
+
+ PMDataManager *getAsPMDataManager() override { return this; }
+ Pass *getAsPass() override { return this; }
+
+ // Print passes managed by this manager
+ void dumpPassStructure(unsigned Offset) override {
+ errs().indent(Offset*2) << "Call Graph SCC Pass Manager\n";
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ Pass *P = getContainedPass(Index);
+ P->dumpPassStructure(Offset + 1);
+ dumpLastUses(P, Offset+1);
+ }
+ }
+
+ Pass *getContainedPass(unsigned N) {
+ assert(N < PassVector.size() && "Pass number out of range!");
+ return static_cast<Pass *>(PassVector[N]);
+ }
+
+ PassManagerType getPassManagerType() const override {
+ return PMT_CallGraphPassManager;
+ }
+
+private:
+ bool RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG,
+ bool &DevirtualizedCall);
+
+ bool RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC,
+ CallGraph &CG, bool &CallGraphUpToDate,
+ bool &DevirtualizedCall);
+ bool RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG,
+ bool IsCheckingMode);
+};
+
+} // end anonymous namespace.
+
+char CGPassManager::ID = 0;
+
+bool CGPassManager::RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC,
+ CallGraph &CG, bool &CallGraphUpToDate,
+ bool &DevirtualizedCall) {
+ bool Changed = false;
+ PMDataManager *PM = P->getAsPMDataManager();
+ Module &M = CG.getModule();
+
+ if (!PM) {
+ CallGraphSCCPass *CGSP = (CallGraphSCCPass *)P;
+ if (!CallGraphUpToDate) {
+ DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false);
+ CallGraphUpToDate = true;
+ }
+
+ {
+ unsigned InstrCount, SCCCount = 0;
+ StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount;
+ bool EmitICRemark = M.shouldEmitInstrCountChangedRemark();
+ TimeRegion PassTimer(getPassTimer(CGSP));
+ if (EmitICRemark)
+ InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount);
+ Changed = CGSP->runOnSCC(CurSCC);
+
+ if (EmitICRemark) {
+ // FIXME: Add getInstructionCount to CallGraphSCC.
+ SCCCount = M.getInstructionCount();
+ // Is there a difference in the number of instructions in the module?
+ if (SCCCount != InstrCount) {
+ // Yep. Emit a remark and update InstrCount.
+ int64_t Delta =
+ static_cast<int64_t>(SCCCount) - static_cast<int64_t>(InstrCount);
+ emitInstrCountChangedRemark(P, M, Delta, InstrCount,
+ FunctionToInstrCount);
+ InstrCount = SCCCount;
+ }
+ }
+ }
+
+ // After the CGSCCPass is done, when assertions are enabled, use
+ // RefreshCallGraph to verify that the callgraph was correctly updated.
+#ifndef NDEBUG
+ if (Changed)
+ RefreshCallGraph(CurSCC, CG, true);
+#endif
+
+ return Changed;
+ }
+
+ assert(PM->getPassManagerType() == PMT_FunctionPassManager &&
+ "Invalid CGPassManager member");
+ FPPassManager *FPP = (FPPassManager*)P;
+
+ // Run pass P on all functions in the current SCC.
+ for (CallGraphNode *CGN : CurSCC) {
+ if (Function *F = CGN->getFunction()) {
+ dumpPassInfo(P, EXECUTION_MSG, ON_FUNCTION_MSG, F->getName());
+ {
+ TimeRegion PassTimer(getPassTimer(FPP));
+ Changed |= FPP->runOnFunction(*F);
+ }
+ F->getContext().yield();
+ }
+ }
+
+ // The function pass(es) modified the IR, they may have clobbered the
+ // callgraph.
+ if (Changed && CallGraphUpToDate) {
+ LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Pass Dirtied SCC: " << P->getPassName()
+ << '\n');
+ CallGraphUpToDate = false;
+ }
+ return Changed;
+}
+
+/// Scan the functions in the specified CFG and resync the
+/// callgraph with the call sites found in it. This is used after
+/// FunctionPasses have potentially munged the callgraph, and can be used after
+/// CallGraphSCC passes to verify that they correctly updated the callgraph.
+///
+/// This function returns true if it devirtualized an existing function call,
+/// meaning it turned an indirect call into a direct call. This happens when
+/// a function pass like GVN optimizes away stuff feeding the indirect call.
+/// This never happens in checking mode.
+bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG,
+ bool CheckingMode) {
+ DenseMap<Value *, CallGraphNode *> Calls;
+
+ LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Refreshing SCC with " << CurSCC.size()
+ << " nodes:\n";
+ for (CallGraphNode *CGN
+ : CurSCC) CGN->dump(););
+
+ bool MadeChange = false;
+ bool DevirtualizedCall = false;
+
+ // Scan all functions in the SCC.
+ unsigned FunctionNo = 0;
+ for (CallGraphSCC::iterator SCCIdx = CurSCC.begin(), E = CurSCC.end();
+ SCCIdx != E; ++SCCIdx, ++FunctionNo) {
+ CallGraphNode *CGN = *SCCIdx;
+ Function *F = CGN->getFunction();
+ if (!F || F->isDeclaration()) continue;
+
+ // Walk the function body looking for call sites. Sync up the call sites in
+ // CGN with those actually in the function.
+
+ // Keep track of the number of direct and indirect calls that were
+ // invalidated and removed.
+ unsigned NumDirectRemoved = 0, NumIndirectRemoved = 0;
+
+ // Get the set of call sites currently in the function.
+ for (CallGraphNode::iterator I = CGN->begin(), E = CGN->end(); I != E; ) {
+ // If this call site is null, then the function pass deleted the call
+ // entirely and the WeakTrackingVH nulled it out.
+ auto *Call = dyn_cast_or_null<CallBase>(I->first);
+ if (!I->first ||
+ // If we've already seen this call site, then the FunctionPass RAUW'd
+ // one call with another, which resulted in two "uses" in the edge
+ // list of the same call.
+ Calls.count(I->first) ||
+
+ // If the call edge is not from a call or invoke, or it is a
+ // instrinsic call, then the function pass RAUW'd a call with
+ // another value. This can happen when constant folding happens
+ // of well known functions etc.
+ !Call ||
+ (Call->getCalledFunction() &&
+ Call->getCalledFunction()->isIntrinsic() &&
+ Intrinsic::isLeaf(Call->getCalledFunction()->getIntrinsicID()))) {
+ assert(!CheckingMode &&
+ "CallGraphSCCPass did not update the CallGraph correctly!");
+
+ // If this was an indirect call site, count it.
+ if (!I->second->getFunction())
+ ++NumIndirectRemoved;
+ else
+ ++NumDirectRemoved;
+
+ // Just remove the edge from the set of callees, keep track of whether
+ // I points to the last element of the vector.
+ bool WasLast = I + 1 == E;
+ CGN->removeCallEdge(I);
+
+ // If I pointed to the last element of the vector, we have to bail out:
+ // iterator checking rejects comparisons of the resultant pointer with
+ // end.
+ if (WasLast)
+ break;
+ E = CGN->end();
+ continue;
+ }
+
+ assert(!Calls.count(I->first) &&
+ "Call site occurs in node multiple times");
+
+ if (Call) {
+ Function *Callee = Call->getCalledFunction();
+ // Ignore intrinsics because they're not really function calls.
+ if (!Callee || !(Callee->isIntrinsic()))
+ Calls.insert(std::make_pair(I->first, I->second));
+ }
+ ++I;
+ }
+
+ // Loop over all of the instructions in the function, getting the callsites.
+ // Keep track of the number of direct/indirect calls added.
+ unsigned NumDirectAdded = 0, NumIndirectAdded = 0;
+
+ for (BasicBlock &BB : *F)
+ for (Instruction &I : BB) {
+ auto *Call = dyn_cast<CallBase>(&I);
+ if (!Call)
+ continue;
+ Function *Callee = Call->getCalledFunction();
+ if (Callee && Callee->isIntrinsic())
+ continue;
+
+ // If this call site already existed in the callgraph, just verify it
+ // matches up to expectations and remove it from Calls.
+ DenseMap<Value *, CallGraphNode *>::iterator ExistingIt =
+ Calls.find(Call);
+ if (ExistingIt != Calls.end()) {
+ CallGraphNode *ExistingNode = ExistingIt->second;
+
+ // Remove from Calls since we have now seen it.
+ Calls.erase(ExistingIt);
+
+ // Verify that the callee is right.
+ if (ExistingNode->getFunction() == Call->getCalledFunction())
+ continue;
+
+ // If we are in checking mode, we are not allowed to actually mutate
+ // the callgraph. If this is a case where we can infer that the
+ // callgraph is less precise than it could be (e.g. an indirect call
+ // site could be turned direct), don't reject it in checking mode, and
+ // don't tweak it to be more precise.
+ if (CheckingMode && Call->getCalledFunction() &&
+ ExistingNode->getFunction() == nullptr)
+ continue;
+
+ assert(!CheckingMode &&
+ "CallGraphSCCPass did not update the CallGraph correctly!");
+
+ // If not, we either went from a direct call to indirect, indirect to
+ // direct, or direct to different direct.
+ CallGraphNode *CalleeNode;
+ if (Function *Callee = Call->getCalledFunction()) {
+ CalleeNode = CG.getOrInsertFunction(Callee);
+ // Keep track of whether we turned an indirect call into a direct
+ // one.
+ if (!ExistingNode->getFunction()) {
+ DevirtualizedCall = true;
+ LLVM_DEBUG(dbgs() << " CGSCCPASSMGR: Devirtualized call to '"
+ << Callee->getName() << "'\n");
+ }
+ } else {
+ CalleeNode = CG.getCallsExternalNode();
+ }
+
+ // Update the edge target in CGN.
+ CGN->replaceCallEdge(*Call, *Call, CalleeNode);
+ MadeChange = true;
+ continue;
+ }
+
+ assert(!CheckingMode &&
+ "CallGraphSCCPass did not update the CallGraph correctly!");
+
+ // If the call site didn't exist in the CGN yet, add it.
+ CallGraphNode *CalleeNode;
+ if (Function *Callee = Call->getCalledFunction()) {
+ CalleeNode = CG.getOrInsertFunction(Callee);
+ ++NumDirectAdded;
+ } else {
+ CalleeNode = CG.getCallsExternalNode();
+ ++NumIndirectAdded;
+ }
+
+ CGN->addCalledFunction(Call, CalleeNode);
+ MadeChange = true;
+ }
+
+ // We scanned the old callgraph node, removing invalidated call sites and
+ // then added back newly found call sites. One thing that can happen is
+ // that an old indirect call site was deleted and replaced with a new direct
+ // call. In this case, we have devirtualized a call, and CGSCCPM would like
+ // to iteratively optimize the new code. Unfortunately, we don't really
+ // have a great way to detect when this happens. As an approximation, we
+ // just look at whether the number of indirect calls is reduced and the
+ // number of direct calls is increased. There are tons of ways to fool this
+ // (e.g. DCE'ing an indirect call and duplicating an unrelated block with a
+ // direct call) but this is close enough.
+ if (NumIndirectRemoved > NumIndirectAdded &&
+ NumDirectRemoved < NumDirectAdded)
+ DevirtualizedCall = true;
+
+ // After scanning this function, if we still have entries in callsites, then
+ // they are dangling pointers. WeakTrackingVH should save us for this, so
+ // abort if
+ // this happens.
+ assert(Calls.empty() && "Dangling pointers found in call sites map");
+
+ // Periodically do an explicit clear to remove tombstones when processing
+ // large scc's.
+ if ((FunctionNo & 15) == 15)
+ Calls.clear();
+ }
+
+ LLVM_DEBUG(if (MadeChange) {
+ dbgs() << "CGSCCPASSMGR: Refreshed SCC is now:\n";
+ for (CallGraphNode *CGN : CurSCC)
+ CGN->dump();
+ if (DevirtualizedCall)
+ dbgs() << "CGSCCPASSMGR: Refresh devirtualized a call!\n";
+ } else {
+ dbgs() << "CGSCCPASSMGR: SCC Refresh didn't change call graph.\n";
+ });
+ (void)MadeChange;
+
+ return DevirtualizedCall;
+}
+
+/// Execute the body of the entire pass manager on the specified SCC.
+/// This keeps track of whether a function pass devirtualizes
+/// any calls and returns it in DevirtualizedCall.
+bool CGPassManager::RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG,
+ bool &DevirtualizedCall) {
+ bool Changed = false;
+
+ // Keep track of whether the callgraph is known to be up-to-date or not.
+ // The CGSSC pass manager runs two types of passes:
+ // CallGraphSCC Passes and other random function passes. Because other
+ // random function passes are not CallGraph aware, they may clobber the
+ // call graph by introducing new calls or deleting other ones. This flag
+ // is set to false when we run a function pass so that we know to clean up
+ // the callgraph when we need to run a CGSCCPass again.
+ bool CallGraphUpToDate = true;
+
+ // Run all passes on current SCC.
+ for (unsigned PassNo = 0, e = getNumContainedPasses();
+ PassNo != e; ++PassNo) {
+ Pass *P = getContainedPass(PassNo);
+
+ // If we're in -debug-pass=Executions mode, construct the SCC node list,
+ // otherwise avoid constructing this string as it is expensive.
+ if (isPassDebuggingExecutionsOrMore()) {
+ std::string Functions;
+ #ifndef NDEBUG
+ raw_string_ostream OS(Functions);
+ for (CallGraphSCC::iterator I = CurSCC.begin(), E = CurSCC.end();
+ I != E; ++I) {
+ if (I != CurSCC.begin()) OS << ", ";
+ (*I)->print(OS);
+ }
+ OS.flush();
+ #endif
+ dumpPassInfo(P, EXECUTION_MSG, ON_CG_MSG, Functions);
+ }
+ dumpRequiredSet(P);
+
+ initializeAnalysisImpl(P);
+
+ // Actually run this pass on the current SCC.
+ Changed |= RunPassOnSCC(P, CurSCC, CG,
+ CallGraphUpToDate, DevirtualizedCall);
+
+ if (Changed)
+ dumpPassInfo(P, MODIFICATION_MSG, ON_CG_MSG, "");
+ dumpPreservedSet(P);
+
+ verifyPreservedAnalysis(P);
+ removeNotPreservedAnalysis(P);
+ recordAvailableAnalysis(P);
+ removeDeadPasses(P, "", ON_CG_MSG);
+ }
+
+ // If the callgraph was left out of date (because the last pass run was a
+ // functionpass), refresh it before we move on to the next SCC.
+ if (!CallGraphUpToDate)
+ DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false);
+ return Changed;
+}
+
+/// Execute all of the passes scheduled for execution. Keep track of
+/// whether any of the passes modifies the module, and if so, return true.
+bool CGPassManager::runOnModule(Module &M) {
+ CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
+ bool Changed = doInitialization(CG);
+
+ // Walk the callgraph in bottom-up SCC order.
+ scc_iterator<CallGraph*> CGI = scc_begin(&CG);
+
+ CallGraphSCC CurSCC(CG, &CGI);
+ while (!CGI.isAtEnd()) {
+ // Copy the current SCC and increment past it so that the pass can hack
+ // on the SCC if it wants to without invalidating our iterator.
+ const std::vector<CallGraphNode *> &NodeVec = *CGI;
+ CurSCC.initialize(NodeVec);
+ ++CGI;
+
+ // At the top level, we run all the passes in this pass manager on the
+ // functions in this SCC. However, we support iterative compilation in the
+ // case where a function pass devirtualizes a call to a function. For
+ // example, it is very common for a function pass (often GVN or instcombine)
+ // to eliminate the addressing that feeds into a call. With that improved
+ // information, we would like the call to be an inline candidate, infer
+ // mod-ref information etc.
+ //
+ // Because of this, we allow iteration up to a specified iteration count.
+ // This only happens in the case of a devirtualized call, so we only burn
+ // compile time in the case that we're making progress. We also have a hard
+ // iteration count limit in case there is crazy code.
+ unsigned Iteration = 0;
+ bool DevirtualizedCall = false;
+ do {
+ LLVM_DEBUG(if (Iteration) dbgs()
+ << " SCCPASSMGR: Re-visiting SCC, iteration #" << Iteration
+ << '\n');
+ DevirtualizedCall = false;
+ Changed |= RunAllPassesOnSCC(CurSCC, CG, DevirtualizedCall);
+ } while (Iteration++ < MaxIterations && DevirtualizedCall);
+
+ if (DevirtualizedCall)
+ LLVM_DEBUG(dbgs() << " CGSCCPASSMGR: Stopped iteration after "
+ << Iteration
+ << " times, due to -max-cg-scc-iterations\n");
+
+ MaxSCCIterations.updateMax(Iteration);
+ }
+ Changed |= doFinalization(CG);
+ return Changed;
+}
+
+/// Initialize CG
+bool CGPassManager::doInitialization(CallGraph &CG) {
+ bool Changed = false;
+ for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) {
+ if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) {
+ assert(PM->getPassManagerType() == PMT_FunctionPassManager &&
+ "Invalid CGPassManager member");
+ Changed |= ((FPPassManager*)PM)->doInitialization(CG.getModule());
+ } else {
+ Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doInitialization(CG);
+ }
+ }
+ return Changed;
+}
+
+/// Finalize CG
+bool CGPassManager::doFinalization(CallGraph &CG) {
+ bool Changed = false;
+ for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) {
+ if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) {
+ assert(PM->getPassManagerType() == PMT_FunctionPassManager &&
+ "Invalid CGPassManager member");
+ Changed |= ((FPPassManager*)PM)->doFinalization(CG.getModule());
+ } else {
+ Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doFinalization(CG);
+ }
+ }
+ return Changed;
+}
+
+//===----------------------------------------------------------------------===//
+// CallGraphSCC Implementation
+//===----------------------------------------------------------------------===//
+
+/// This informs the SCC and the pass manager that the specified
+/// Old node has been deleted, and New is to be used in its place.
+void CallGraphSCC::ReplaceNode(CallGraphNode *Old, CallGraphNode *New) {
+ assert(Old != New && "Should not replace node with self");
+ for (unsigned i = 0; ; ++i) {
+ assert(i != Nodes.size() && "Node not in SCC");
+ if (Nodes[i] != Old) continue;
+ Nodes[i] = New;
+ break;
+ }
+
+ // Update the active scc_iterator so that it doesn't contain dangling
+ // pointers to the old CallGraphNode.
+ scc_iterator<CallGraph*> *CGI = (scc_iterator<CallGraph*>*)Context;
+ CGI->ReplaceNode(Old, New);
+}
+
+//===----------------------------------------------------------------------===//
+// CallGraphSCCPass Implementation
+//===----------------------------------------------------------------------===//
+
+/// Assign pass manager to manage this pass.
+void CallGraphSCCPass::assignPassManager(PMStack &PMS,
+ PassManagerType PreferredType) {
+ // Find CGPassManager
+ while (!PMS.empty() &&
+ PMS.top()->getPassManagerType() > PMT_CallGraphPassManager)
+ PMS.pop();
+
+ assert(!PMS.empty() && "Unable to handle Call Graph Pass");
+ CGPassManager *CGP;
+
+ if (PMS.top()->getPassManagerType() == PMT_CallGraphPassManager)
+ CGP = (CGPassManager*)PMS.top();
+ else {
+ // Create new Call Graph SCC Pass Manager if it does not exist.
+ assert(!PMS.empty() && "Unable to create Call Graph Pass Manager");
+ PMDataManager *PMD = PMS.top();
+
+ // [1] Create new Call Graph Pass Manager
+ CGP = new CGPassManager();
+
+ // [2] Set up new manager's top level manager
+ PMTopLevelManager *TPM = PMD->getTopLevelManager();
+ TPM->addIndirectPassManager(CGP);
+
+ // [3] Assign manager to manage this new manager. This may create
+ // and push new managers into PMS
+ Pass *P = CGP;
+ TPM->schedulePass(P);
+
+ // [4] Push new manager into PMS
+ PMS.push(CGP);
+ }
+
+ CGP->add(this);
+}
+
+/// For this class, we declare that we require and preserve the call graph.
+/// If the derived class implements this method, it should
+/// always explicitly call the implementation here.
+void CallGraphSCCPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<CallGraphWrapperPass>();
+ AU.addPreserved<CallGraphWrapperPass>();
+}
+
+//===----------------------------------------------------------------------===//
+// PrintCallGraphPass Implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+ /// PrintCallGraphPass - Print a Module corresponding to a call graph.
+ ///
+ class PrintCallGraphPass : public CallGraphSCCPass {
+ std::string Banner;
+ raw_ostream &OS; // raw_ostream to print on.
+
+ public:
+ static char ID;
+
+ PrintCallGraphPass(const std::string &B, raw_ostream &OS)
+ : CallGraphSCCPass(ID), Banner(B), OS(OS) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+
+ bool runOnSCC(CallGraphSCC &SCC) override {
+ bool BannerPrinted = false;
+ auto PrintBannerOnce = [&]() {
+ if (BannerPrinted)
+ return;
+ OS << Banner;
+ BannerPrinted = true;
+ };
+
+ bool NeedModule = llvm::forcePrintModuleIR();
+ if (isFunctionInPrintList("*") && NeedModule) {
+ PrintBannerOnce();
+ OS << "\n";
+ SCC.getCallGraph().getModule().print(OS, nullptr);
+ return false;
+ }
+ bool FoundFunction = false;
+ for (CallGraphNode *CGN : SCC) {
+ if (Function *F = CGN->getFunction()) {
+ if (!F->isDeclaration() && isFunctionInPrintList(F->getName())) {
+ FoundFunction = true;
+ if (!NeedModule) {
+ PrintBannerOnce();
+ F->print(OS);
+ }
+ }
+ } else if (isFunctionInPrintList("*")) {
+ PrintBannerOnce();
+ OS << "\nPrinting <null> Function\n";
+ }
+ }
+ if (NeedModule && FoundFunction) {
+ PrintBannerOnce();
+ OS << "\n";
+ SCC.getCallGraph().getModule().print(OS, nullptr);
+ }
+ return false;
+ }
+
+ StringRef getPassName() const override { return "Print CallGraph IR"; }
+ };
+
+} // end anonymous namespace.
+
+char PrintCallGraphPass::ID = 0;
+
+Pass *CallGraphSCCPass::createPrinterPass(raw_ostream &OS,
+ const std::string &Banner) const {
+ return new PrintCallGraphPass(Banner, OS);
+}
+
+static std::string getDescription(const CallGraphSCC &SCC) {
+ std::string Desc = "SCC (";
+ bool First = true;
+ for (CallGraphNode *CGN : SCC) {
+ if (First)
+ First = false;
+ else
+ Desc += ", ";
+ Function *F = CGN->getFunction();
+ if (F)
+ Desc += F->getName();
+ else
+ Desc += "<<null function>>";
+ }
+ Desc += ")";
+ return Desc;
+}
+
+bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const {
+ OptPassGate &Gate =
+ SCC.getCallGraph().getModule().getContext().getOptPassGate();
+ return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(SCC));
+}
+
+char DummyCGSCCPass::ID = 0;
+
+INITIALIZE_PASS(DummyCGSCCPass, "DummyCGSCCPass", "DummyCGSCCPass", false,
+ false)
diff --git a/llvm/lib/Analysis/CallPrinter.cpp b/llvm/lib/Analysis/CallPrinter.cpp
new file mode 100644
index 000000000000..d24cbd104bf6
--- /dev/null
+++ b/llvm/lib/Analysis/CallPrinter.cpp
@@ -0,0 +1,91 @@
+//===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
+// containing the call graph of a module.
+//
+// There is also a pass available to directly call dotty ('-view-callgraph').
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CallPrinter.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/DOTGraphTraitsPass.h"
+
+using namespace llvm;
+
+namespace llvm {
+
+template <> struct DOTGraphTraits<CallGraph *> : public DefaultDOTGraphTraits {
+ DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
+
+ static std::string getGraphName(CallGraph *Graph) { return "Call graph"; }
+
+ std::string getNodeLabel(CallGraphNode *Node, CallGraph *Graph) {
+ if (Function *Func = Node->getFunction())
+ return Func->getName();
+
+ return "external node";
+ }
+};
+
+struct AnalysisCallGraphWrapperPassTraits {
+ static CallGraph *getGraph(CallGraphWrapperPass *P) {
+ return &P->getCallGraph();
+ }
+};
+
+} // end llvm namespace
+
+namespace {
+
+struct CallGraphViewer
+ : public DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *,
+ AnalysisCallGraphWrapperPassTraits> {
+ static char ID;
+
+ CallGraphViewer()
+ : DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *,
+ AnalysisCallGraphWrapperPassTraits>(
+ "callgraph", ID) {
+ initializeCallGraphViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct CallGraphDOTPrinter : public DOTGraphTraitsModulePrinter<
+ CallGraphWrapperPass, true, CallGraph *,
+ AnalysisCallGraphWrapperPassTraits> {
+ static char ID;
+
+ CallGraphDOTPrinter()
+ : DOTGraphTraitsModulePrinter<CallGraphWrapperPass, true, CallGraph *,
+ AnalysisCallGraphWrapperPassTraits>(
+ "callgraph", ID) {
+ initializeCallGraphDOTPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+} // end anonymous namespace
+
+char CallGraphViewer::ID = 0;
+INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
+ false)
+
+char CallGraphDOTPrinter::ID = 0;
+INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
+ "Print call graph to 'dot' file", false, false)
+
+// Create methods available outside of this file, to use them
+// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
+// the link time optimization.
+
+ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
+
+ModulePass *llvm::createCallGraphDOTPrinterPass() {
+ return new CallGraphDOTPrinter();
+}
diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp
new file mode 100644
index 000000000000..20e2f06540a3
--- /dev/null
+++ b/llvm/lib/Analysis/CaptureTracking.cpp
@@ -0,0 +1,389 @@
+//===--- CaptureTracking.cpp - Determine whether a pointer is captured ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains routines that help determine which pointers are captured.
+// A pointer value is captured if the function makes a copy of any part of the
+// pointer that outlives the call. Not being captured means, more or less, that
+// the pointer is only dereferenced and not stored in a global. Returning part
+// of the pointer as the function return value may or may not count as capturing
+// the pointer, depending on the context.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/OrderedBasicBlock.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+
+using namespace llvm;
+
+CaptureTracker::~CaptureTracker() {}
+
+bool CaptureTracker::shouldExplore(const Use *U) { return true; }
+
+bool CaptureTracker::isDereferenceableOrNull(Value *O, const DataLayout &DL) {
+ // An inbounds GEP can either be a valid pointer (pointing into
+ // or to the end of an allocation), or be null in the default
+ // address space. So for an inbounds GEP there is no way to let
+ // the pointer escape using clever GEP hacking because doing so
+ // would make the pointer point outside of the allocated object
+ // and thus make the GEP result a poison value. Similarly, other
+ // dereferenceable pointers cannot be manipulated without producing
+ // poison.
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(O))
+ if (GEP->isInBounds())
+ return true;
+ bool CanBeNull;
+ return O->getPointerDereferenceableBytes(DL, CanBeNull);
+}
+
+namespace {
+ struct SimpleCaptureTracker : public CaptureTracker {
+ explicit SimpleCaptureTracker(bool ReturnCaptures)
+ : ReturnCaptures(ReturnCaptures), Captured(false) {}
+
+ void tooManyUses() override { Captured = true; }
+
+ bool captured(const Use *U) override {
+ if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures)
+ return false;
+
+ Captured = true;
+ return true;
+ }
+
+ bool ReturnCaptures;
+
+ bool Captured;
+ };
+
+ /// Only find pointer captures which happen before the given instruction. Uses
+ /// the dominator tree to determine whether one instruction is before another.
+ /// Only support the case where the Value is defined in the same basic block
+ /// as the given instruction and the use.
+ struct CapturesBefore : public CaptureTracker {
+
+ CapturesBefore(bool ReturnCaptures, const Instruction *I, const DominatorTree *DT,
+ bool IncludeI, OrderedBasicBlock *IC)
+ : OrderedBB(IC), BeforeHere(I), DT(DT),
+ ReturnCaptures(ReturnCaptures), IncludeI(IncludeI), Captured(false) {}
+
+ void tooManyUses() override { Captured = true; }
+
+ bool isSafeToPrune(Instruction *I) {
+ BasicBlock *BB = I->getParent();
+ // We explore this usage only if the usage can reach "BeforeHere".
+ // If use is not reachable from entry, there is no need to explore.
+ if (BeforeHere != I && !DT->isReachableFromEntry(BB))
+ return true;
+
+ // Compute the case where both instructions are inside the same basic
+ // block. Since instructions in the same BB as BeforeHere are numbered in
+ // 'OrderedBB', avoid using 'dominates' and 'isPotentiallyReachable'
+ // which are very expensive for large basic blocks.
+ if (BB == BeforeHere->getParent()) {
+ // 'I' dominates 'BeforeHere' => not safe to prune.
+ //
+ // The value defined by an invoke dominates an instruction only
+ // if it dominates every instruction in UseBB. A PHI is dominated only
+ // if the instruction dominates every possible use in the UseBB. Since
+ // UseBB == BB, avoid pruning.
+ if (isa<InvokeInst>(BeforeHere) || isa<PHINode>(I) || I == BeforeHere)
+ return false;
+ if (!OrderedBB->dominates(BeforeHere, I))
+ return false;
+
+ // 'BeforeHere' comes before 'I', it's safe to prune if we also
+ // guarantee that 'I' never reaches 'BeforeHere' through a back-edge or
+ // by its successors, i.e, prune if:
+ //
+ // (1) BB is an entry block or have no successors.
+ // (2) There's no path coming back through BB successors.
+ if (BB == &BB->getParent()->getEntryBlock() ||
+ !BB->getTerminator()->getNumSuccessors())
+ return true;
+
+ SmallVector<BasicBlock*, 32> Worklist;
+ Worklist.append(succ_begin(BB), succ_end(BB));
+ return !isPotentiallyReachableFromMany(Worklist, BB, nullptr, DT);
+ }
+
+ // If the value is defined in the same basic block as use and BeforeHere,
+ // there is no need to explore the use if BeforeHere dominates use.
+ // Check whether there is a path from I to BeforeHere.
+ if (BeforeHere != I && DT->dominates(BeforeHere, I) &&
+ !isPotentiallyReachable(I, BeforeHere, nullptr, DT))
+ return true;
+
+ return false;
+ }
+
+ bool shouldExplore(const Use *U) override {
+ Instruction *I = cast<Instruction>(U->getUser());
+
+ if (BeforeHere == I && !IncludeI)
+ return false;
+
+ if (isSafeToPrune(I))
+ return false;
+
+ return true;
+ }
+
+ bool captured(const Use *U) override {
+ if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures)
+ return false;
+
+ if (!shouldExplore(U))
+ return false;
+
+ Captured = true;
+ return true;
+ }
+
+ OrderedBasicBlock *OrderedBB;
+ const Instruction *BeforeHere;
+ const DominatorTree *DT;
+
+ bool ReturnCaptures;
+ bool IncludeI;
+
+ bool Captured;
+ };
+}
+
+/// PointerMayBeCaptured - Return true if this pointer value may be captured
+/// by the enclosing function (which is required to exist). This routine can
+/// be expensive, so consider caching the results. The boolean ReturnCaptures
+/// specifies whether returning the value (or part of it) from the function
+/// counts as capturing it or not. The boolean StoreCaptures specified whether
+/// storing the value (or part of it) into memory anywhere automatically
+/// counts as capturing it or not.
+bool llvm::PointerMayBeCaptured(const Value *V,
+ bool ReturnCaptures, bool StoreCaptures,
+ unsigned MaxUsesToExplore) {
+ assert(!isa<GlobalValue>(V) &&
+ "It doesn't make sense to ask whether a global is captured.");
+
+ // TODO: If StoreCaptures is not true, we could do Fancy analysis
+ // to determine whether this store is not actually an escape point.
+ // In that case, BasicAliasAnalysis should be updated as well to
+ // take advantage of this.
+ (void)StoreCaptures;
+
+ SimpleCaptureTracker SCT(ReturnCaptures);
+ PointerMayBeCaptured(V, &SCT, MaxUsesToExplore);
+ return SCT.Captured;
+}
+
+/// PointerMayBeCapturedBefore - Return true if this pointer value may be
+/// captured by the enclosing function (which is required to exist). If a
+/// DominatorTree is provided, only captures which happen before the given
+/// instruction are considered. This routine can be expensive, so consider
+/// caching the results. The boolean ReturnCaptures specifies whether
+/// returning the value (or part of it) from the function counts as capturing
+/// it or not. The boolean StoreCaptures specified whether storing the value
+/// (or part of it) into memory anywhere automatically counts as capturing it
+/// or not. A ordered basic block \p OBB can be used in order to speed up
+/// queries about relative order among instructions in the same basic block.
+bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures,
+ bool StoreCaptures, const Instruction *I,
+ const DominatorTree *DT, bool IncludeI,
+ OrderedBasicBlock *OBB,
+ unsigned MaxUsesToExplore) {
+ assert(!isa<GlobalValue>(V) &&
+ "It doesn't make sense to ask whether a global is captured.");
+ bool UseNewOBB = OBB == nullptr;
+
+ if (!DT)
+ return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures,
+ MaxUsesToExplore);
+ if (UseNewOBB)
+ OBB = new OrderedBasicBlock(I->getParent());
+
+ // TODO: See comment in PointerMayBeCaptured regarding what could be done
+ // with StoreCaptures.
+
+ CapturesBefore CB(ReturnCaptures, I, DT, IncludeI, OBB);
+ PointerMayBeCaptured(V, &CB, MaxUsesToExplore);
+
+ if (UseNewOBB)
+ delete OBB;
+ return CB.Captured;
+}
+
+void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
+ unsigned MaxUsesToExplore) {
+ assert(V->getType()->isPointerTy() && "Capture is for pointers only!");
+ SmallVector<const Use *, DefaultMaxUsesToExplore> Worklist;
+ SmallSet<const Use *, DefaultMaxUsesToExplore> Visited;
+
+ auto AddUses = [&](const Value *V) {
+ unsigned Count = 0;
+ for (const Use &U : V->uses()) {
+ // If there are lots of uses, conservatively say that the value
+ // is captured to avoid taking too much compile time.
+ if (Count++ >= MaxUsesToExplore)
+ return Tracker->tooManyUses();
+ if (!Visited.insert(&U).second)
+ continue;
+ if (!Tracker->shouldExplore(&U))
+ continue;
+ Worklist.push_back(&U);
+ }
+ };
+ AddUses(V);
+
+ while (!Worklist.empty()) {
+ const Use *U = Worklist.pop_back_val();
+ Instruction *I = cast<Instruction>(U->getUser());
+ V = U->get();
+
+ switch (I->getOpcode()) {
+ case Instruction::Call:
+ case Instruction::Invoke: {
+ auto *Call = cast<CallBase>(I);
+ // Not captured if the callee is readonly, doesn't return a copy through
+ // its return value and doesn't unwind (a readonly function can leak bits
+ // by throwing an exception or not depending on the input value).
+ if (Call->onlyReadsMemory() && Call->doesNotThrow() &&
+ Call->getType()->isVoidTy())
+ break;
+
+ // The pointer is not captured if returned pointer is not captured.
+ // NOTE: CaptureTracking users should not assume that only functions
+ // marked with nocapture do not capture. This means that places like
+ // GetUnderlyingObject in ValueTracking or DecomposeGEPExpression
+ // in BasicAA also need to know about this property.
+ if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call,
+ true)) {
+ AddUses(Call);
+ break;
+ }
+
+ // Volatile operations effectively capture the memory location that they
+ // load and store to.
+ if (auto *MI = dyn_cast<MemIntrinsic>(Call))
+ if (MI->isVolatile())
+ if (Tracker->captured(U))
+ return;
+
+ // Not captured if only passed via 'nocapture' arguments. Note that
+ // calling a function pointer does not in itself cause the pointer to
+ // be captured. This is a subtle point considering that (for example)
+ // the callee might return its own address. It is analogous to saying
+ // that loading a value from a pointer does not cause the pointer to be
+ // captured, even though the loaded value might be the pointer itself
+ // (think of self-referential objects).
+ for (auto IdxOpPair : enumerate(Call->data_ops())) {
+ int Idx = IdxOpPair.index();
+ Value *A = IdxOpPair.value();
+ if (A == V && !Call->doesNotCapture(Idx))
+ // The parameter is not marked 'nocapture' - captured.
+ if (Tracker->captured(U))
+ return;
+ }
+ break;
+ }
+ case Instruction::Load:
+ // Volatile loads make the address observable.
+ if (cast<LoadInst>(I)->isVolatile())
+ if (Tracker->captured(U))
+ return;
+ break;
+ case Instruction::VAArg:
+ // "va-arg" from a pointer does not cause it to be captured.
+ break;
+ case Instruction::Store:
+ // Stored the pointer - conservatively assume it may be captured.
+ // Volatile stores make the address observable.
+ if (V == I->getOperand(0) || cast<StoreInst>(I)->isVolatile())
+ if (Tracker->captured(U))
+ return;
+ break;
+ case Instruction::AtomicRMW: {
+ // atomicrmw conceptually includes both a load and store from
+ // the same location.
+ // As with a store, the location being accessed is not captured,
+ // but the value being stored is.
+ // Volatile stores make the address observable.
+ auto *ARMWI = cast<AtomicRMWInst>(I);
+ if (ARMWI->getValOperand() == V || ARMWI->isVolatile())
+ if (Tracker->captured(U))
+ return;
+ break;
+ }
+ case Instruction::AtomicCmpXchg: {
+ // cmpxchg conceptually includes both a load and store from
+ // the same location.
+ // As with a store, the location being accessed is not captured,
+ // but the value being stored is.
+ // Volatile stores make the address observable.
+ auto *ACXI = cast<AtomicCmpXchgInst>(I);
+ if (ACXI->getCompareOperand() == V || ACXI->getNewValOperand() == V ||
+ ACXI->isVolatile())
+ if (Tracker->captured(U))
+ return;
+ break;
+ }
+ case Instruction::BitCast:
+ case Instruction::GetElementPtr:
+ case Instruction::PHI:
+ case Instruction::Select:
+ case Instruction::AddrSpaceCast:
+ // The original value is not captured via this if the new value isn't.
+ AddUses(I);
+ break;
+ case Instruction::ICmp: {
+ unsigned Idx = (I->getOperand(0) == V) ? 0 : 1;
+ unsigned OtherIdx = 1 - Idx;
+ if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) {
+ // Don't count comparisons of a no-alias return value against null as
+ // captures. This allows us to ignore comparisons of malloc results
+ // with null, for example.
+ if (CPN->getType()->getAddressSpace() == 0)
+ if (isNoAliasCall(V->stripPointerCasts()))
+ break;
+ if (!I->getFunction()->nullPointerIsDefined()) {
+ auto *O = I->getOperand(Idx)->stripPointerCastsSameRepresentation();
+ // Comparing a dereferenceable_or_null pointer against null cannot
+ // lead to pointer escapes, because if it is not null it must be a
+ // valid (in-bounds) pointer.
+ if (Tracker->isDereferenceableOrNull(O, I->getModule()->getDataLayout()))
+ break;
+ }
+ }
+ // Comparison against value stored in global variable. Given the pointer
+ // does not escape, its value cannot be guessed and stored separately in a
+ // global variable.
+ auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIdx));
+ if (LI && isa<GlobalVariable>(LI->getPointerOperand()))
+ break;
+ // Otherwise, be conservative. There are crazy ways to capture pointers
+ // using comparisons.
+ if (Tracker->captured(U))
+ return;
+ break;
+ }
+ default:
+ // Something else - be conservative and say it is captured.
+ if (Tracker->captured(U))
+ return;
+ break;
+ }
+ }
+
+ // All uses examined.
+}
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
new file mode 100644
index 000000000000..a5757be2c4f4
--- /dev/null
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -0,0 +1,143 @@
+//===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file holds routines to help analyse compare instructions
+// and fold them into constants or other compare instructions
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CmpInstAnalysis.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/PatternMatch.h"
+
+using namespace llvm;
+
+unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) {
+ ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate()
+ : ICI->getPredicate();
+ switch (Pred) {
+ // False -> 0
+ case ICmpInst::ICMP_UGT: return 1; // 001
+ case ICmpInst::ICMP_SGT: return 1; // 001
+ case ICmpInst::ICMP_EQ: return 2; // 010
+ case ICmpInst::ICMP_UGE: return 3; // 011
+ case ICmpInst::ICMP_SGE: return 3; // 011
+ case ICmpInst::ICMP_ULT: return 4; // 100
+ case ICmpInst::ICMP_SLT: return 4; // 100
+ case ICmpInst::ICMP_NE: return 5; // 101
+ case ICmpInst::ICMP_ULE: return 6; // 110
+ case ICmpInst::ICMP_SLE: return 6; // 110
+ // True -> 7
+ default:
+ llvm_unreachable("Invalid ICmp predicate!");
+ }
+}
+
+Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy,
+ CmpInst::Predicate &Pred) {
+ switch (Code) {
+ default: llvm_unreachable("Illegal ICmp code!");
+ case 0: // False.
+ return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0);
+ case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break;
+ case 2: Pred = ICmpInst::ICMP_EQ; break;
+ case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break;
+ case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break;
+ case 5: Pred = ICmpInst::ICMP_NE; break;
+ case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break;
+ case 7: // True.
+ return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1);
+ }
+ return nullptr;
+}
+
+bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) {
+ return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) ||
+ (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) ||
+ (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1));
+}
+
+bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
+ CmpInst::Predicate &Pred,
+ Value *&X, APInt &Mask, bool LookThruTrunc) {
+ using namespace PatternMatch;
+
+ const APInt *C;
+ if (!match(RHS, m_APInt(C)))
+ return false;
+
+ switch (Pred) {
+ default:
+ return false;
+ case ICmpInst::ICMP_SLT:
+ // X < 0 is equivalent to (X & SignMask) != 0.
+ if (!C->isNullValue())
+ return false;
+ Mask = APInt::getSignMask(C->getBitWidth());
+ Pred = ICmpInst::ICMP_NE;
+ break;
+ case ICmpInst::ICMP_SLE:
+ // X <= -1 is equivalent to (X & SignMask) != 0.
+ if (!C->isAllOnesValue())
+ return false;
+ Mask = APInt::getSignMask(C->getBitWidth());
+ Pred = ICmpInst::ICMP_NE;
+ break;
+ case ICmpInst::ICMP_SGT:
+ // X > -1 is equivalent to (X & SignMask) == 0.
+ if (!C->isAllOnesValue())
+ return false;
+ Mask = APInt::getSignMask(C->getBitWidth());
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case ICmpInst::ICMP_SGE:
+ // X >= 0 is equivalent to (X & SignMask) == 0.
+ if (!C->isNullValue())
+ return false;
+ Mask = APInt::getSignMask(C->getBitWidth());
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case ICmpInst::ICMP_ULT:
+ // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
+ if (!C->isPowerOf2())
+ return false;
+ Mask = -*C;
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case ICmpInst::ICMP_ULE:
+ // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0.
+ if (!(*C + 1).isPowerOf2())
+ return false;
+ Mask = ~*C;
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case ICmpInst::ICMP_UGT:
+ // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0.
+ if (!(*C + 1).isPowerOf2())
+ return false;
+ Mask = ~*C;
+ Pred = ICmpInst::ICMP_NE;
+ break;
+ case ICmpInst::ICMP_UGE:
+ // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0.
+ if (!C->isPowerOf2())
+ return false;
+ Mask = -*C;
+ Pred = ICmpInst::ICMP_NE;
+ break;
+ }
+
+ if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
+ Mask = Mask.zext(X->getType()->getScalarSizeInBits());
+ } else {
+ X = LHS;
+ }
+
+ return true;
+}
diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp
new file mode 100644
index 000000000000..627d955c865f
--- /dev/null
+++ b/llvm/lib/Analysis/CodeMetrics.cpp
@@ -0,0 +1,195 @@
+//===- CodeMetrics.cpp - Code cost measurements ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements code cost measurement utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "code-metrics"
+
+using namespace llvm;
+
+static void
+appendSpeculatableOperands(const Value *V,
+ SmallPtrSetImpl<const Value *> &Visited,
+ SmallVectorImpl<const Value *> &Worklist) {
+ const User *U = dyn_cast<User>(V);
+ if (!U)
+ return;
+
+ for (const Value *Operand : U->operands())
+ if (Visited.insert(Operand).second)
+ if (isSafeToSpeculativelyExecute(Operand))
+ Worklist.push_back(Operand);
+}
+
+static void completeEphemeralValues(SmallPtrSetImpl<const Value *> &Visited,
+ SmallVectorImpl<const Value *> &Worklist,
+ SmallPtrSetImpl<const Value *> &EphValues) {
+ // Note: We don't speculate PHIs here, so we'll miss instruction chains kept
+ // alive only by ephemeral values.
+
+ // Walk the worklist using an index but without caching the size so we can
+ // append more entries as we process the worklist. This forms a queue without
+ // quadratic behavior by just leaving processed nodes at the head of the
+ // worklist forever.
+ for (int i = 0; i < (int)Worklist.size(); ++i) {
+ const Value *V = Worklist[i];
+
+ assert(Visited.count(V) &&
+ "Failed to add a worklist entry to our visited set!");
+
+ // If all uses of this value are ephemeral, then so is this value.
+ if (!all_of(V->users(), [&](const User *U) { return EphValues.count(U); }))
+ continue;
+
+ EphValues.insert(V);
+ LLVM_DEBUG(dbgs() << "Ephemeral Value: " << *V << "\n");
+
+ // Append any more operands to consider.
+ appendSpeculatableOperands(V, Visited, Worklist);
+ }
+}
+
+// Find all ephemeral values.
+void CodeMetrics::collectEphemeralValues(
+ const Loop *L, AssumptionCache *AC,
+ SmallPtrSetImpl<const Value *> &EphValues) {
+ SmallPtrSet<const Value *, 32> Visited;
+ SmallVector<const Value *, 16> Worklist;
+
+ for (auto &AssumeVH : AC->assumptions()) {
+ if (!AssumeVH)
+ continue;
+ Instruction *I = cast<Instruction>(AssumeVH);
+
+ // Filter out call sites outside of the loop so we don't do a function's
+ // worth of work for each of its loops (and, in the common case, ephemeral
+ // values in the loop are likely due to @llvm.assume calls in the loop).
+ if (!L->contains(I->getParent()))
+ continue;
+
+ if (EphValues.insert(I).second)
+ appendSpeculatableOperands(I, Visited, Worklist);
+ }
+
+ completeEphemeralValues(Visited, Worklist, EphValues);
+}
+
+void CodeMetrics::collectEphemeralValues(
+ const Function *F, AssumptionCache *AC,
+ SmallPtrSetImpl<const Value *> &EphValues) {
+ SmallPtrSet<const Value *, 32> Visited;
+ SmallVector<const Value *, 16> Worklist;
+
+ for (auto &AssumeVH : AC->assumptions()) {
+ if (!AssumeVH)
+ continue;
+ Instruction *I = cast<Instruction>(AssumeVH);
+ assert(I->getParent()->getParent() == F &&
+ "Found assumption for the wrong function!");
+
+ if (EphValues.insert(I).second)
+ appendSpeculatableOperands(I, Visited, Worklist);
+ }
+
+ completeEphemeralValues(Visited, Worklist, EphValues);
+}
+
+/// Fill in the current structure with information gleaned from the specified
+/// block.
+void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB,
+ const TargetTransformInfo &TTI,
+ const SmallPtrSetImpl<const Value*> &EphValues) {
+ ++NumBlocks;
+ unsigned NumInstsBeforeThisBB = NumInsts;
+ for (const Instruction &I : *BB) {
+ // Skip ephemeral values.
+ if (EphValues.count(&I))
+ continue;
+
+ // Special handling for calls.
+ if (const auto *Call = dyn_cast<CallBase>(&I)) {
+ if (const Function *F = Call->getCalledFunction()) {
+ // If a function is both internal and has a single use, then it is
+ // extremely likely to get inlined in the future (it was probably
+ // exposed by an interleaved devirtualization pass).
+ if (!Call->isNoInline() && F->hasInternalLinkage() && F->hasOneUse())
+ ++NumInlineCandidates;
+
+ // If this call is to function itself, then the function is recursive.
+ // Inlining it into other functions is a bad idea, because this is
+ // basically just a form of loop peeling, and our metrics aren't useful
+ // for that case.
+ if (F == BB->getParent())
+ isRecursive = true;
+
+ if (TTI.isLoweredToCall(F))
+ ++NumCalls;
+ } else {
+ // We don't want inline asm to count as a call - that would prevent loop
+ // unrolling. The argument setup cost is still real, though.
+ if (!Call->isInlineAsm())
+ ++NumCalls;
+ }
+ }
+
+ if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
+ if (!AI->isStaticAlloca())
+ this->usesDynamicAlloca = true;
+ }
+
+ if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy())
+ ++NumVectorInsts;
+
+ if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB))
+ notDuplicatable = true;
+
+ if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
+ if (CI->cannotDuplicate())
+ notDuplicatable = true;
+ if (CI->isConvergent())
+ convergent = true;
+ }
+
+ if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I))
+ if (InvI->cannotDuplicate())
+ notDuplicatable = true;
+
+ NumInsts += TTI.getUserCost(&I);
+ }
+
+ if (isa<ReturnInst>(BB->getTerminator()))
+ ++NumRets;
+
+ // We never want to inline functions that contain an indirectbr. This is
+ // incorrect because all the blockaddress's (in static global initializers
+ // for example) would be referring to the original function, and this indirect
+ // jump would jump from the inlined copy of the function into the original
+ // function which is extremely undefined behavior.
+ // FIXME: This logic isn't really right; we can safely inline functions
+ // with indirectbr's as long as no other function or global references the
+ // blockaddress of a block within the current function. And as a QOI issue,
+ // if someone is using a blockaddress without an indirectbr, and that
+ // reference somehow ends up in another function or global, we probably
+ // don't want to inline this function.
+ notDuplicatable |= isa<IndirectBrInst>(BB->getTerminator());
+
+ // Remember NumInsts for this BB.
+ NumBBInsts[BB] = NumInsts - NumInstsBeforeThisBB;
+}
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
new file mode 100644
index 000000000000..8dbcf7034fda
--- /dev/null
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -0,0 +1,2637 @@
+//===-- ConstantFolding.cpp - Fold instructions into constants ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines routines for folding instructions into constants.
+//
+// Also, to supplement the basic IR ConstantExpr simplifications,
+// this file defines some additional folding routines that can make use of
+// DataLayout information. These functions cannot go in IR due to library
+// dependency issues.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/Config/config.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+#include <cerrno>
+#include <cfenv>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+
+using namespace llvm;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Constant Folding internal helper functions
+//===----------------------------------------------------------------------===//
+
+static Constant *foldConstVectorToAPInt(APInt &Result, Type *DestTy,
+ Constant *C, Type *SrcEltTy,
+ unsigned NumSrcElts,
+ const DataLayout &DL) {
+ // Now that we know that the input value is a vector of integers, just shift
+ // and insert them into our result.
+ unsigned BitShift = DL.getTypeSizeInBits(SrcEltTy);
+ for (unsigned i = 0; i != NumSrcElts; ++i) {
+ Constant *Element;
+ if (DL.isLittleEndian())
+ Element = C->getAggregateElement(NumSrcElts - i - 1);
+ else
+ Element = C->getAggregateElement(i);
+
+ if (Element && isa<UndefValue>(Element)) {
+ Result <<= BitShift;
+ continue;
+ }
+
+ auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
+ if (!ElementCI)
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ Result <<= BitShift;
+ Result |= ElementCI->getValue().zextOrSelf(Result.getBitWidth());
+ }
+
+ return nullptr;
+}
+
+/// Constant fold bitcast, symbolically evaluating it with DataLayout.
+/// This always returns a non-null constant, but it may be a
+/// ConstantExpr if unfoldable.
+Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) {
+ assert(CastInst::castIsValid(Instruction::BitCast, C, DestTy) &&
+ "Invalid constantexpr bitcast!");
+
+ // Catch the obvious splat cases.
+ if (C->isNullValue() && !DestTy->isX86_MMXTy())
+ return Constant::getNullValue(DestTy);
+ if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() &&
+ !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types!
+ return Constant::getAllOnesValue(DestTy);
+
+ if (auto *VTy = dyn_cast<VectorType>(C->getType())) {
+ // Handle a vector->scalar integer/fp cast.
+ if (isa<IntegerType>(DestTy) || DestTy->isFloatingPointTy()) {
+ unsigned NumSrcElts = VTy->getNumElements();
+ Type *SrcEltTy = VTy->getElementType();
+
+ // If the vector is a vector of floating point, convert it to vector of int
+ // to simplify things.
+ if (SrcEltTy->isFloatingPointTy()) {
+ unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
+ Type *SrcIVTy =
+ VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElts);
+ // Ask IR to do the conversion now that #elts line up.
+ C = ConstantExpr::getBitCast(C, SrcIVTy);
+ }
+
+ APInt Result(DL.getTypeSizeInBits(DestTy), 0);
+ if (Constant *CE = foldConstVectorToAPInt(Result, DestTy, C,
+ SrcEltTy, NumSrcElts, DL))
+ return CE;
+
+ if (isa<IntegerType>(DestTy))
+ return ConstantInt::get(DestTy, Result);
+
+ APFloat FP(DestTy->getFltSemantics(), Result);
+ return ConstantFP::get(DestTy->getContext(), FP);
+ }
+ }
+
+ // The code below only handles casts to vectors currently.
+ auto *DestVTy = dyn_cast<VectorType>(DestTy);
+ if (!DestVTy)
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ // If this is a scalar -> vector cast, convert the input into a <1 x scalar>
+ // vector so the code below can handle it uniformly.
+ if (isa<ConstantFP>(C) || isa<ConstantInt>(C)) {
+ Constant *Ops = C; // don't take the address of C!
+ return FoldBitCast(ConstantVector::get(Ops), DestTy, DL);
+ }
+
+ // If this is a bitcast from constant vector -> vector, fold it.
+ if (!isa<ConstantDataVector>(C) && !isa<ConstantVector>(C))
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ // If the element types match, IR can fold it.
+ unsigned NumDstElt = DestVTy->getNumElements();
+ unsigned NumSrcElt = C->getType()->getVectorNumElements();
+ if (NumDstElt == NumSrcElt)
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ Type *SrcEltTy = C->getType()->getVectorElementType();
+ Type *DstEltTy = DestVTy->getElementType();
+
+ // Otherwise, we're changing the number of elements in a vector, which
+ // requires endianness information to do the right thing. For example,
+ // bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+ // folds to (little endian):
+ // <4 x i32> <i32 0, i32 0, i32 1, i32 0>
+ // and to (big endian):
+ // <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+
+ // First thing is first. We only want to think about integer here, so if
+ // we have something in FP form, recast it as integer.
+ if (DstEltTy->isFloatingPointTy()) {
+ // Fold to an vector of integers with same size as our FP type.
+ unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits();
+ Type *DestIVTy =
+ VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumDstElt);
+ // Recursively handle this integer conversion, if possible.
+ C = FoldBitCast(C, DestIVTy, DL);
+
+ // Finally, IR can handle this now that #elts line up.
+ return ConstantExpr::getBitCast(C, DestTy);
+ }
+
+ // Okay, we know the destination is integer, if the input is FP, convert
+ // it to integer first.
+ if (SrcEltTy->isFloatingPointTy()) {
+ unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
+ Type *SrcIVTy =
+ VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElt);
+ // Ask IR to do the conversion now that #elts line up.
+ C = ConstantExpr::getBitCast(C, SrcIVTy);
+ // If IR wasn't able to fold it, bail out.
+ if (!isa<ConstantVector>(C) && // FIXME: Remove ConstantVector.
+ !isa<ConstantDataVector>(C))
+ return C;
+ }
+
+ // Now we know that the input and output vectors are both integer vectors
+ // of the same size, and that their #elements is not the same. Do the
+ // conversion here, which depends on whether the input or output has
+ // more elements.
+ bool isLittleEndian = DL.isLittleEndian();
+
+ SmallVector<Constant*, 32> Result;
+ if (NumDstElt < NumSrcElt) {
+ // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>)
+ Constant *Zero = Constant::getNullValue(DstEltTy);
+ unsigned Ratio = NumSrcElt/NumDstElt;
+ unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits();
+ unsigned SrcElt = 0;
+ for (unsigned i = 0; i != NumDstElt; ++i) {
+ // Build each element of the result.
+ Constant *Elt = Zero;
+ unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1);
+ for (unsigned j = 0; j != Ratio; ++j) {
+ Constant *Src = C->getAggregateElement(SrcElt++);
+ if (Src && isa<UndefValue>(Src))
+ Src = Constant::getNullValue(C->getType()->getVectorElementType());
+ else
+ Src = dyn_cast_or_null<ConstantInt>(Src);
+ if (!Src) // Reject constantexpr elements.
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ // Zero extend the element to the right size.
+ Src = ConstantExpr::getZExt(Src, Elt->getType());
+
+ // Shift it to the right place, depending on endianness.
+ Src = ConstantExpr::getShl(Src,
+ ConstantInt::get(Src->getType(), ShiftAmt));
+ ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize;
+
+ // Mix it in.
+ Elt = ConstantExpr::getOr(Elt, Src);
+ }
+ Result.push_back(Elt);
+ }
+ return ConstantVector::get(Result);
+ }
+
+ // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+ unsigned Ratio = NumDstElt/NumSrcElt;
+ unsigned DstBitSize = DL.getTypeSizeInBits(DstEltTy);
+
+ // Loop over each source value, expanding into multiple results.
+ for (unsigned i = 0; i != NumSrcElt; ++i) {
+ auto *Element = C->getAggregateElement(i);
+
+ if (!Element) // Reject constantexpr elements.
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ if (isa<UndefValue>(Element)) {
+ // Correctly Propagate undef values.
+ Result.append(Ratio, UndefValue::get(DstEltTy));
+ continue;
+ }
+
+ auto *Src = dyn_cast<ConstantInt>(Element);
+ if (!Src)
+ return ConstantExpr::getBitCast(C, DestTy);
+
+ unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1);
+ for (unsigned j = 0; j != Ratio; ++j) {
+ // Shift the piece of the value into the right place, depending on
+ // endianness.
+ Constant *Elt = ConstantExpr::getLShr(Src,
+ ConstantInt::get(Src->getType(), ShiftAmt));
+ ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize;
+
+ // Truncate the element to an integer with the same pointer size and
+ // convert the element back to a pointer using a inttoptr.
+ if (DstEltTy->isPointerTy()) {
+ IntegerType *DstIntTy = Type::getIntNTy(C->getContext(), DstBitSize);
+ Constant *CE = ConstantExpr::getTrunc(Elt, DstIntTy);
+ Result.push_back(ConstantExpr::getIntToPtr(CE, DstEltTy));
+ continue;
+ }
+
+ // Truncate and remember this piece.
+ Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy));
+ }
+ }
+
+ return ConstantVector::get(Result);
+}
+
+} // end anonymous namespace
+
+/// If this constant is a constant offset from a global, return the global and
+/// the constant. Because of constantexprs, this function is recursive.
+bool llvm::IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
+ APInt &Offset, const DataLayout &DL) {
+ // Trivial case, constant is the global.
+ if ((GV = dyn_cast<GlobalValue>(C))) {
+ unsigned BitWidth = DL.getIndexTypeSizeInBits(GV->getType());
+ Offset = APInt(BitWidth, 0);
+ return true;
+ }
+
+ // Otherwise, if this isn't a constant expr, bail out.
+ auto *CE = dyn_cast<ConstantExpr>(C);
+ if (!CE) return false;
+
+ // Look through ptr->int and ptr->ptr casts.
+ if (CE->getOpcode() == Instruction::PtrToInt ||
+ CE->getOpcode() == Instruction::BitCast)
+ return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, DL);
+
+ // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5)
+ auto *GEP = dyn_cast<GEPOperator>(CE);
+ if (!GEP)
+ return false;
+
+ unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP->getType());
+ APInt TmpOffset(BitWidth, 0);
+
+ // If the base isn't a global+constant, we aren't either.
+ if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, TmpOffset, DL))
+ return false;
+
+ // Otherwise, add any offset that our operands provide.
+ if (!GEP->accumulateConstantOffset(DL, TmpOffset))
+ return false;
+
+ Offset = TmpOffset;
+ return true;
+}
+
+Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy,
+ const DataLayout &DL) {
+ do {
+ Type *SrcTy = C->getType();
+
+ // If the type sizes are the same and a cast is legal, just directly
+ // cast the constant.
+ if (DL.getTypeSizeInBits(DestTy) == DL.getTypeSizeInBits(SrcTy)) {
+ Instruction::CastOps Cast = Instruction::BitCast;
+ // If we are going from a pointer to int or vice versa, we spell the cast
+ // differently.
+ if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
+ Cast = Instruction::IntToPtr;
+ else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
+ Cast = Instruction::PtrToInt;
+
+ if (CastInst::castIsValid(Cast, C, DestTy))
+ return ConstantExpr::getCast(Cast, C, DestTy);
+ }
+
+ // If this isn't an aggregate type, there is nothing we can do to drill down
+ // and find a bitcastable constant.
+ if (!SrcTy->isAggregateType())
+ return nullptr;
+
+ // We're simulating a load through a pointer that was bitcast to point to
+ // a different type, so we can try to walk down through the initial
+ // elements of an aggregate to see if some part of the aggregate is
+ // castable to implement the "load" semantic model.
+ if (SrcTy->isStructTy()) {
+ // Struct types might have leading zero-length elements like [0 x i32],
+ // which are certainly not what we are looking for, so skip them.
+ unsigned Elem = 0;
+ Constant *ElemC;
+ do {
+ ElemC = C->getAggregateElement(Elem++);
+ } while (ElemC && DL.getTypeSizeInBits(ElemC->getType()) == 0);
+ C = ElemC;
+ } else {
+ C = C->getAggregateElement(0u);
+ }
+ } while (C);
+
+ return nullptr;
+}
+
+namespace {
+
+/// Recursive helper to read bits out of global. C is the constant being copied
+/// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy
+/// results into and BytesLeft is the number of bytes left in
+/// the CurPtr buffer. DL is the DataLayout.
+bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr,
+ unsigned BytesLeft, const DataLayout &DL) {
+ assert(ByteOffset <= DL.getTypeAllocSize(C->getType()) &&
+ "Out of range access");
+
+ // If this element is zero or undefined, we can just return since *CurPtr is
+ // zero initialized.
+ if (isa<ConstantAggregateZero>(C) || isa<UndefValue>(C))
+ return true;
+
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
+ if (CI->getBitWidth() > 64 ||
+ (CI->getBitWidth() & 7) != 0)
+ return false;
+
+ uint64_t Val = CI->getZExtValue();
+ unsigned IntBytes = unsigned(CI->getBitWidth()/8);
+
+ for (unsigned i = 0; i != BytesLeft && ByteOffset != IntBytes; ++i) {
+ int n = ByteOffset;
+ if (!DL.isLittleEndian())
+ n = IntBytes - n - 1;
+ CurPtr[i] = (unsigned char)(Val >> (n * 8));
+ ++ByteOffset;
+ }
+ return true;
+ }
+
+ if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+ if (CFP->getType()->isDoubleTy()) {
+ C = FoldBitCast(C, Type::getInt64Ty(C->getContext()), DL);
+ return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL);
+ }
+ if (CFP->getType()->isFloatTy()){
+ C = FoldBitCast(C, Type::getInt32Ty(C->getContext()), DL);
+ return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL);
+ }
+ if (CFP->getType()->isHalfTy()){
+ C = FoldBitCast(C, Type::getInt16Ty(C->getContext()), DL);
+ return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL);
+ }
+ return false;
+ }
+
+ if (auto *CS = dyn_cast<ConstantStruct>(C)) {
+ const StructLayout *SL = DL.getStructLayout(CS->getType());
+ unsigned Index = SL->getElementContainingOffset(ByteOffset);
+ uint64_t CurEltOffset = SL->getElementOffset(Index);
+ ByteOffset -= CurEltOffset;
+
+ while (true) {
+ // If the element access is to the element itself and not to tail padding,
+ // read the bytes from the element.
+ uint64_t EltSize = DL.getTypeAllocSize(CS->getOperand(Index)->getType());
+
+ if (ByteOffset < EltSize &&
+ !ReadDataFromGlobal(CS->getOperand(Index), ByteOffset, CurPtr,
+ BytesLeft, DL))
+ return false;
+
+ ++Index;
+
+ // Check to see if we read from the last struct element, if so we're done.
+ if (Index == CS->getType()->getNumElements())
+ return true;
+
+ // If we read all of the bytes we needed from this element we're done.
+ uint64_t NextEltOffset = SL->getElementOffset(Index);
+
+ if (BytesLeft <= NextEltOffset - CurEltOffset - ByteOffset)
+ return true;
+
+ // Move to the next element of the struct.
+ CurPtr += NextEltOffset - CurEltOffset - ByteOffset;
+ BytesLeft -= NextEltOffset - CurEltOffset - ByteOffset;
+ ByteOffset = 0;
+ CurEltOffset = NextEltOffset;
+ }
+ // not reached.
+ }
+
+ if (isa<ConstantArray>(C) || isa<ConstantVector>(C) ||
+ isa<ConstantDataSequential>(C)) {
+ Type *EltTy = C->getType()->getSequentialElementType();
+ uint64_t EltSize = DL.getTypeAllocSize(EltTy);
+ uint64_t Index = ByteOffset / EltSize;
+ uint64_t Offset = ByteOffset - Index * EltSize;
+ uint64_t NumElts;
+ if (auto *AT = dyn_cast<ArrayType>(C->getType()))
+ NumElts = AT->getNumElements();
+ else
+ NumElts = C->getType()->getVectorNumElements();
+
+ for (; Index != NumElts; ++Index) {
+ if (!ReadDataFromGlobal(C->getAggregateElement(Index), Offset, CurPtr,
+ BytesLeft, DL))
+ return false;
+
+ uint64_t BytesWritten = EltSize - Offset;
+ assert(BytesWritten <= EltSize && "Not indexing into this element?");
+ if (BytesWritten >= BytesLeft)
+ return true;
+
+ Offset = 0;
+ BytesLeft -= BytesWritten;
+ CurPtr += BytesWritten;
+ }
+ return true;
+ }
+
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ if (CE->getOpcode() == Instruction::IntToPtr &&
+ CE->getOperand(0)->getType() == DL.getIntPtrType(CE->getType())) {
+ return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr,
+ BytesLeft, DL);
+ }
+ }
+
+ // Otherwise, unknown initializer type.
+ return false;
+}
+
+Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy,
+ const DataLayout &DL) {
+ auto *PTy = cast<PointerType>(C->getType());
+ auto *IntType = dyn_cast<IntegerType>(LoadTy);
+
+ // If this isn't an integer load we can't fold it directly.
+ if (!IntType) {
+ unsigned AS = PTy->getAddressSpace();
+
+ // If this is a float/double load, we can try folding it as an int32/64 load
+ // and then bitcast the result. This can be useful for union cases. Note
+ // that address spaces don't matter here since we're not going to result in
+ // an actual new load.
+ Type *MapTy;
+ if (LoadTy->isHalfTy())
+ MapTy = Type::getInt16Ty(C->getContext());
+ else if (LoadTy->isFloatTy())
+ MapTy = Type::getInt32Ty(C->getContext());
+ else if (LoadTy->isDoubleTy())
+ MapTy = Type::getInt64Ty(C->getContext());
+ else if (LoadTy->isVectorTy()) {
+ MapTy = PointerType::getIntNTy(C->getContext(),
+ DL.getTypeSizeInBits(LoadTy));
+ } else
+ return nullptr;
+
+ C = FoldBitCast(C, MapTy->getPointerTo(AS), DL);
+ if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) {
+ if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
+ // Materializing a zero can be done trivially without a bitcast
+ return Constant::getNullValue(LoadTy);
+ Type *CastTy = LoadTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(LoadTy) : LoadTy;
+ Res = FoldBitCast(Res, CastTy, DL);
+ if (LoadTy->isPtrOrPtrVectorTy()) {
+ // For vector of pointer, we needed to first convert to a vector of integer, then do vector inttoptr
+ if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
+ return Constant::getNullValue(LoadTy);
+ if (DL.isNonIntegralPointerType(LoadTy->getScalarType()))
+ // Be careful not to replace a load of an addrspace value with an inttoptr here
+ return nullptr;
+ Res = ConstantExpr::getCast(Instruction::IntToPtr, Res, LoadTy);
+ }
+ return Res;
+ }
+ return nullptr;
+ }
+
+ unsigned BytesLoaded = (IntType->getBitWidth() + 7) / 8;
+ if (BytesLoaded > 32 || BytesLoaded == 0)
+ return nullptr;
+
+ GlobalValue *GVal;
+ APInt OffsetAI;
+ if (!IsConstantOffsetFromGlobal(C, GVal, OffsetAI, DL))
+ return nullptr;
+
+ auto *GV = dyn_cast<GlobalVariable>(GVal);
+ if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
+ !GV->getInitializer()->getType()->isSized())
+ return nullptr;
+
+ int64_t Offset = OffsetAI.getSExtValue();
+ int64_t InitializerSize = DL.getTypeAllocSize(GV->getInitializer()->getType());
+
+ // If we're not accessing anything in this constant, the result is undefined.
+ if (Offset <= -1 * static_cast<int64_t>(BytesLoaded))
+ return UndefValue::get(IntType);
+
+ // If we're not accessing anything in this constant, the result is undefined.
+ if (Offset >= InitializerSize)
+ return UndefValue::get(IntType);
+
+ unsigned char RawBytes[32] = {0};
+ unsigned char *CurPtr = RawBytes;
+ unsigned BytesLeft = BytesLoaded;
+
+ // If we're loading off the beginning of the global, some bytes may be valid.
+ if (Offset < 0) {
+ CurPtr += -Offset;
+ BytesLeft += Offset;
+ Offset = 0;
+ }
+
+ if (!ReadDataFromGlobal(GV->getInitializer(), Offset, CurPtr, BytesLeft, DL))
+ return nullptr;
+
+ APInt ResultVal = APInt(IntType->getBitWidth(), 0);
+ if (DL.isLittleEndian()) {
+ ResultVal = RawBytes[BytesLoaded - 1];
+ for (unsigned i = 1; i != BytesLoaded; ++i) {
+ ResultVal <<= 8;
+ ResultVal |= RawBytes[BytesLoaded - 1 - i];
+ }
+ } else {
+ ResultVal = RawBytes[0];
+ for (unsigned i = 1; i != BytesLoaded; ++i) {
+ ResultVal <<= 8;
+ ResultVal |= RawBytes[i];
+ }
+ }
+
+ return ConstantInt::get(IntType->getContext(), ResultVal);
+}
+
+Constant *ConstantFoldLoadThroughBitcastExpr(ConstantExpr *CE, Type *DestTy,
+ const DataLayout &DL) {
+ auto *SrcPtr = CE->getOperand(0);
+ auto *SrcPtrTy = dyn_cast<PointerType>(SrcPtr->getType());
+ if (!SrcPtrTy)
+ return nullptr;
+ Type *SrcTy = SrcPtrTy->getPointerElementType();
+
+ Constant *C = ConstantFoldLoadFromConstPtr(SrcPtr, SrcTy, DL);
+ if (!C)
+ return nullptr;
+
+ return llvm::ConstantFoldLoadThroughBitcast(C, DestTy, DL);
+}
+
+} // end anonymous namespace
+
+Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty,
+ const DataLayout &DL) {
+ // First, try the easy cases:
+ if (auto *GV = dyn_cast<GlobalVariable>(C))
+ if (GV->isConstant() && GV->hasDefinitiveInitializer())
+ return GV->getInitializer();
+
+ if (auto *GA = dyn_cast<GlobalAlias>(C))
+ if (GA->getAliasee() && !GA->isInterposable())
+ return ConstantFoldLoadFromConstPtr(GA->getAliasee(), Ty, DL);
+
+ // If the loaded value isn't a constant expr, we can't handle it.
+ auto *CE = dyn_cast<ConstantExpr>(C);
+ if (!CE)
+ return nullptr;
+
+ if (CE->getOpcode() == Instruction::GetElementPtr) {
+ if (auto *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) {
+ if (GV->isConstant() && GV->hasDefinitiveInitializer()) {
+ if (Constant *V =
+ ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE))
+ return V;
+ }
+ }
+ }
+
+ if (CE->getOpcode() == Instruction::BitCast)
+ if (Constant *LoadedC = ConstantFoldLoadThroughBitcastExpr(CE, Ty, DL))
+ return LoadedC;
+
+ // Instead of loading constant c string, use corresponding integer value
+ // directly if string length is small enough.
+ StringRef Str;
+ if (getConstantStringInfo(CE, Str) && !Str.empty()) {
+ size_t StrLen = Str.size();
+ unsigned NumBits = Ty->getPrimitiveSizeInBits();
+ // Replace load with immediate integer if the result is an integer or fp
+ // value.
+ if ((NumBits >> 3) == StrLen + 1 && (NumBits & 7) == 0 &&
+ (isa<IntegerType>(Ty) || Ty->isFloatingPointTy())) {
+ APInt StrVal(NumBits, 0);
+ APInt SingleChar(NumBits, 0);
+ if (DL.isLittleEndian()) {
+ for (unsigned char C : reverse(Str.bytes())) {
+ SingleChar = static_cast<uint64_t>(C);
+ StrVal = (StrVal << 8) | SingleChar;
+ }
+ } else {
+ for (unsigned char C : Str.bytes()) {
+ SingleChar = static_cast<uint64_t>(C);
+ StrVal = (StrVal << 8) | SingleChar;
+ }
+ // Append NULL at the end.
+ SingleChar = 0;
+ StrVal = (StrVal << 8) | SingleChar;
+ }
+
+ Constant *Res = ConstantInt::get(CE->getContext(), StrVal);
+ if (Ty->isFloatingPointTy())
+ Res = ConstantExpr::getBitCast(Res, Ty);
+ return Res;
+ }
+ }
+
+ // If this load comes from anywhere in a constant global, and if the global
+ // is all undef or zero, we know what it loads.
+ if (auto *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(CE, DL))) {
+ if (GV->isConstant() && GV->hasDefinitiveInitializer()) {
+ if (GV->getInitializer()->isNullValue())
+ return Constant::getNullValue(Ty);
+ if (isa<UndefValue>(GV->getInitializer()))
+ return UndefValue::get(Ty);
+ }
+ }
+
+ // Try hard to fold loads from bitcasted strange and non-type-safe things.
+ return FoldReinterpretLoadFromConstPtr(CE, Ty, DL);
+}
+
+namespace {
+
+Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout &DL) {
+ if (LI->isVolatile()) return nullptr;
+
+ if (auto *C = dyn_cast<Constant>(LI->getOperand(0)))
+ return ConstantFoldLoadFromConstPtr(C, LI->getType(), DL);
+
+ return nullptr;
+}
+
+/// One of Op0/Op1 is a constant expression.
+/// Attempt to symbolically evaluate the result of a binary operator merging
+/// these together. If target data info is available, it is provided as DL,
+/// otherwise DL is null.
+Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1,
+ const DataLayout &DL) {
+ // SROA
+
+ // Fold (and 0xffffffff00000000, (shl x, 32)) -> shl.
+ // Fold (lshr (or X, Y), 32) -> (lshr [X/Y], 32) if one doesn't contribute
+ // bits.
+
+ if (Opc == Instruction::And) {
+ KnownBits Known0 = computeKnownBits(Op0, DL);
+ KnownBits Known1 = computeKnownBits(Op1, DL);
+ if ((Known1.One | Known0.Zero).isAllOnesValue()) {
+ // All the bits of Op0 that the 'and' could be masking are already zero.
+ return Op0;
+ }
+ if ((Known0.One | Known1.Zero).isAllOnesValue()) {
+ // All the bits of Op1 that the 'and' could be masking are already zero.
+ return Op1;
+ }
+
+ Known0.Zero |= Known1.Zero;
+ Known0.One &= Known1.One;
+ if (Known0.isConstant())
+ return ConstantInt::get(Op0->getType(), Known0.getConstant());
+ }
+
+ // If the constant expr is something like &A[123] - &A[4].f, fold this into a
+ // constant. This happens frequently when iterating over a global array.
+ if (Opc == Instruction::Sub) {
+ GlobalValue *GV1, *GV2;
+ APInt Offs1, Offs2;
+
+ if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, DL))
+ if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, DL) && GV1 == GV2) {
+ unsigned OpSize = DL.getTypeSizeInBits(Op0->getType());
+
+ // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow.
+ // PtrToInt may change the bitwidth so we have convert to the right size
+ // first.
+ return ConstantInt::get(Op0->getType(), Offs1.zextOrTrunc(OpSize) -
+ Offs2.zextOrTrunc(OpSize));
+ }
+ }
+
+ return nullptr;
+}
+
+/// If array indices are not pointer-sized integers, explicitly cast them so
+/// that they aren't implicitly casted by the getelementptr.
+Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops,
+ Type *ResultTy, Optional<unsigned> InRangeIndex,
+ const DataLayout &DL, const TargetLibraryInfo *TLI) {
+ Type *IntPtrTy = DL.getIntPtrType(ResultTy);
+ Type *IntPtrScalarTy = IntPtrTy->getScalarType();
+
+ bool Any = false;
+ SmallVector<Constant*, 32> NewIdxs;
+ for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
+ if ((i == 1 ||
+ !isa<StructType>(GetElementPtrInst::getIndexedType(
+ SrcElemTy, Ops.slice(1, i - 1)))) &&
+ Ops[i]->getType()->getScalarType() != IntPtrScalarTy) {
+ Any = true;
+ Type *NewType = Ops[i]->getType()->isVectorTy()
+ ? IntPtrTy
+ : IntPtrTy->getScalarType();
+ NewIdxs.push_back(ConstantExpr::getCast(CastInst::getCastOpcode(Ops[i],
+ true,
+ NewType,
+ true),
+ Ops[i], NewType));
+ } else
+ NewIdxs.push_back(Ops[i]);
+ }
+
+ if (!Any)
+ return nullptr;
+
+ Constant *C = ConstantExpr::getGetElementPtr(
+ SrcElemTy, Ops[0], NewIdxs, /*InBounds=*/false, InRangeIndex);
+ if (Constant *Folded = ConstantFoldConstant(C, DL, TLI))
+ C = Folded;
+
+ return C;
+}
+
+/// Strip the pointer casts, but preserve the address space information.
+Constant *StripPtrCastKeepAS(Constant *Ptr, Type *&ElemTy) {
+ assert(Ptr->getType()->isPointerTy() && "Not a pointer type");
+ auto *OldPtrTy = cast<PointerType>(Ptr->getType());
+ Ptr = cast<Constant>(Ptr->stripPointerCasts());
+ auto *NewPtrTy = cast<PointerType>(Ptr->getType());
+
+ ElemTy = NewPtrTy->getPointerElementType();
+
+ // Preserve the address space number of the pointer.
+ if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) {
+ NewPtrTy = ElemTy->getPointerTo(OldPtrTy->getAddressSpace());
+ Ptr = ConstantExpr::getPointerCast(Ptr, NewPtrTy);
+ }
+ return Ptr;
+}
+
+/// If we can symbolically evaluate the GEP constant expression, do so.
+Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
+ ArrayRef<Constant *> Ops,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ const GEPOperator *InnermostGEP = GEP;
+ bool InBounds = GEP->isInBounds();
+
+ Type *SrcElemTy = GEP->getSourceElementType();
+ Type *ResElemTy = GEP->getResultElementType();
+ Type *ResTy = GEP->getType();
+ if (!SrcElemTy->isSized())
+ return nullptr;
+
+ if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy,
+ GEP->getInRangeIndex(), DL, TLI))
+ return C;
+
+ Constant *Ptr = Ops[0];
+ if (!Ptr->getType()->isPointerTy())
+ return nullptr;
+
+ Type *IntPtrTy = DL.getIntPtrType(Ptr->getType());
+
+ // If this is a constant expr gep that is effectively computing an
+ // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12'
+ for (unsigned i = 1, e = Ops.size(); i != e; ++i)
+ if (!isa<ConstantInt>(Ops[i])) {
+
+ // If this is "gep i8* Ptr, (sub 0, V)", fold this as:
+ // "inttoptr (sub (ptrtoint Ptr), V)"
+ if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) {
+ auto *CE = dyn_cast<ConstantExpr>(Ops[1]);
+ assert((!CE || CE->getType() == IntPtrTy) &&
+ "CastGEPIndices didn't canonicalize index types!");
+ if (CE && CE->getOpcode() == Instruction::Sub &&
+ CE->getOperand(0)->isNullValue()) {
+ Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType());
+ Res = ConstantExpr::getSub(Res, CE->getOperand(1));
+ Res = ConstantExpr::getIntToPtr(Res, ResTy);
+ if (auto *FoldedRes = ConstantFoldConstant(Res, DL, TLI))
+ Res = FoldedRes;
+ return Res;
+ }
+ }
+ return nullptr;
+ }
+
+ unsigned BitWidth = DL.getTypeSizeInBits(IntPtrTy);
+ APInt Offset =
+ APInt(BitWidth,
+ DL.getIndexedOffsetInType(
+ SrcElemTy,
+ makeArrayRef((Value * const *)Ops.data() + 1, Ops.size() - 1)));
+ Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy);
+
+ // If this is a GEP of a GEP, fold it all into a single GEP.
+ while (auto *GEP = dyn_cast<GEPOperator>(Ptr)) {
+ InnermostGEP = GEP;
+ InBounds &= GEP->isInBounds();
+
+ SmallVector<Value *, 4> NestedOps(GEP->op_begin() + 1, GEP->op_end());
+
+ // Do not try the incorporate the sub-GEP if some index is not a number.
+ bool AllConstantInt = true;
+ for (Value *NestedOp : NestedOps)
+ if (!isa<ConstantInt>(NestedOp)) {
+ AllConstantInt = false;
+ break;
+ }
+ if (!AllConstantInt)
+ break;
+
+ Ptr = cast<Constant>(GEP->getOperand(0));
+ SrcElemTy = GEP->getSourceElementType();
+ Offset += APInt(BitWidth, DL.getIndexedOffsetInType(SrcElemTy, NestedOps));
+ Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy);
+ }
+
+ // If the base value for this address is a literal integer value, fold the
+ // getelementptr to the resulting integer value casted to the pointer type.
+ APInt BasePtr(BitWidth, 0);
+ if (auto *CE = dyn_cast<ConstantExpr>(Ptr)) {
+ if (CE->getOpcode() == Instruction::IntToPtr) {
+ if (auto *Base = dyn_cast<ConstantInt>(CE->getOperand(0)))
+ BasePtr = Base->getValue().zextOrTrunc(BitWidth);
+ }
+ }
+
+ auto *PTy = cast<PointerType>(Ptr->getType());
+ if ((Ptr->isNullValue() || BasePtr != 0) &&
+ !DL.isNonIntegralPointerType(PTy)) {
+ Constant *C = ConstantInt::get(Ptr->getContext(), Offset + BasePtr);
+ return ConstantExpr::getIntToPtr(C, ResTy);
+ }
+
+ // Otherwise form a regular getelementptr. Recompute the indices so that
+ // we eliminate over-indexing of the notional static type array bounds.
+ // This makes it easy to determine if the getelementptr is "inbounds".
+ // Also, this helps GlobalOpt do SROA on GlobalVariables.
+ Type *Ty = PTy;
+ SmallVector<Constant *, 32> NewIdxs;
+
+ do {
+ if (!Ty->isStructTy()) {
+ if (Ty->isPointerTy()) {
+ // The only pointer indexing we'll do is on the first index of the GEP.
+ if (!NewIdxs.empty())
+ break;
+
+ Ty = SrcElemTy;
+
+ // Only handle pointers to sized types, not pointers to functions.
+ if (!Ty->isSized())
+ return nullptr;
+ } else if (auto *ATy = dyn_cast<SequentialType>(Ty)) {
+ Ty = ATy->getElementType();
+ } else {
+ // We've reached some non-indexable type.
+ break;
+ }
+
+ // Determine which element of the array the offset points into.
+ APInt ElemSize(BitWidth, DL.getTypeAllocSize(Ty));
+ if (ElemSize == 0) {
+ // The element size is 0. This may be [0 x Ty]*, so just use a zero
+ // index for this level and proceed to the next level to see if it can
+ // accommodate the offset.
+ NewIdxs.push_back(ConstantInt::get(IntPtrTy, 0));
+ } else {
+ // The element size is non-zero divide the offset by the element
+ // size (rounding down), to compute the index at this level.
+ bool Overflow;
+ APInt NewIdx = Offset.sdiv_ov(ElemSize, Overflow);
+ if (Overflow)
+ break;
+ Offset -= NewIdx * ElemSize;
+ NewIdxs.push_back(ConstantInt::get(IntPtrTy, NewIdx));
+ }
+ } else {
+ auto *STy = cast<StructType>(Ty);
+ // If we end up with an offset that isn't valid for this struct type, we
+ // can't re-form this GEP in a regular form, so bail out. The pointer
+ // operand likely went through casts that are necessary to make the GEP
+ // sensible.
+ const StructLayout &SL = *DL.getStructLayout(STy);
+ if (Offset.isNegative() || Offset.uge(SL.getSizeInBytes()))
+ break;
+
+ // Determine which field of the struct the offset points into. The
+ // getZExtValue is fine as we've already ensured that the offset is
+ // within the range representable by the StructLayout API.
+ unsigned ElIdx = SL.getElementContainingOffset(Offset.getZExtValue());
+ NewIdxs.push_back(ConstantInt::get(Type::getInt32Ty(Ty->getContext()),
+ ElIdx));
+ Offset -= APInt(BitWidth, SL.getElementOffset(ElIdx));
+ Ty = STy->getTypeAtIndex(ElIdx);
+ }
+ } while (Ty != ResElemTy);
+
+ // If we haven't used up the entire offset by descending the static
+ // type, then the offset is pointing into the middle of an indivisible
+ // member, so we can't simplify it.
+ if (Offset != 0)
+ return nullptr;
+
+ // Preserve the inrange index from the innermost GEP if possible. We must
+ // have calculated the same indices up to and including the inrange index.
+ Optional<unsigned> InRangeIndex;
+ if (Optional<unsigned> LastIRIndex = InnermostGEP->getInRangeIndex())
+ if (SrcElemTy == InnermostGEP->getSourceElementType() &&
+ NewIdxs.size() > *LastIRIndex) {
+ InRangeIndex = LastIRIndex;
+ for (unsigned I = 0; I <= *LastIRIndex; ++I)
+ if (NewIdxs[I] != InnermostGEP->getOperand(I + 1))
+ return nullptr;
+ }
+
+ // Create a GEP.
+ Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs,
+ InBounds, InRangeIndex);
+ assert(C->getType()->getPointerElementType() == Ty &&
+ "Computed GetElementPtr has unexpected type!");
+
+ // If we ended up indexing a member with a type that doesn't match
+ // the type of what the original indices indexed, add a cast.
+ if (Ty != ResElemTy)
+ C = FoldBitCast(C, ResTy, DL);
+
+ return C;
+}
+
+/// Attempt to constant fold an instruction with the
+/// specified opcode and operands. If successful, the constant result is
+/// returned, if not, null is returned. Note that this function can fail when
+/// attempting to fold instructions like loads and stores, which have no
+/// constant expression form.
+Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode,
+ ArrayRef<Constant *> Ops,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ Type *DestTy = InstOrCE->getType();
+
+ if (Instruction::isUnaryOp(Opcode))
+ return ConstantFoldUnaryOpOperand(Opcode, Ops[0], DL);
+
+ if (Instruction::isBinaryOp(Opcode))
+ return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL);
+
+ if (Instruction::isCast(Opcode))
+ return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL);
+
+ if (auto *GEP = dyn_cast<GEPOperator>(InstOrCE)) {
+ if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI))
+ return C;
+
+ return ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), Ops[0],
+ Ops.slice(1), GEP->isInBounds(),
+ GEP->getInRangeIndex());
+ }
+
+ if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE))
+ return CE->getWithOperands(Ops);
+
+ switch (Opcode) {
+ default: return nullptr;
+ case Instruction::ICmp:
+ case Instruction::FCmp: llvm_unreachable("Invalid for compares");
+ case Instruction::Call:
+ if (auto *F = dyn_cast<Function>(Ops.back())) {
+ const auto *Call = cast<CallBase>(InstOrCE);
+ if (canConstantFoldCallTo(Call, F))
+ return ConstantFoldCall(Call, F, Ops.slice(0, Ops.size() - 1), TLI);
+ }
+ return nullptr;
+ case Instruction::Select:
+ return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
+ case Instruction::ExtractElement:
+ return ConstantExpr::getExtractElement(Ops[0], Ops[1]);
+ case Instruction::ExtractValue:
+ return ConstantExpr::getExtractValue(
+ Ops[0], cast<ExtractValueInst>(InstOrCE)->getIndices());
+ case Instruction::InsertElement:
+ return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]);
+ case Instruction::ShuffleVector:
+ return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
+ }
+}
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Constant Folding public APIs
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+Constant *
+ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ SmallDenseMap<Constant *, Constant *> &FoldedOps) {
+ if (!isa<ConstantVector>(C) && !isa<ConstantExpr>(C))
+ return nullptr;
+
+ SmallVector<Constant *, 8> Ops;
+ for (const Use &NewU : C->operands()) {
+ auto *NewC = cast<Constant>(&NewU);
+ // Recursively fold the ConstantExpr's operands. If we have already folded
+ // a ConstantExpr, we don't have to process it again.
+ if (isa<ConstantVector>(NewC) || isa<ConstantExpr>(NewC)) {
+ auto It = FoldedOps.find(NewC);
+ if (It == FoldedOps.end()) {
+ if (auto *FoldedC =
+ ConstantFoldConstantImpl(NewC, DL, TLI, FoldedOps)) {
+ FoldedOps.insert({NewC, FoldedC});
+ NewC = FoldedC;
+ } else {
+ FoldedOps.insert({NewC, NewC});
+ }
+ } else {
+ NewC = It->second;
+ }
+ }
+ Ops.push_back(NewC);
+ }
+
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ if (CE->isCompare())
+ return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1],
+ DL, TLI);
+
+ return ConstantFoldInstOperandsImpl(CE, CE->getOpcode(), Ops, DL, TLI);
+ }
+
+ assert(isa<ConstantVector>(C));
+ return ConstantVector::get(Ops);
+}
+
+} // end anonymous namespace
+
+Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ // Handle PHI nodes quickly here...
+ if (auto *PN = dyn_cast<PHINode>(I)) {
+ Constant *CommonValue = nullptr;
+
+ SmallDenseMap<Constant *, Constant *> FoldedOps;
+ for (Value *Incoming : PN->incoming_values()) {
+ // If the incoming value is undef then skip it. Note that while we could
+ // skip the value if it is equal to the phi node itself we choose not to
+ // because that would break the rule that constant folding only applies if
+ // all operands are constants.
+ if (isa<UndefValue>(Incoming))
+ continue;
+ // If the incoming value is not a constant, then give up.
+ auto *C = dyn_cast<Constant>(Incoming);
+ if (!C)
+ return nullptr;
+ // Fold the PHI's operands.
+ if (auto *FoldedC = ConstantFoldConstantImpl(C, DL, TLI, FoldedOps))
+ C = FoldedC;
+ // If the incoming value is a different constant to
+ // the one we saw previously, then give up.
+ if (CommonValue && C != CommonValue)
+ return nullptr;
+ CommonValue = C;
+ }
+
+ // If we reach here, all incoming values are the same constant or undef.
+ return CommonValue ? CommonValue : UndefValue::get(PN->getType());
+ }
+
+ // Scan the operand list, checking to see if they are all constants, if so,
+ // hand off to ConstantFoldInstOperandsImpl.
+ if (!all_of(I->operands(), [](Use &U) { return isa<Constant>(U); }))
+ return nullptr;
+
+ SmallDenseMap<Constant *, Constant *> FoldedOps;
+ SmallVector<Constant *, 8> Ops;
+ for (const Use &OpU : I->operands()) {
+ auto *Op = cast<Constant>(&OpU);
+ // Fold the Instruction's operands.
+ if (auto *FoldedOp = ConstantFoldConstantImpl(Op, DL, TLI, FoldedOps))
+ Op = FoldedOp;
+
+ Ops.push_back(Op);
+ }
+
+ if (const auto *CI = dyn_cast<CmpInst>(I))
+ return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1],
+ DL, TLI);
+
+ if (const auto *LI = dyn_cast<LoadInst>(I))
+ return ConstantFoldLoadInst(LI, DL);
+
+ if (auto *IVI = dyn_cast<InsertValueInst>(I)) {
+ return ConstantExpr::getInsertValue(
+ cast<Constant>(IVI->getAggregateOperand()),
+ cast<Constant>(IVI->getInsertedValueOperand()),
+ IVI->getIndices());
+ }
+
+ if (auto *EVI = dyn_cast<ExtractValueInst>(I)) {
+ return ConstantExpr::getExtractValue(
+ cast<Constant>(EVI->getAggregateOperand()),
+ EVI->getIndices());
+ }
+
+ return ConstantFoldInstOperands(I, Ops, DL, TLI);
+}
+
+Constant *llvm::ConstantFoldConstant(const Constant *C, const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ SmallDenseMap<Constant *, Constant *> FoldedOps;
+ return ConstantFoldConstantImpl(C, DL, TLI, FoldedOps);
+}
+
+Constant *llvm::ConstantFoldInstOperands(Instruction *I,
+ ArrayRef<Constant *> Ops,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ return ConstantFoldInstOperandsImpl(I, I->getOpcode(), Ops, DL, TLI);
+}
+
+Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
+ Constant *Ops0, Constant *Ops1,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ // fold: icmp (inttoptr x), null -> icmp x, 0
+ // fold: icmp null, (inttoptr x) -> icmp 0, x
+ // fold: icmp (ptrtoint x), 0 -> icmp x, null
+ // fold: icmp 0, (ptrtoint x) -> icmp null, x
+ // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y
+ // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y
+ //
+ // FIXME: The following comment is out of data and the DataLayout is here now.
+ // ConstantExpr::getCompare cannot do this, because it doesn't have DL
+ // around to know if bit truncation is happening.
+ if (auto *CE0 = dyn_cast<ConstantExpr>(Ops0)) {
+ if (Ops1->isNullValue()) {
+ if (CE0->getOpcode() == Instruction::IntToPtr) {
+ Type *IntPtrTy = DL.getIntPtrType(CE0->getType());
+ // Convert the integer value to the right size to ensure we get the
+ // proper extension or truncation.
+ Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0),
+ IntPtrTy, false);
+ Constant *Null = Constant::getNullValue(C->getType());
+ return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI);
+ }
+
+ // Only do this transformation if the int is intptrty in size, otherwise
+ // there is a truncation or extension that we aren't modeling.
+ if (CE0->getOpcode() == Instruction::PtrToInt) {
+ Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType());
+ if (CE0->getType() == IntPtrTy) {
+ Constant *C = CE0->getOperand(0);
+ Constant *Null = Constant::getNullValue(C->getType());
+ return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI);
+ }
+ }
+ }
+
+ if (auto *CE1 = dyn_cast<ConstantExpr>(Ops1)) {
+ if (CE0->getOpcode() == CE1->getOpcode()) {
+ if (CE0->getOpcode() == Instruction::IntToPtr) {
+ Type *IntPtrTy = DL.getIntPtrType(CE0->getType());
+
+ // Convert the integer value to the right size to ensure we get the
+ // proper extension or truncation.
+ Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0),
+ IntPtrTy, false);
+ Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0),
+ IntPtrTy, false);
+ return ConstantFoldCompareInstOperands(Predicate, C0, C1, DL, TLI);
+ }
+
+ // Only do this transformation if the int is intptrty in size, otherwise
+ // there is a truncation or extension that we aren't modeling.
+ if (CE0->getOpcode() == Instruction::PtrToInt) {
+ Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType());
+ if (CE0->getType() == IntPtrTy &&
+ CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()) {
+ return ConstantFoldCompareInstOperands(
+ Predicate, CE0->getOperand(0), CE1->getOperand(0), DL, TLI);
+ }
+ }
+ }
+ }
+
+ // icmp eq (or x, y), 0 -> (icmp eq x, 0) & (icmp eq y, 0)
+ // icmp ne (or x, y), 0 -> (icmp ne x, 0) | (icmp ne y, 0)
+ if ((Predicate == ICmpInst::ICMP_EQ || Predicate == ICmpInst::ICMP_NE) &&
+ CE0->getOpcode() == Instruction::Or && Ops1->isNullValue()) {
+ Constant *LHS = ConstantFoldCompareInstOperands(
+ Predicate, CE0->getOperand(0), Ops1, DL, TLI);
+ Constant *RHS = ConstantFoldCompareInstOperands(
+ Predicate, CE0->getOperand(1), Ops1, DL, TLI);
+ unsigned OpC =
+ Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
+ return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL);
+ }
+ } else if (isa<ConstantExpr>(Ops1)) {
+ // If RHS is a constant expression, but the left side isn't, swap the
+ // operands and try again.
+ Predicate = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)Predicate);
+ return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI);
+ }
+
+ return ConstantExpr::getCompare(Predicate, Ops0, Ops1);
+}
+
+Constant *llvm::ConstantFoldUnaryOpOperand(unsigned Opcode, Constant *Op,
+ const DataLayout &DL) {
+ assert(Instruction::isUnaryOp(Opcode));
+
+ return ConstantExpr::get(Opcode, Op);
+}
+
+Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS,
+ Constant *RHS,
+ const DataLayout &DL) {
+ assert(Instruction::isBinaryOp(Opcode));
+ if (isa<ConstantExpr>(LHS) || isa<ConstantExpr>(RHS))
+ if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL))
+ return C;
+
+ return ConstantExpr::get(Opcode, LHS, RHS);
+}
+
+Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C,
+ Type *DestTy, const DataLayout &DL) {
+ assert(Instruction::isCast(Opcode));
+ switch (Opcode) {
+ default:
+ llvm_unreachable("Missing case");
+ case Instruction::PtrToInt:
+ // If the input is a inttoptr, eliminate the pair. This requires knowing
+ // the width of a pointer, so it can't be done in ConstantExpr::getCast.
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ if (CE->getOpcode() == Instruction::IntToPtr) {
+ Constant *Input = CE->getOperand(0);
+ unsigned InWidth = Input->getType()->getScalarSizeInBits();
+ unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType());
+ if (PtrWidth < InWidth) {
+ Constant *Mask =
+ ConstantInt::get(CE->getContext(),
+ APInt::getLowBitsSet(InWidth, PtrWidth));
+ Input = ConstantExpr::getAnd(Input, Mask);
+ }
+ // Do a zext or trunc to get to the dest size.
+ return ConstantExpr::getIntegerCast(Input, DestTy, false);
+ }
+ }
+ return ConstantExpr::getCast(Opcode, C, DestTy);
+ case Instruction::IntToPtr:
+ // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if
+ // the int size is >= the ptr size and the address spaces are the same.
+ // This requires knowing the width of a pointer, so it can't be done in
+ // ConstantExpr::getCast.
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ if (CE->getOpcode() == Instruction::PtrToInt) {
+ Constant *SrcPtr = CE->getOperand(0);
+ unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType());
+ unsigned MidIntSize = CE->getType()->getScalarSizeInBits();
+
+ if (MidIntSize >= SrcPtrSize) {
+ unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace();
+ if (SrcAS == DestTy->getPointerAddressSpace())
+ return FoldBitCast(CE->getOperand(0), DestTy, DL);
+ }
+ }
+ }
+
+ return ConstantExpr::getCast(Opcode, C, DestTy);
+ case Instruction::Trunc:
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::UIToFP:
+ case Instruction::SIToFP:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::AddrSpaceCast:
+ return ConstantExpr::getCast(Opcode, C, DestTy);
+ case Instruction::BitCast:
+ return FoldBitCast(C, DestTy, DL);
+ }
+}
+
+Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
+ ConstantExpr *CE) {
+ if (!CE->getOperand(1)->isNullValue())
+ return nullptr; // Do not allow stepping over the value!
+
+ // Loop over all of the operands, tracking down which value we are
+ // addressing.
+ for (unsigned i = 2, e = CE->getNumOperands(); i != e; ++i) {
+ C = C->getAggregateElement(CE->getOperand(i));
+ if (!C)
+ return nullptr;
+ }
+ return C;
+}
+
+Constant *
+llvm::ConstantFoldLoadThroughGEPIndices(Constant *C,
+ ArrayRef<Constant *> Indices) {
+ // Loop over all of the operands, tracking down which value we are
+ // addressing.
+ for (Constant *Index : Indices) {
+ C = C->getAggregateElement(Index);
+ if (!C)
+ return nullptr;
+ }
+ return C;
+}
+
+//===----------------------------------------------------------------------===//
+// Constant Folding for Calls
+//
+
+bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
+ if (Call->isNoBuiltin() || Call->isStrictFP())
+ return false;
+ switch (F->getIntrinsicID()) {
+ case Intrinsic::fabs:
+ case Intrinsic::minnum:
+ case Intrinsic::maxnum:
+ case Intrinsic::minimum:
+ case Intrinsic::maximum:
+ case Intrinsic::log:
+ case Intrinsic::log2:
+ case Intrinsic::log10:
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::floor:
+ case Intrinsic::ceil:
+ case Intrinsic::sqrt:
+ case Intrinsic::sin:
+ case Intrinsic::cos:
+ case Intrinsic::trunc:
+ case Intrinsic::rint:
+ case Intrinsic::nearbyint:
+ case Intrinsic::pow:
+ case Intrinsic::powi:
+ case Intrinsic::bswap:
+ case Intrinsic::ctpop:
+ case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::fshl:
+ case Intrinsic::fshr:
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd:
+ case Intrinsic::copysign:
+ case Intrinsic::launder_invariant_group:
+ case Intrinsic::strip_invariant_group:
+ case Intrinsic::round:
+ case Intrinsic::masked_load:
+ case Intrinsic::sadd_with_overflow:
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::ssub_with_overflow:
+ case Intrinsic::usub_with_overflow:
+ case Intrinsic::smul_with_overflow:
+ case Intrinsic::umul_with_overflow:
+ case Intrinsic::sadd_sat:
+ case Intrinsic::uadd_sat:
+ case Intrinsic::ssub_sat:
+ case Intrinsic::usub_sat:
+ case Intrinsic::smul_fix:
+ case Intrinsic::smul_fix_sat:
+ case Intrinsic::convert_from_fp16:
+ case Intrinsic::convert_to_fp16:
+ case Intrinsic::bitreverse:
+ case Intrinsic::x86_sse_cvtss2si:
+ case Intrinsic::x86_sse_cvtss2si64:
+ case Intrinsic::x86_sse_cvttss2si:
+ case Intrinsic::x86_sse_cvttss2si64:
+ case Intrinsic::x86_sse2_cvtsd2si:
+ case Intrinsic::x86_sse2_cvtsd2si64:
+ case Intrinsic::x86_sse2_cvttsd2si:
+ case Intrinsic::x86_sse2_cvttsd2si64:
+ case Intrinsic::x86_avx512_vcvtss2si32:
+ case Intrinsic::x86_avx512_vcvtss2si64:
+ case Intrinsic::x86_avx512_cvttss2si:
+ case Intrinsic::x86_avx512_cvttss2si64:
+ case Intrinsic::x86_avx512_vcvtsd2si32:
+ case Intrinsic::x86_avx512_vcvtsd2si64:
+ case Intrinsic::x86_avx512_cvttsd2si:
+ case Intrinsic::x86_avx512_cvttsd2si64:
+ case Intrinsic::x86_avx512_vcvtss2usi32:
+ case Intrinsic::x86_avx512_vcvtss2usi64:
+ case Intrinsic::x86_avx512_cvttss2usi:
+ case Intrinsic::x86_avx512_cvttss2usi64:
+ case Intrinsic::x86_avx512_vcvtsd2usi32:
+ case Intrinsic::x86_avx512_vcvtsd2usi64:
+ case Intrinsic::x86_avx512_cvttsd2usi:
+ case Intrinsic::x86_avx512_cvttsd2usi64:
+ case Intrinsic::is_constant:
+ return true;
+ default:
+ return false;
+ case Intrinsic::not_intrinsic: break;
+ }
+
+ if (!F->hasName())
+ return false;
+
+ // In these cases, the check of the length is required. We don't want to
+ // return true for a name like "cos\0blah" which strcmp would return equal to
+ // "cos", but has length 8.
+ StringRef Name = F->getName();
+ switch (Name[0]) {
+ default:
+ return false;
+ case 'a':
+ return Name == "acos" || Name == "acosf" ||
+ Name == "asin" || Name == "asinf" ||
+ Name == "atan" || Name == "atanf" ||
+ Name == "atan2" || Name == "atan2f";
+ case 'c':
+ return Name == "ceil" || Name == "ceilf" ||
+ Name == "cos" || Name == "cosf" ||
+ Name == "cosh" || Name == "coshf";
+ case 'e':
+ return Name == "exp" || Name == "expf" ||
+ Name == "exp2" || Name == "exp2f";
+ case 'f':
+ return Name == "fabs" || Name == "fabsf" ||
+ Name == "floor" || Name == "floorf" ||
+ Name == "fmod" || Name == "fmodf";
+ case 'l':
+ return Name == "log" || Name == "logf" ||
+ Name == "log2" || Name == "log2f" ||
+ Name == "log10" || Name == "log10f";
+ case 'n':
+ return Name == "nearbyint" || Name == "nearbyintf";
+ case 'p':
+ return Name == "pow" || Name == "powf";
+ case 'r':
+ return Name == "rint" || Name == "rintf" ||
+ Name == "round" || Name == "roundf";
+ case 's':
+ return Name == "sin" || Name == "sinf" ||
+ Name == "sinh" || Name == "sinhf" ||
+ Name == "sqrt" || Name == "sqrtf";
+ case 't':
+ return Name == "tan" || Name == "tanf" ||
+ Name == "tanh" || Name == "tanhf" ||
+ Name == "trunc" || Name == "truncf";
+ case '_':
+ // Check for various function names that get used for the math functions
+ // when the header files are preprocessed with the macro
+ // __FINITE_MATH_ONLY__ enabled.
+ // The '12' here is the length of the shortest name that can match.
+ // We need to check the size before looking at Name[1] and Name[2]
+ // so we may as well check a limit that will eliminate mismatches.
+ if (Name.size() < 12 || Name[1] != '_')
+ return false;
+ switch (Name[2]) {
+ default:
+ return false;
+ case 'a':
+ return Name == "__acos_finite" || Name == "__acosf_finite" ||
+ Name == "__asin_finite" || Name == "__asinf_finite" ||
+ Name == "__atan2_finite" || Name == "__atan2f_finite";
+ case 'c':
+ return Name == "__cosh_finite" || Name == "__coshf_finite";
+ case 'e':
+ return Name == "__exp_finite" || Name == "__expf_finite" ||
+ Name == "__exp2_finite" || Name == "__exp2f_finite";
+ case 'l':
+ return Name == "__log_finite" || Name == "__logf_finite" ||
+ Name == "__log10_finite" || Name == "__log10f_finite";
+ case 'p':
+ return Name == "__pow_finite" || Name == "__powf_finite";
+ case 's':
+ return Name == "__sinh_finite" || Name == "__sinhf_finite";
+ }
+ }
+}
+
+namespace {
+
+Constant *GetConstantFoldFPValue(double V, Type *Ty) {
+ if (Ty->isHalfTy() || Ty->isFloatTy()) {
+ APFloat APF(V);
+ bool unused;
+ APF.convert(Ty->getFltSemantics(), APFloat::rmNearestTiesToEven, &unused);
+ return ConstantFP::get(Ty->getContext(), APF);
+ }
+ if (Ty->isDoubleTy())
+ return ConstantFP::get(Ty->getContext(), APFloat(V));
+ llvm_unreachable("Can only constant fold half/float/double");
+}
+
+/// Clear the floating-point exception state.
+inline void llvm_fenv_clearexcept() {
+#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT
+ feclearexcept(FE_ALL_EXCEPT);
+#endif
+ errno = 0;
+}
+
+/// Test if a floating-point exception was raised.
+inline bool llvm_fenv_testexcept() {
+ int errno_val = errno;
+ if (errno_val == ERANGE || errno_val == EDOM)
+ return true;
+#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT && HAVE_DECL_FE_INEXACT
+ if (fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT))
+ return true;
+#endif
+ return false;
+}
+
+Constant *ConstantFoldFP(double (*NativeFP)(double), double V, Type *Ty) {
+ llvm_fenv_clearexcept();
+ V = NativeFP(V);
+ if (llvm_fenv_testexcept()) {
+ llvm_fenv_clearexcept();
+ return nullptr;
+ }
+
+ return GetConstantFoldFPValue(V, Ty);
+}
+
+Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V,
+ double W, Type *Ty) {
+ llvm_fenv_clearexcept();
+ V = NativeFP(V, W);
+ if (llvm_fenv_testexcept()) {
+ llvm_fenv_clearexcept();
+ return nullptr;
+ }
+
+ return GetConstantFoldFPValue(V, Ty);
+}
+
+/// Attempt to fold an SSE floating point to integer conversion of a constant
+/// floating point. If roundTowardZero is false, the default IEEE rounding is
+/// used (toward nearest, ties to even). This matches the behavior of the
+/// non-truncating SSE instructions in the default rounding mode. The desired
+/// integer type Ty is used to select how many bits are available for the
+/// result. Returns null if the conversion cannot be performed, otherwise
+/// returns the Constant value resulting from the conversion.
+Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero,
+ Type *Ty, bool IsSigned) {
+ // All of these conversion intrinsics form an integer of at most 64bits.
+ unsigned ResultWidth = Ty->getIntegerBitWidth();
+ assert(ResultWidth <= 64 &&
+ "Can only constant fold conversions to 64 and 32 bit ints");
+
+ uint64_t UIntVal;
+ bool isExact = false;
+ APFloat::roundingMode mode = roundTowardZero? APFloat::rmTowardZero
+ : APFloat::rmNearestTiesToEven;
+ APFloat::opStatus status =
+ Val.convertToInteger(makeMutableArrayRef(UIntVal), ResultWidth,
+ IsSigned, mode, &isExact);
+ if (status != APFloat::opOK &&
+ (!roundTowardZero || status != APFloat::opInexact))
+ return nullptr;
+ return ConstantInt::get(Ty, UIntVal, IsSigned);
+}
+
+double getValueAsDouble(ConstantFP *Op) {
+ Type *Ty = Op->getType();
+
+ if (Ty->isFloatTy())
+ return Op->getValueAPF().convertToFloat();
+
+ if (Ty->isDoubleTy())
+ return Op->getValueAPF().convertToDouble();
+
+ bool unused;
+ APFloat APF = Op->getValueAPF();
+ APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &unused);
+ return APF.convertToDouble();
+}
+
+static bool isManifestConstant(const Constant *c) {
+ if (isa<ConstantData>(c)) {
+ return true;
+ } else if (isa<ConstantAggregate>(c) || isa<ConstantExpr>(c)) {
+ for (const Value *subc : c->operand_values()) {
+ if (!isManifestConstant(cast<Constant>(subc)))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool getConstIntOrUndef(Value *Op, const APInt *&C) {
+ if (auto *CI = dyn_cast<ConstantInt>(Op)) {
+ C = &CI->getValue();
+ return true;
+ }
+ if (isa<UndefValue>(Op)) {
+ C = nullptr;
+ return true;
+ }
+ return false;
+}
+
+static Constant *ConstantFoldScalarCall1(StringRef Name,
+ Intrinsic::ID IntrinsicID,
+ Type *Ty,
+ ArrayRef<Constant *> Operands,
+ const TargetLibraryInfo *TLI,
+ const CallBase *Call) {
+ assert(Operands.size() == 1 && "Wrong number of operands.");
+
+ if (IntrinsicID == Intrinsic::is_constant) {
+ // We know we have a "Constant" argument. But we want to only
+ // return true for manifest constants, not those that depend on
+ // constants with unknowable values, e.g. GlobalValue or BlockAddress.
+ if (isManifestConstant(Operands[0]))
+ return ConstantInt::getTrue(Ty->getContext());
+ return nullptr;
+ }
+ if (isa<UndefValue>(Operands[0])) {
+ // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN.
+ // ctpop() is between 0 and bitwidth, pick 0 for undef.
+ if (IntrinsicID == Intrinsic::cos ||
+ IntrinsicID == Intrinsic::ctpop)
+ return Constant::getNullValue(Ty);
+ if (IntrinsicID == Intrinsic::bswap ||
+ IntrinsicID == Intrinsic::bitreverse ||
+ IntrinsicID == Intrinsic::launder_invariant_group ||
+ IntrinsicID == Intrinsic::strip_invariant_group)
+ return Operands[0];
+ }
+
+ if (isa<ConstantPointerNull>(Operands[0])) {
+ // launder(null) == null == strip(null) iff in addrspace 0
+ if (IntrinsicID == Intrinsic::launder_invariant_group ||
+ IntrinsicID == Intrinsic::strip_invariant_group) {
+ // If instruction is not yet put in a basic block (e.g. when cloning
+ // a function during inlining), Call's caller may not be available.
+ // So check Call's BB first before querying Call->getCaller.
+ const Function *Caller =
+ Call->getParent() ? Call->getCaller() : nullptr;
+ if (Caller &&
+ !NullPointerIsDefined(
+ Caller, Operands[0]->getType()->getPointerAddressSpace())) {
+ return Operands[0];
+ }
+ return nullptr;
+ }
+ }
+
+ if (auto *Op = dyn_cast<ConstantFP>(Operands[0])) {
+ if (IntrinsicID == Intrinsic::convert_to_fp16) {
+ APFloat Val(Op->getValueAPF());
+
+ bool lost = false;
+ Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &lost);
+
+ return ConstantInt::get(Ty->getContext(), Val.bitcastToAPInt());
+ }
+
+ if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
+ return nullptr;
+
+ // Use internal versions of these intrinsics.
+ APFloat U = Op->getValueAPF();
+
+ if (IntrinsicID == Intrinsic::nearbyint || IntrinsicID == Intrinsic::rint) {
+ U.roundToIntegral(APFloat::rmNearestTiesToEven);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ if (IntrinsicID == Intrinsic::round) {
+ U.roundToIntegral(APFloat::rmNearestTiesToAway);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ if (IntrinsicID == Intrinsic::ceil) {
+ U.roundToIntegral(APFloat::rmTowardPositive);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ if (IntrinsicID == Intrinsic::floor) {
+ U.roundToIntegral(APFloat::rmTowardNegative);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ if (IntrinsicID == Intrinsic::trunc) {
+ U.roundToIntegral(APFloat::rmTowardZero);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ if (IntrinsicID == Intrinsic::fabs) {
+ U.clearSign();
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+
+ /// We only fold functions with finite arguments. Folding NaN and inf is
+ /// likely to be aborted with an exception anyway, and some host libms
+ /// have known errors raising exceptions.
+ if (Op->getValueAPF().isNaN() || Op->getValueAPF().isInfinity())
+ return nullptr;
+
+ /// Currently APFloat versions of these functions do not exist, so we use
+ /// the host native double versions. Float versions are not called
+ /// directly but for all these it is true (float)(f((double)arg)) ==
+ /// f(arg). Long double not supported yet.
+ double V = getValueAsDouble(Op);
+
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::log:
+ return ConstantFoldFP(log, V, Ty);
+ case Intrinsic::log2:
+ // TODO: What about hosts that lack a C99 library?
+ return ConstantFoldFP(Log2, V, Ty);
+ case Intrinsic::log10:
+ // TODO: What about hosts that lack a C99 library?
+ return ConstantFoldFP(log10, V, Ty);
+ case Intrinsic::exp:
+ return ConstantFoldFP(exp, V, Ty);
+ case Intrinsic::exp2:
+ // Fold exp2(x) as pow(2, x), in case the host lacks a C99 library.
+ return ConstantFoldBinaryFP(pow, 2.0, V, Ty);
+ case Intrinsic::sin:
+ return ConstantFoldFP(sin, V, Ty);
+ case Intrinsic::cos:
+ return ConstantFoldFP(cos, V, Ty);
+ case Intrinsic::sqrt:
+ return ConstantFoldFP(sqrt, V, Ty);
+ }
+
+ if (!TLI)
+ return nullptr;
+
+ LibFunc Func = NotLibFunc;
+ TLI->getLibFunc(Name, Func);
+ switch (Func) {
+ default:
+ break;
+ case LibFunc_acos:
+ case LibFunc_acosf:
+ case LibFunc_acos_finite:
+ case LibFunc_acosf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldFP(acos, V, Ty);
+ break;
+ case LibFunc_asin:
+ case LibFunc_asinf:
+ case LibFunc_asin_finite:
+ case LibFunc_asinf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldFP(asin, V, Ty);
+ break;
+ case LibFunc_atan:
+ case LibFunc_atanf:
+ if (TLI->has(Func))
+ return ConstantFoldFP(atan, V, Ty);
+ break;
+ case LibFunc_ceil:
+ case LibFunc_ceilf:
+ if (TLI->has(Func)) {
+ U.roundToIntegral(APFloat::rmTowardPositive);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ case LibFunc_cos:
+ case LibFunc_cosf:
+ if (TLI->has(Func))
+ return ConstantFoldFP(cos, V, Ty);
+ break;
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_cosh_finite:
+ case LibFunc_coshf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldFP(cosh, V, Ty);
+ break;
+ case LibFunc_exp:
+ case LibFunc_expf:
+ case LibFunc_exp_finite:
+ case LibFunc_expf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldFP(exp, V, Ty);
+ break;
+ case LibFunc_exp2:
+ case LibFunc_exp2f:
+ case LibFunc_exp2_finite:
+ case LibFunc_exp2f_finite:
+ if (TLI->has(Func))
+ // Fold exp2(x) as pow(2, x), in case the host lacks a C99 library.
+ return ConstantFoldBinaryFP(pow, 2.0, V, Ty);
+ break;
+ case LibFunc_fabs:
+ case LibFunc_fabsf:
+ if (TLI->has(Func)) {
+ U.clearSign();
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ case LibFunc_floor:
+ case LibFunc_floorf:
+ if (TLI->has(Func)) {
+ U.roundToIntegral(APFloat::rmTowardNegative);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ case LibFunc_log:
+ case LibFunc_logf:
+ case LibFunc_log_finite:
+ case LibFunc_logf_finite:
+ if (V > 0.0 && TLI->has(Func))
+ return ConstantFoldFP(log, V, Ty);
+ break;
+ case LibFunc_log2:
+ case LibFunc_log2f:
+ case LibFunc_log2_finite:
+ case LibFunc_log2f_finite:
+ if (V > 0.0 && TLI->has(Func))
+ // TODO: What about hosts that lack a C99 library?
+ return ConstantFoldFP(Log2, V, Ty);
+ break;
+ case LibFunc_log10:
+ case LibFunc_log10f:
+ case LibFunc_log10_finite:
+ case LibFunc_log10f_finite:
+ if (V > 0.0 && TLI->has(Func))
+ // TODO: What about hosts that lack a C99 library?
+ return ConstantFoldFP(log10, V, Ty);
+ break;
+ case LibFunc_nearbyint:
+ case LibFunc_nearbyintf:
+ case LibFunc_rint:
+ case LibFunc_rintf:
+ if (TLI->has(Func)) {
+ U.roundToIntegral(APFloat::rmNearestTiesToEven);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ case LibFunc_round:
+ case LibFunc_roundf:
+ if (TLI->has(Func)) {
+ U.roundToIntegral(APFloat::rmNearestTiesToAway);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ case LibFunc_sin:
+ case LibFunc_sinf:
+ if (TLI->has(Func))
+ return ConstantFoldFP(sin, V, Ty);
+ break;
+ case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinh_finite:
+ case LibFunc_sinhf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldFP(sinh, V, Ty);
+ break;
+ case LibFunc_sqrt:
+ case LibFunc_sqrtf:
+ if (V >= 0.0 && TLI->has(Func))
+ return ConstantFoldFP(sqrt, V, Ty);
+ break;
+ case LibFunc_tan:
+ case LibFunc_tanf:
+ if (TLI->has(Func))
+ return ConstantFoldFP(tan, V, Ty);
+ break;
+ case LibFunc_tanh:
+ case LibFunc_tanhf:
+ if (TLI->has(Func))
+ return ConstantFoldFP(tanh, V, Ty);
+ break;
+ case LibFunc_trunc:
+ case LibFunc_truncf:
+ if (TLI->has(Func)) {
+ U.roundToIntegral(APFloat::rmTowardZero);
+ return ConstantFP::get(Ty->getContext(), U);
+ }
+ break;
+ }
+ return nullptr;
+ }
+
+ if (auto *Op = dyn_cast<ConstantInt>(Operands[0])) {
+ switch (IntrinsicID) {
+ case Intrinsic::bswap:
+ return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap());
+ case Intrinsic::ctpop:
+ return ConstantInt::get(Ty, Op->getValue().countPopulation());
+ case Intrinsic::bitreverse:
+ return ConstantInt::get(Ty->getContext(), Op->getValue().reverseBits());
+ case Intrinsic::convert_from_fp16: {
+ APFloat Val(APFloat::IEEEhalf(), Op->getValue());
+
+ bool lost = false;
+ APFloat::opStatus status = Val.convert(
+ Ty->getFltSemantics(), APFloat::rmNearestTiesToEven, &lost);
+
+ // Conversion is always precise.
+ (void)status;
+ assert(status == APFloat::opOK && !lost &&
+ "Precision lost during fp16 constfolding");
+
+ return ConstantFP::get(Ty->getContext(), Val);
+ }
+ default:
+ return nullptr;
+ }
+ }
+
+ // Support ConstantVector in case we have an Undef in the top.
+ if (isa<ConstantVector>(Operands[0]) ||
+ isa<ConstantDataVector>(Operands[0])) {
+ auto *Op = cast<Constant>(Operands[0]);
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::x86_sse_cvtss2si:
+ case Intrinsic::x86_sse_cvtss2si64:
+ case Intrinsic::x86_sse2_cvtsd2si:
+ case Intrinsic::x86_sse2_cvtsd2si64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/false, Ty,
+ /*IsSigned*/true);
+ break;
+ case Intrinsic::x86_sse_cvttss2si:
+ case Intrinsic::x86_sse_cvttss2si64:
+ case Intrinsic::x86_sse2_cvttsd2si:
+ case Intrinsic::x86_sse2_cvttsd2si64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/true, Ty,
+ /*IsSigned*/true);
+ break;
+ }
+ }
+
+ return nullptr;
+}
+
+static Constant *ConstantFoldScalarCall2(StringRef Name,
+ Intrinsic::ID IntrinsicID,
+ Type *Ty,
+ ArrayRef<Constant *> Operands,
+ const TargetLibraryInfo *TLI,
+ const CallBase *Call) {
+ assert(Operands.size() == 2 && "Wrong number of operands.");
+
+ if (auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
+ if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
+ return nullptr;
+ double Op1V = getValueAsDouble(Op1);
+
+ if (auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
+ if (Op2->getType() != Op1->getType())
+ return nullptr;
+
+ double Op2V = getValueAsDouble(Op2);
+ if (IntrinsicID == Intrinsic::pow) {
+ return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
+ }
+ if (IntrinsicID == Intrinsic::copysign) {
+ APFloat V1 = Op1->getValueAPF();
+ const APFloat &V2 = Op2->getValueAPF();
+ V1.copySign(V2);
+ return ConstantFP::get(Ty->getContext(), V1);
+ }
+
+ if (IntrinsicID == Intrinsic::minnum) {
+ const APFloat &C1 = Op1->getValueAPF();
+ const APFloat &C2 = Op2->getValueAPF();
+ return ConstantFP::get(Ty->getContext(), minnum(C1, C2));
+ }
+
+ if (IntrinsicID == Intrinsic::maxnum) {
+ const APFloat &C1 = Op1->getValueAPF();
+ const APFloat &C2 = Op2->getValueAPF();
+ return ConstantFP::get(Ty->getContext(), maxnum(C1, C2));
+ }
+
+ if (IntrinsicID == Intrinsic::minimum) {
+ const APFloat &C1 = Op1->getValueAPF();
+ const APFloat &C2 = Op2->getValueAPF();
+ return ConstantFP::get(Ty->getContext(), minimum(C1, C2));
+ }
+
+ if (IntrinsicID == Intrinsic::maximum) {
+ const APFloat &C1 = Op1->getValueAPF();
+ const APFloat &C2 = Op2->getValueAPF();
+ return ConstantFP::get(Ty->getContext(), maximum(C1, C2));
+ }
+
+ if (!TLI)
+ return nullptr;
+
+ LibFunc Func = NotLibFunc;
+ TLI->getLibFunc(Name, Func);
+ switch (Func) {
+ default:
+ break;
+ case LibFunc_pow:
+ case LibFunc_powf:
+ case LibFunc_pow_finite:
+ case LibFunc_powf_finite:
+ if (TLI->has(Func))
+ return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
+ break;
+ case LibFunc_fmod:
+ case LibFunc_fmodf:
+ if (TLI->has(Func)) {
+ APFloat V = Op1->getValueAPF();
+ if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF()))
+ return ConstantFP::get(Ty->getContext(), V);
+ }
+ break;
+ case LibFunc_atan2:
+ case LibFunc_atan2f:
+ case LibFunc_atan2_finite:
+ case LibFunc_atan2f_finite:
+ if (TLI->has(Func))
+ return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
+ break;
+ }
+ } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
+ if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy())
+ return ConstantFP::get(Ty->getContext(),
+ APFloat((float)std::pow((float)Op1V,
+ (int)Op2C->getZExtValue())));
+ if (IntrinsicID == Intrinsic::powi && Ty->isFloatTy())
+ return ConstantFP::get(Ty->getContext(),
+ APFloat((float)std::pow((float)Op1V,
+ (int)Op2C->getZExtValue())));
+ if (IntrinsicID == Intrinsic::powi && Ty->isDoubleTy())
+ return ConstantFP::get(Ty->getContext(),
+ APFloat((double)std::pow((double)Op1V,
+ (int)Op2C->getZExtValue())));
+ }
+ return nullptr;
+ }
+
+ if (Operands[0]->getType()->isIntegerTy() &&
+ Operands[1]->getType()->isIntegerTy()) {
+ const APInt *C0, *C1;
+ if (!getConstIntOrUndef(Operands[0], C0) ||
+ !getConstIntOrUndef(Operands[1], C1))
+ return nullptr;
+
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::usub_with_overflow:
+ case Intrinsic::ssub_with_overflow:
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::sadd_with_overflow:
+ // X - undef -> { undef, false }
+ // undef - X -> { undef, false }
+ // X + undef -> { undef, false }
+ // undef + x -> { undef, false }
+ if (!C0 || !C1) {
+ return ConstantStruct::get(
+ cast<StructType>(Ty),
+ {UndefValue::get(Ty->getStructElementType(0)),
+ Constant::getNullValue(Ty->getStructElementType(1))});
+ }
+ LLVM_FALLTHROUGH;
+ case Intrinsic::smul_with_overflow:
+ case Intrinsic::umul_with_overflow: {
+ // undef * X -> { 0, false }
+ // X * undef -> { 0, false }
+ if (!C0 || !C1)
+ return Constant::getNullValue(Ty);
+
+ APInt Res;
+ bool Overflow;
+ switch (IntrinsicID) {
+ default: llvm_unreachable("Invalid case");
+ case Intrinsic::sadd_with_overflow:
+ Res = C0->sadd_ov(*C1, Overflow);
+ break;
+ case Intrinsic::uadd_with_overflow:
+ Res = C0->uadd_ov(*C1, Overflow);
+ break;
+ case Intrinsic::ssub_with_overflow:
+ Res = C0->ssub_ov(*C1, Overflow);
+ break;
+ case Intrinsic::usub_with_overflow:
+ Res = C0->usub_ov(*C1, Overflow);
+ break;
+ case Intrinsic::smul_with_overflow:
+ Res = C0->smul_ov(*C1, Overflow);
+ break;
+ case Intrinsic::umul_with_overflow:
+ Res = C0->umul_ov(*C1, Overflow);
+ break;
+ }
+ Constant *Ops[] = {
+ ConstantInt::get(Ty->getContext(), Res),
+ ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow)
+ };
+ return ConstantStruct::get(cast<StructType>(Ty), Ops);
+ }
+ case Intrinsic::uadd_sat:
+ case Intrinsic::sadd_sat:
+ if (!C0 && !C1)
+ return UndefValue::get(Ty);
+ if (!C0 || !C1)
+ return Constant::getAllOnesValue(Ty);
+ if (IntrinsicID == Intrinsic::uadd_sat)
+ return ConstantInt::get(Ty, C0->uadd_sat(*C1));
+ else
+ return ConstantInt::get(Ty, C0->sadd_sat(*C1));
+ case Intrinsic::usub_sat:
+ case Intrinsic::ssub_sat:
+ if (!C0 && !C1)
+ return UndefValue::get(Ty);
+ if (!C0 || !C1)
+ return Constant::getNullValue(Ty);
+ if (IntrinsicID == Intrinsic::usub_sat)
+ return ConstantInt::get(Ty, C0->usub_sat(*C1));
+ else
+ return ConstantInt::get(Ty, C0->ssub_sat(*C1));
+ case Intrinsic::cttz:
+ case Intrinsic::ctlz:
+ assert(C1 && "Must be constant int");
+
+ // cttz(0, 1) and ctlz(0, 1) are undef.
+ if (C1->isOneValue() && (!C0 || C0->isNullValue()))
+ return UndefValue::get(Ty);
+ if (!C0)
+ return Constant::getNullValue(Ty);
+ if (IntrinsicID == Intrinsic::cttz)
+ return ConstantInt::get(Ty, C0->countTrailingZeros());
+ else
+ return ConstantInt::get(Ty, C0->countLeadingZeros());
+ }
+
+ return nullptr;
+ }
+
+ // Support ConstantVector in case we have an Undef in the top.
+ if ((isa<ConstantVector>(Operands[0]) ||
+ isa<ConstantDataVector>(Operands[0])) &&
+ // Check for default rounding mode.
+ // FIXME: Support other rounding modes?
+ isa<ConstantInt>(Operands[1]) &&
+ cast<ConstantInt>(Operands[1])->getValue() == 4) {
+ auto *Op = cast<Constant>(Operands[0]);
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::x86_avx512_vcvtss2si32:
+ case Intrinsic::x86_avx512_vcvtss2si64:
+ case Intrinsic::x86_avx512_vcvtsd2si32:
+ case Intrinsic::x86_avx512_vcvtsd2si64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/false, Ty,
+ /*IsSigned*/true);
+ break;
+ case Intrinsic::x86_avx512_vcvtss2usi32:
+ case Intrinsic::x86_avx512_vcvtss2usi64:
+ case Intrinsic::x86_avx512_vcvtsd2usi32:
+ case Intrinsic::x86_avx512_vcvtsd2usi64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/false, Ty,
+ /*IsSigned*/false);
+ break;
+ case Intrinsic::x86_avx512_cvttss2si:
+ case Intrinsic::x86_avx512_cvttss2si64:
+ case Intrinsic::x86_avx512_cvttsd2si:
+ case Intrinsic::x86_avx512_cvttsd2si64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/true, Ty,
+ /*IsSigned*/true);
+ break;
+ case Intrinsic::x86_avx512_cvttss2usi:
+ case Intrinsic::x86_avx512_cvttss2usi64:
+ case Intrinsic::x86_avx512_cvttsd2usi:
+ case Intrinsic::x86_avx512_cvttsd2usi64:
+ if (ConstantFP *FPOp =
+ dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U)))
+ return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(),
+ /*roundTowardZero=*/true, Ty,
+ /*IsSigned*/false);
+ break;
+ }
+ }
+ return nullptr;
+}
+
+static Constant *ConstantFoldScalarCall3(StringRef Name,
+ Intrinsic::ID IntrinsicID,
+ Type *Ty,
+ ArrayRef<Constant *> Operands,
+ const TargetLibraryInfo *TLI,
+ const CallBase *Call) {
+ assert(Operands.size() == 3 && "Wrong number of operands.");
+
+ if (const auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
+ if (const auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
+ if (const auto *Op3 = dyn_cast<ConstantFP>(Operands[2])) {
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd: {
+ APFloat V = Op1->getValueAPF();
+ V.fusedMultiplyAdd(Op2->getValueAPF(), Op3->getValueAPF(),
+ APFloat::rmNearestTiesToEven);
+ return ConstantFP::get(Ty->getContext(), V);
+ }
+ }
+ }
+ }
+ }
+
+ if (const auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) {
+ if (const auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) {
+ if (const auto *Op3 = dyn_cast<ConstantInt>(Operands[2])) {
+ switch (IntrinsicID) {
+ default: break;
+ case Intrinsic::smul_fix:
+ case Intrinsic::smul_fix_sat: {
+ // This code performs rounding towards negative infinity in case the
+ // result cannot be represented exactly for the given scale. Targets
+ // that do care about rounding should use a target hook for specifying
+ // how rounding should be done, and provide their own folding to be
+ // consistent with rounding. This is the same approach as used by
+ // DAGTypeLegalizer::ExpandIntRes_MULFIX.
+ APInt Lhs = Op1->getValue();
+ APInt Rhs = Op2->getValue();
+ unsigned Scale = Op3->getValue().getZExtValue();
+ unsigned Width = Lhs.getBitWidth();
+ assert(Scale < Width && "Illegal scale.");
+ unsigned ExtendedWidth = Width * 2;
+ APInt Product = (Lhs.sextOrSelf(ExtendedWidth) *
+ Rhs.sextOrSelf(ExtendedWidth)).ashr(Scale);
+ if (IntrinsicID == Intrinsic::smul_fix_sat) {
+ APInt MaxValue =
+ APInt::getSignedMaxValue(Width).sextOrSelf(ExtendedWidth);
+ APInt MinValue =
+ APInt::getSignedMinValue(Width).sextOrSelf(ExtendedWidth);
+ Product = APIntOps::smin(Product, MaxValue);
+ Product = APIntOps::smax(Product, MinValue);
+ }
+ return ConstantInt::get(Ty->getContext(),
+ Product.sextOrTrunc(Width));
+ }
+ }
+ }
+ }
+ }
+
+ if (IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) {
+ const APInt *C0, *C1, *C2;
+ if (!getConstIntOrUndef(Operands[0], C0) ||
+ !getConstIntOrUndef(Operands[1], C1) ||
+ !getConstIntOrUndef(Operands[2], C2))
+ return nullptr;
+
+ bool IsRight = IntrinsicID == Intrinsic::fshr;
+ if (!C2)
+ return Operands[IsRight ? 1 : 0];
+ if (!C0 && !C1)
+ return UndefValue::get(Ty);
+
+ // The shift amount is interpreted as modulo the bitwidth. If the shift
+ // amount is effectively 0, avoid UB due to oversized inverse shift below.
+ unsigned BitWidth = C2->getBitWidth();
+ unsigned ShAmt = C2->urem(BitWidth);
+ if (!ShAmt)
+ return Operands[IsRight ? 1 : 0];
+
+ // (C0 << ShlAmt) | (C1 >> LshrAmt)
+ unsigned LshrAmt = IsRight ? ShAmt : BitWidth - ShAmt;
+ unsigned ShlAmt = !IsRight ? ShAmt : BitWidth - ShAmt;
+ if (!C0)
+ return ConstantInt::get(Ty, C1->lshr(LshrAmt));
+ if (!C1)
+ return ConstantInt::get(Ty, C0->shl(ShlAmt));
+ return ConstantInt::get(Ty, C0->shl(ShlAmt) | C1->lshr(LshrAmt));
+ }
+
+ return nullptr;
+}
+
+static Constant *ConstantFoldScalarCall(StringRef Name,
+ Intrinsic::ID IntrinsicID,
+ Type *Ty,
+ ArrayRef<Constant *> Operands,
+ const TargetLibraryInfo *TLI,
+ const CallBase *Call) {
+ if (Operands.size() == 1)
+ return ConstantFoldScalarCall1(Name, IntrinsicID, Ty, Operands, TLI, Call);
+
+ if (Operands.size() == 2)
+ return ConstantFoldScalarCall2(Name, IntrinsicID, Ty, Operands, TLI, Call);
+
+ if (Operands.size() == 3)
+ return ConstantFoldScalarCall3(Name, IntrinsicID, Ty, Operands, TLI, Call);
+
+ return nullptr;
+}
+
+static Constant *ConstantFoldVectorCall(StringRef Name,
+ Intrinsic::ID IntrinsicID,
+ VectorType *VTy,
+ ArrayRef<Constant *> Operands,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ const CallBase *Call) {
+ SmallVector<Constant *, 4> Result(VTy->getNumElements());
+ SmallVector<Constant *, 4> Lane(Operands.size());
+ Type *Ty = VTy->getElementType();
+
+ if (IntrinsicID == Intrinsic::masked_load) {
+ auto *SrcPtr = Operands[0];
+ auto *Mask = Operands[2];
+ auto *Passthru = Operands[3];
+
+ Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL);
+
+ SmallVector<Constant *, 32> NewElements;
+ for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {
+ auto *MaskElt = Mask->getAggregateElement(I);
+ if (!MaskElt)
+ break;
+ auto *PassthruElt = Passthru->getAggregateElement(I);
+ auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr;
+ if (isa<UndefValue>(MaskElt)) {
+ if (PassthruElt)
+ NewElements.push_back(PassthruElt);
+ else if (VecElt)
+ NewElements.push_back(VecElt);
+ else
+ return nullptr;
+ }
+ if (MaskElt->isNullValue()) {
+ if (!PassthruElt)
+ return nullptr;
+ NewElements.push_back(PassthruElt);
+ } else if (MaskElt->isOneValue()) {
+ if (!VecElt)
+ return nullptr;
+ NewElements.push_back(VecElt);
+ } else {
+ return nullptr;
+ }
+ }
+ if (NewElements.size() != VTy->getNumElements())
+ return nullptr;
+ return ConstantVector::get(NewElements);
+ }
+
+ for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {
+ // Gather a column of constants.
+ for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
+ // Some intrinsics use a scalar type for certain arguments.
+ if (hasVectorInstrinsicScalarOpd(IntrinsicID, J)) {
+ Lane[J] = Operands[J];
+ continue;
+ }
+
+ Constant *Agg = Operands[J]->getAggregateElement(I);
+ if (!Agg)
+ return nullptr;
+
+ Lane[J] = Agg;
+ }
+
+ // Use the regular scalar folding to simplify this column.
+ Constant *Folded =
+ ConstantFoldScalarCall(Name, IntrinsicID, Ty, Lane, TLI, Call);
+ if (!Folded)
+ return nullptr;
+ Result[I] = Folded;
+ }
+
+ return ConstantVector::get(Result);
+}
+
+} // end anonymous namespace
+
+Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F,
+ ArrayRef<Constant *> Operands,
+ const TargetLibraryInfo *TLI) {
+ if (Call->isNoBuiltin() || Call->isStrictFP())
+ return nullptr;
+ if (!F->hasName())
+ return nullptr;
+ StringRef Name = F->getName();
+
+ Type *Ty = F->getReturnType();
+
+ if (auto *VTy = dyn_cast<VectorType>(Ty))
+ return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands,
+ F->getParent()->getDataLayout(), TLI, Call);
+
+ return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI,
+ Call);
+}
+
+bool llvm::isMathLibCallNoop(const CallBase *Call,
+ const TargetLibraryInfo *TLI) {
+ // FIXME: Refactor this code; this duplicates logic in LibCallsShrinkWrap
+ // (and to some extent ConstantFoldScalarCall).
+ if (Call->isNoBuiltin() || Call->isStrictFP())
+ return false;
+ Function *F = Call->getCalledFunction();
+ if (!F)
+ return false;
+
+ LibFunc Func;
+ if (!TLI || !TLI->getLibFunc(*F, Func))
+ return false;
+
+ if (Call->getNumArgOperands() == 1) {
+ if (ConstantFP *OpC = dyn_cast<ConstantFP>(Call->getArgOperand(0))) {
+ const APFloat &Op = OpC->getValueAPF();
+ switch (Func) {
+ case LibFunc_logl:
+ case LibFunc_log:
+ case LibFunc_logf:
+ case LibFunc_log2l:
+ case LibFunc_log2:
+ case LibFunc_log2f:
+ case LibFunc_log10l:
+ case LibFunc_log10:
+ case LibFunc_log10f:
+ return Op.isNaN() || (!Op.isZero() && !Op.isNegative());
+
+ case LibFunc_expl:
+ case LibFunc_exp:
+ case LibFunc_expf:
+ // FIXME: These boundaries are slightly conservative.
+ if (OpC->getType()->isDoubleTy())
+ return Op.compare(APFloat(-745.0)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(709.0)) != APFloat::cmpGreaterThan;
+ if (OpC->getType()->isFloatTy())
+ return Op.compare(APFloat(-103.0f)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(88.0f)) != APFloat::cmpGreaterThan;
+ break;
+
+ case LibFunc_exp2l:
+ case LibFunc_exp2:
+ case LibFunc_exp2f:
+ // FIXME: These boundaries are slightly conservative.
+ if (OpC->getType()->isDoubleTy())
+ return Op.compare(APFloat(-1074.0)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(1023.0)) != APFloat::cmpGreaterThan;
+ if (OpC->getType()->isFloatTy())
+ return Op.compare(APFloat(-149.0f)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(127.0f)) != APFloat::cmpGreaterThan;
+ break;
+
+ case LibFunc_sinl:
+ case LibFunc_sin:
+ case LibFunc_sinf:
+ case LibFunc_cosl:
+ case LibFunc_cos:
+ case LibFunc_cosf:
+ return !Op.isInfinity();
+
+ case LibFunc_tanl:
+ case LibFunc_tan:
+ case LibFunc_tanf: {
+ // FIXME: Stop using the host math library.
+ // FIXME: The computation isn't done in the right precision.
+ Type *Ty = OpC->getType();
+ if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) {
+ double OpV = getValueAsDouble(OpC);
+ return ConstantFoldFP(tan, OpV, Ty) != nullptr;
+ }
+ break;
+ }
+
+ case LibFunc_asinl:
+ case LibFunc_asin:
+ case LibFunc_asinf:
+ case LibFunc_acosl:
+ case LibFunc_acos:
+ case LibFunc_acosf:
+ return Op.compare(APFloat(Op.getSemantics(), "-1")) !=
+ APFloat::cmpLessThan &&
+ Op.compare(APFloat(Op.getSemantics(), "1")) !=
+ APFloat::cmpGreaterThan;
+
+ case LibFunc_sinh:
+ case LibFunc_cosh:
+ case LibFunc_sinhf:
+ case LibFunc_coshf:
+ case LibFunc_sinhl:
+ case LibFunc_coshl:
+ // FIXME: These boundaries are slightly conservative.
+ if (OpC->getType()->isDoubleTy())
+ return Op.compare(APFloat(-710.0)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(710.0)) != APFloat::cmpGreaterThan;
+ if (OpC->getType()->isFloatTy())
+ return Op.compare(APFloat(-89.0f)) != APFloat::cmpLessThan &&
+ Op.compare(APFloat(89.0f)) != APFloat::cmpGreaterThan;
+ break;
+
+ case LibFunc_sqrtl:
+ case LibFunc_sqrt:
+ case LibFunc_sqrtf:
+ return Op.isNaN() || Op.isZero() || !Op.isNegative();
+
+ // FIXME: Add more functions: sqrt_finite, atanh, expm1, log1p,
+ // maybe others?
+ default:
+ break;
+ }
+ }
+ }
+
+ if (Call->getNumArgOperands() == 2) {
+ ConstantFP *Op0C = dyn_cast<ConstantFP>(Call->getArgOperand(0));
+ ConstantFP *Op1C = dyn_cast<ConstantFP>(Call->getArgOperand(1));
+ if (Op0C && Op1C) {
+ const APFloat &Op0 = Op0C->getValueAPF();
+ const APFloat &Op1 = Op1C->getValueAPF();
+
+ switch (Func) {
+ case LibFunc_powl:
+ case LibFunc_pow:
+ case LibFunc_powf: {
+ // FIXME: Stop using the host math library.
+ // FIXME: The computation isn't done in the right precision.
+ Type *Ty = Op0C->getType();
+ if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) {
+ if (Ty == Op1C->getType()) {
+ double Op0V = getValueAsDouble(Op0C);
+ double Op1V = getValueAsDouble(Op1C);
+ return ConstantFoldBinaryFP(pow, Op0V, Op1V, Ty) != nullptr;
+ }
+ }
+ break;
+ }
+
+ case LibFunc_fmodl:
+ case LibFunc_fmod:
+ case LibFunc_fmodf:
+ return Op0.isNaN() || Op1.isNaN() ||
+ (!Op0.isInfinity() && !Op1.isZero());
+
+ default:
+ break;
+ }
+ }
+ }
+
+ return false;
+}
diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp
new file mode 100644
index 000000000000..bf0cdbfd0c8b
--- /dev/null
+++ b/llvm/lib/Analysis/CostModel.cpp
@@ -0,0 +1,111 @@
+//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the cost model analysis. It provides a very basic cost
+// estimation for LLVM-IR. This analysis uses the services of the codegen
+// to approximate the cost of any IR instruction when lowered to machine
+// instructions. The cost results are unit-less and the cost number represents
+// the throughput of the machine assuming that all loads hit the cache, all
+// branches are predicted, etc. The cost numbers can be added in order to
+// compare two or more transformation alternatives.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+static cl::opt<TargetTransformInfo::TargetCostKind> CostKind(
+ "cost-kind", cl::desc("Target cost kind"),
+ cl::init(TargetTransformInfo::TCK_RecipThroughput),
+ cl::values(clEnumValN(TargetTransformInfo::TCK_RecipThroughput,
+ "throughput", "Reciprocal throughput"),
+ clEnumValN(TargetTransformInfo::TCK_Latency,
+ "latency", "Instruction latency"),
+ clEnumValN(TargetTransformInfo::TCK_CodeSize,
+ "code-size", "Code size")));
+
+#define CM_NAME "cost-model"
+#define DEBUG_TYPE CM_NAME
+
+namespace {
+ class CostModelAnalysis : public FunctionPass {
+
+ public:
+ static char ID; // Class identification, replacement for typeinfo
+ CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
+ initializeCostModelAnalysisPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ /// Returns the expected cost of the instruction.
+ /// Returns -1 if the cost is unknown.
+ /// Note, this method does not cache the cost calculation and it
+ /// can be expensive in some cases.
+ unsigned getInstructionCost(const Instruction *I) const {
+ return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput);
+ }
+
+ private:
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ bool runOnFunction(Function &F) override;
+ void print(raw_ostream &OS, const Module*) const override;
+
+ /// The function that we analyze.
+ Function *F;
+ /// Target information.
+ const TargetTransformInfo *TTI;
+ };
+} // End of anonymous namespace
+
+// Register this pass.
+char CostModelAnalysis::ID = 0;
+static const char cm_name[] = "Cost Model Analysis";
+INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
+INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true)
+
+FunctionPass *llvm::createCostModelAnalysisPass() {
+ return new CostModelAnalysis();
+}
+
+void
+CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
+
+bool
+CostModelAnalysis::runOnFunction(Function &F) {
+ this->F = &F;
+ auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
+ TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
+
+ return false;
+}
+
+void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
+ if (!F)
+ return;
+
+ for (BasicBlock &B : *F) {
+ for (Instruction &Inst : B) {
+ unsigned Cost = TTI->getInstructionCost(&Inst, CostKind);
+ if (Cost != (unsigned)-1)
+ OS << "Cost Model: Found an estimated cost of " << Cost;
+ else
+ OS << "Cost Model: Unknown cost";
+
+ OS << " for instruction: " << Inst << "\n";
+ }
+ }
+}
diff --git a/llvm/lib/Analysis/DDG.cpp b/llvm/lib/Analysis/DDG.cpp
new file mode 100644
index 000000000000..b5c3c761ad98
--- /dev/null
+++ b/llvm/lib/Analysis/DDG.cpp
@@ -0,0 +1,203 @@
+//===- DDG.cpp - Data Dependence Graph -------------------------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The implementation for the data dependence graph.
+//===----------------------------------------------------------------------===//
+#include "llvm/Analysis/DDG.h"
+#include "llvm/Analysis/LoopInfo.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "ddg"
+
+template class llvm::DGEdge<DDGNode, DDGEdge>;
+template class llvm::DGNode<DDGNode, DDGEdge>;
+template class llvm::DirectedGraph<DDGNode, DDGEdge>;
+
+//===--------------------------------------------------------------------===//
+// DDGNode implementation
+//===--------------------------------------------------------------------===//
+DDGNode::~DDGNode() {}
+
+bool DDGNode::collectInstructions(
+ llvm::function_ref<bool(Instruction *)> const &Pred,
+ InstructionListType &IList) const {
+ assert(IList.empty() && "Expected the IList to be empty on entry.");
+ if (isa<SimpleDDGNode>(this)) {
+ for (auto *I : cast<const SimpleDDGNode>(this)->getInstructions())
+ if (Pred(I))
+ IList.push_back(I);
+ } else
+ llvm_unreachable("unimplemented type of node");
+ return !IList.empty();
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGNode::NodeKind K) {
+ const char *Out;
+ switch (K) {
+ case DDGNode::NodeKind::SingleInstruction:
+ Out = "single-instruction";
+ break;
+ case DDGNode::NodeKind::MultiInstruction:
+ Out = "multi-instruction";
+ break;
+ case DDGNode::NodeKind::Root:
+ Out = "root";
+ break;
+ case DDGNode::NodeKind::Unknown:
+ Out = "??";
+ break;
+ }
+ OS << Out;
+ return OS;
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGNode &N) {
+ OS << "Node Address:" << &N << ":" << N.getKind() << "\n";
+ if (isa<SimpleDDGNode>(N)) {
+ OS << " Instructions:\n";
+ for (auto *I : cast<const SimpleDDGNode>(N).getInstructions())
+ OS.indent(2) << *I << "\n";
+ } else if (!isa<RootDDGNode>(N))
+ llvm_unreachable("unimplemented type of node");
+
+ OS << (N.getEdges().empty() ? " Edges:none!\n" : " Edges:\n");
+ for (auto &E : N.getEdges())
+ OS.indent(2) << *E;
+ return OS;
+}
+
+//===--------------------------------------------------------------------===//
+// SimpleDDGNode implementation
+//===--------------------------------------------------------------------===//
+
+SimpleDDGNode::SimpleDDGNode(Instruction &I)
+ : DDGNode(NodeKind::SingleInstruction), InstList() {
+ assert(InstList.empty() && "Expected empty list.");
+ InstList.push_back(&I);
+}
+
+SimpleDDGNode::SimpleDDGNode(const SimpleDDGNode &N)
+ : DDGNode(N), InstList(N.InstList) {
+ assert(((getKind() == NodeKind::SingleInstruction && InstList.size() == 1) ||
+ (getKind() == NodeKind::MultiInstruction && InstList.size() > 1)) &&
+ "constructing from invalid simple node.");
+}
+
+SimpleDDGNode::SimpleDDGNode(SimpleDDGNode &&N)
+ : DDGNode(std::move(N)), InstList(std::move(N.InstList)) {
+ assert(((getKind() == NodeKind::SingleInstruction && InstList.size() == 1) ||
+ (getKind() == NodeKind::MultiInstruction && InstList.size() > 1)) &&
+ "constructing from invalid simple node.");
+}
+
+SimpleDDGNode::~SimpleDDGNode() { InstList.clear(); }
+
+//===--------------------------------------------------------------------===//
+// DDGEdge implementation
+//===--------------------------------------------------------------------===//
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGEdge::EdgeKind K) {
+ const char *Out;
+ switch (K) {
+ case DDGEdge::EdgeKind::RegisterDefUse:
+ Out = "def-use";
+ break;
+ case DDGEdge::EdgeKind::MemoryDependence:
+ Out = "memory";
+ break;
+ case DDGEdge::EdgeKind::Rooted:
+ Out = "rooted";
+ break;
+ case DDGEdge::EdgeKind::Unknown:
+ Out = "??";
+ break;
+ }
+ OS << Out;
+ return OS;
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const DDGEdge &E) {
+ OS << "[" << E.getKind() << "] to " << &E.getTargetNode() << "\n";
+ return OS;
+}
+
+//===--------------------------------------------------------------------===//
+// DataDependenceGraph implementation
+//===--------------------------------------------------------------------===//
+using BasicBlockListType = SmallVector<BasicBlock *, 8>;
+
+DataDependenceGraph::DataDependenceGraph(Function &F, DependenceInfo &D)
+ : DependenceGraphInfo(F.getName().str(), D) {
+ BasicBlockListType BBList;
+ for (auto &BB : F.getBasicBlockList())
+ BBList.push_back(&BB);
+ DDGBuilder(*this, D, BBList).populate();
+}
+
+DataDependenceGraph::DataDependenceGraph(const Loop &L, DependenceInfo &D)
+ : DependenceGraphInfo(Twine(L.getHeader()->getParent()->getName() + "." +
+ L.getHeader()->getName())
+ .str(),
+ D) {
+ BasicBlockListType BBList;
+ for (BasicBlock *BB : L.blocks())
+ BBList.push_back(BB);
+ DDGBuilder(*this, D, BBList).populate();
+}
+
+DataDependenceGraph::~DataDependenceGraph() {
+ for (auto *N : Nodes) {
+ for (auto *E : *N)
+ delete E;
+ delete N;
+ }
+}
+
+bool DataDependenceGraph::addNode(DDGNode &N) {
+ if (!DDGBase::addNode(N))
+ return false;
+
+ // In general, if the root node is already created and linked, it is not safe
+ // to add new nodes since they may be unreachable by the root.
+ // TODO: Allow adding Pi-block nodes after root is created. Pi-blocks are an
+ // exception because they represent components that are already reachable by
+ // root.
+ assert(!Root && "Root node is already added. No more nodes can be added.");
+ if (isa<RootDDGNode>(N))
+ Root = &N;
+
+ return true;
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const DataDependenceGraph &G) {
+ for (auto *Node : G)
+ OS << *Node << "\n";
+ return OS;
+}
+
+//===--------------------------------------------------------------------===//
+// DDG Analysis Passes
+//===--------------------------------------------------------------------===//
+
+/// DDG as a loop pass.
+DDGAnalysis::Result DDGAnalysis::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR) {
+ Function *F = L.getHeader()->getParent();
+ DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI);
+ return std::make_unique<DataDependenceGraph>(L, DI);
+}
+AnalysisKey DDGAnalysis::Key;
+
+PreservedAnalyses DDGAnalysisPrinterPass::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &U) {
+ OS << "'DDG' for loop '" << L.getHeader()->getName() << "':\n";
+ OS << *AM.getResult<DDGAnalysis>(L, AR);
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/Delinearization.cpp b/llvm/lib/Analysis/Delinearization.cpp
new file mode 100644
index 000000000000..c1043e446beb
--- /dev/null
+++ b/llvm/lib/Analysis/Delinearization.cpp
@@ -0,0 +1,129 @@
+//===---- Delinearization.cpp - MultiDimensional Index Delinearization ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements an analysis pass that tries to delinearize all GEP
+// instructions in all loops using the SCEV analysis functionality. This pass is
+// only used for testing purposes: if your pass needs delinearization, please
+// use the on-demand SCEVAddRecExpr::delinearize() function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DL_NAME "delinearize"
+#define DEBUG_TYPE DL_NAME
+
+namespace {
+
+class Delinearization : public FunctionPass {
+ Delinearization(const Delinearization &); // do not implement
+protected:
+ Function *F;
+ LoopInfo *LI;
+ ScalarEvolution *SE;
+
+public:
+ static char ID; // Pass identification, replacement for typeid
+
+ Delinearization() : FunctionPass(ID) {
+ initializeDelinearizationPass(*PassRegistry::getPassRegistry());
+ }
+ bool runOnFunction(Function &F) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ void print(raw_ostream &O, const Module *M = nullptr) const override;
+};
+
+} // end anonymous namespace
+
+void Delinearization::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+}
+
+bool Delinearization::runOnFunction(Function &F) {
+ this->F = &F;
+ SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+ LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ return false;
+}
+
+void Delinearization::print(raw_ostream &O, const Module *) const {
+ O << "Delinearization on function " << F->getName() << ":\n";
+ for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+ Instruction *Inst = &(*I);
+
+ // Only analyze loads and stores.
+ if (!isa<StoreInst>(Inst) && !isa<LoadInst>(Inst) &&
+ !isa<GetElementPtrInst>(Inst))
+ continue;
+
+ const BasicBlock *BB = Inst->getParent();
+ // Delinearize the memory access as analyzed in all the surrounding loops.
+ // Do not analyze memory accesses outside loops.
+ for (Loop *L = LI->getLoopFor(BB); L != nullptr; L = L->getParentLoop()) {
+ const SCEV *AccessFn = SE->getSCEVAtScope(getPointerOperand(Inst), L);
+
+ const SCEVUnknown *BasePointer =
+ dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn));
+ // Do not delinearize if we cannot find the base pointer.
+ if (!BasePointer)
+ break;
+ AccessFn = SE->getMinusSCEV(AccessFn, BasePointer);
+
+ O << "\n";
+ O << "Inst:" << *Inst << "\n";
+ O << "In Loop with Header: " << L->getHeader()->getName() << "\n";
+ O << "AccessFunction: " << *AccessFn << "\n";
+
+ SmallVector<const SCEV *, 3> Subscripts, Sizes;
+ SE->delinearize(AccessFn, Subscripts, Sizes, SE->getElementSize(Inst));
+ if (Subscripts.size() == 0 || Sizes.size() == 0 ||
+ Subscripts.size() != Sizes.size()) {
+ O << "failed to delinearize\n";
+ continue;
+ }
+
+ O << "Base offset: " << *BasePointer << "\n";
+ O << "ArrayDecl[UnknownSize]";
+ int Size = Subscripts.size();
+ for (int i = 0; i < Size - 1; i++)
+ O << "[" << *Sizes[i] << "]";
+ O << " with elements of " << *Sizes[Size - 1] << " bytes.\n";
+
+ O << "ArrayRef";
+ for (int i = 0; i < Size; i++)
+ O << "[" << *Subscripts[i] << "]";
+ O << "\n";
+ }
+ }
+}
+
+char Delinearization::ID = 0;
+static const char delinearization_name[] = "Delinearization";
+INITIALIZE_PASS_BEGIN(Delinearization, DL_NAME, delinearization_name, true,
+ true)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(Delinearization, DL_NAME, delinearization_name, true, true)
+
+FunctionPass *llvm::createDelinearizationPass() { return new Delinearization; }
diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
new file mode 100644
index 000000000000..01b8ff10d355
--- /dev/null
+++ b/llvm/lib/Analysis/DemandedBits.cpp
@@ -0,0 +1,488 @@
+//===- DemandedBits.cpp - Determine demanded bits -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements a demanded bits analysis. A demanded bit is one that
+// contributes to a result; bits that are not demanded can be either zero or
+// one without affecting control or data flow. For example in this sequence:
+//
+// %1 = add i32 %x, %y
+// %2 = trunc i32 %1 to i16
+//
+// Only the lowest 16 bits of %1 are demanded; the rest are removed by the
+// trunc.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DemandedBits.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cstdint>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "demanded-bits"
+
+char DemandedBitsWrapperPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
+ "Demanded bits analysis", false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
+ "Demanded bits analysis", false, false)
+
+DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
+ initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesCFG();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.setPreservesAll();
+}
+
+void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
+ DB->print(OS);
+}
+
+static bool isAlwaysLive(Instruction *I) {
+ return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
+ I->mayHaveSideEffects();
+}
+
+void DemandedBits::determineLiveOperandBits(
+ const Instruction *UserI, const Value *Val, unsigned OperandNo,
+ const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
+ bool &KnownBitsComputed) {
+ unsigned BitWidth = AB.getBitWidth();
+
+ // We're called once per operand, but for some instructions, we need to
+ // compute known bits of both operands in order to determine the live bits of
+ // either (when both operands are instructions themselves). We don't,
+ // however, want to do this twice, so we cache the result in APInts that live
+ // in the caller. For the two-relevant-operands case, both operand values are
+ // provided here.
+ auto ComputeKnownBits =
+ [&](unsigned BitWidth, const Value *V1, const Value *V2) {
+ if (KnownBitsComputed)
+ return;
+ KnownBitsComputed = true;
+
+ const DataLayout &DL = UserI->getModule()->getDataLayout();
+ Known = KnownBits(BitWidth);
+ computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
+
+ if (V2) {
+ Known2 = KnownBits(BitWidth);
+ computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
+ }
+ };
+
+ switch (UserI->getOpcode()) {
+ default: break;
+ case Instruction::Call:
+ case Instruction::Invoke:
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI))
+ switch (II->getIntrinsicID()) {
+ default: break;
+ case Intrinsic::bswap:
+ // The alive bits of the input are the swapped alive bits of
+ // the output.
+ AB = AOut.byteSwap();
+ break;
+ case Intrinsic::bitreverse:
+ // The alive bits of the input are the reversed alive bits of
+ // the output.
+ AB = AOut.reverseBits();
+ break;
+ case Intrinsic::ctlz:
+ if (OperandNo == 0) {
+ // We need some output bits, so we need all bits of the
+ // input to the left of, and including, the leftmost bit
+ // known to be one.
+ ComputeKnownBits(BitWidth, Val, nullptr);
+ AB = APInt::getHighBitsSet(BitWidth,
+ std::min(BitWidth, Known.countMaxLeadingZeros()+1));
+ }
+ break;
+ case Intrinsic::cttz:
+ if (OperandNo == 0) {
+ // We need some output bits, so we need all bits of the
+ // input to the right of, and including, the rightmost bit
+ // known to be one.
+ ComputeKnownBits(BitWidth, Val, nullptr);
+ AB = APInt::getLowBitsSet(BitWidth,
+ std::min(BitWidth, Known.countMaxTrailingZeros()+1));
+ }
+ break;
+ case Intrinsic::fshl:
+ case Intrinsic::fshr: {
+ const APInt *SA;
+ if (OperandNo == 2) {
+ // Shift amount is modulo the bitwidth. For powers of two we have
+ // SA % BW == SA & (BW - 1).
+ if (isPowerOf2_32(BitWidth))
+ AB = BitWidth - 1;
+ } else if (match(II->getOperand(2), m_APInt(SA))) {
+ // Normalize to funnel shift left. APInt shifts of BitWidth are well-
+ // defined, so no need to special-case zero shifts here.
+ uint64_t ShiftAmt = SA->urem(BitWidth);
+ if (II->getIntrinsicID() == Intrinsic::fshr)
+ ShiftAmt = BitWidth - ShiftAmt;
+
+ if (OperandNo == 0)
+ AB = AOut.lshr(ShiftAmt);
+ else if (OperandNo == 1)
+ AB = AOut.shl(BitWidth - ShiftAmt);
+ }
+ break;
+ }
+ }
+ break;
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Mul:
+ // Find the highest live output bit. We don't need any more input
+ // bits than that (adds, and thus subtracts, ripple only to the
+ // left).
+ AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
+ break;
+ case Instruction::Shl:
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
+ uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
+ AB = AOut.lshr(ShiftAmt);
+
+ // If the shift is nuw/nsw, then the high bits are not dead
+ // (because we've promised that they *must* be zero).
+ const ShlOperator *S = cast<ShlOperator>(UserI);
+ if (S->hasNoSignedWrap())
+ AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
+ else if (S->hasNoUnsignedWrap())
+ AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
+ }
+ }
+ break;
+ case Instruction::LShr:
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
+ uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
+ AB = AOut.shl(ShiftAmt);
+
+ // If the shift is exact, then the low bits are not dead
+ // (they must be zero).
+ if (cast<LShrOperator>(UserI)->isExact())
+ AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ }
+ }
+ break;
+ case Instruction::AShr:
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
+ uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
+ AB = AOut.shl(ShiftAmt);
+ // Because the high input bit is replicated into the
+ // high-order bits of the result, if we need any of those
+ // bits, then we must keep the highest input bit.
+ if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
+ .getBoolValue())
+ AB.setSignBit();
+
+ // If the shift is exact, then the low bits are not dead
+ // (they must be zero).
+ if (cast<AShrOperator>(UserI)->isExact())
+ AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ }
+ }
+ break;
+ case Instruction::And:
+ AB = AOut;
+
+ // For bits that are known zero, the corresponding bits in the
+ // other operand are dead (unless they're both zero, in which
+ // case they can't both be dead, so just mark the LHS bits as
+ // dead).
+ ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+ if (OperandNo == 0)
+ AB &= ~Known2.Zero;
+ else
+ AB &= ~(Known.Zero & ~Known2.Zero);
+ break;
+ case Instruction::Or:
+ AB = AOut;
+
+ // For bits that are known one, the corresponding bits in the
+ // other operand are dead (unless they're both one, in which
+ // case they can't both be dead, so just mark the LHS bits as
+ // dead).
+ ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+ if (OperandNo == 0)
+ AB &= ~Known2.One;
+ else
+ AB &= ~(Known.One & ~Known2.One);
+ break;
+ case Instruction::Xor:
+ case Instruction::PHI:
+ AB = AOut;
+ break;
+ case Instruction::Trunc:
+ AB = AOut.zext(BitWidth);
+ break;
+ case Instruction::ZExt:
+ AB = AOut.trunc(BitWidth);
+ break;
+ case Instruction::SExt:
+ AB = AOut.trunc(BitWidth);
+ // Because the high input bit is replicated into the
+ // high-order bits of the result, if we need any of those
+ // bits, then we must keep the highest input bit.
+ if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
+ AOut.getBitWidth() - BitWidth))
+ .getBoolValue())
+ AB.setSignBit();
+ break;
+ case Instruction::Select:
+ if (OperandNo != 0)
+ AB = AOut;
+ break;
+ case Instruction::ExtractElement:
+ if (OperandNo == 0)
+ AB = AOut;
+ break;
+ case Instruction::InsertElement:
+ case Instruction::ShuffleVector:
+ if (OperandNo == 0 || OperandNo == 1)
+ AB = AOut;
+ break;
+ }
+}
+
+bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
+ auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ DB.emplace(F, AC, DT);
+ return false;
+}
+
+void DemandedBitsWrapperPass::releaseMemory() {
+ DB.reset();
+}
+
+void DemandedBits::performAnalysis() {
+ if (Analyzed)
+ // Analysis already completed for this function.
+ return;
+ Analyzed = true;
+
+ Visited.clear();
+ AliveBits.clear();
+ DeadUses.clear();
+
+ SmallSetVector<Instruction*, 16> Worklist;
+
+ // Collect the set of "root" instructions that are known live.
+ for (Instruction &I : instructions(F)) {
+ if (!isAlwaysLive(&I))
+ continue;
+
+ LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
+ // For integer-valued instructions, set up an initial empty set of alive
+ // bits and add the instruction to the work list. For other instructions
+ // add their operands to the work list (for integer values operands, mark
+ // all bits as live).
+ Type *T = I.getType();
+ if (T->isIntOrIntVectorTy()) {
+ if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
+ Worklist.insert(&I);
+
+ continue;
+ }
+
+ // Non-integer-typed instructions...
+ for (Use &OI : I.operands()) {
+ if (Instruction *J = dyn_cast<Instruction>(OI)) {
+ Type *T = J->getType();
+ if (T->isIntOrIntVectorTy())
+ AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
+ else
+ Visited.insert(J);
+ Worklist.insert(J);
+ }
+ }
+ // To save memory, we don't add I to the Visited set here. Instead, we
+ // check isAlwaysLive on every instruction when searching for dead
+ // instructions later (we need to check isAlwaysLive for the
+ // integer-typed instructions anyway).
+ }
+
+ // Propagate liveness backwards to operands.
+ while (!Worklist.empty()) {
+ Instruction *UserI = Worklist.pop_back_val();
+
+ LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
+ APInt AOut;
+ bool InputIsKnownDead = false;
+ if (UserI->getType()->isIntOrIntVectorTy()) {
+ AOut = AliveBits[UserI];
+ LLVM_DEBUG(dbgs() << " Alive Out: 0x"
+ << Twine::utohexstr(AOut.getLimitedValue()));
+
+ // If all bits of the output are dead, then all bits of the input
+ // are also dead.
+ InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
+ }
+ LLVM_DEBUG(dbgs() << "\n");
+
+ KnownBits Known, Known2;
+ bool KnownBitsComputed = false;
+ // Compute the set of alive bits for each operand. These are anded into the
+ // existing set, if any, and if that changes the set of alive bits, the
+ // operand is added to the work-list.
+ for (Use &OI : UserI->operands()) {
+ // We also want to detect dead uses of arguments, but will only store
+ // demanded bits for instructions.
+ Instruction *I = dyn_cast<Instruction>(OI);
+ if (!I && !isa<Argument>(OI))
+ continue;
+
+ Type *T = OI->getType();
+ if (T->isIntOrIntVectorTy()) {
+ unsigned BitWidth = T->getScalarSizeInBits();
+ APInt AB = APInt::getAllOnesValue(BitWidth);
+ if (InputIsKnownDead) {
+ AB = APInt(BitWidth, 0);
+ } else {
+ // Bits of each operand that are used to compute alive bits of the
+ // output are alive, all others are dead.
+ determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
+ Known, Known2, KnownBitsComputed);
+
+ // Keep track of uses which have no demanded bits.
+ if (AB.isNullValue())
+ DeadUses.insert(&OI);
+ else
+ DeadUses.erase(&OI);
+ }
+
+ if (I) {
+ // If we've added to the set of alive bits (or the operand has not
+ // been previously visited), then re-queue the operand to be visited
+ // again.
+ auto Res = AliveBits.try_emplace(I);
+ if (Res.second || (AB |= Res.first->second) != Res.first->second) {
+ Res.first->second = std::move(AB);
+ Worklist.insert(I);
+ }
+ }
+ } else if (I && Visited.insert(I).second) {
+ Worklist.insert(I);
+ }
+ }
+ }
+}
+
+APInt DemandedBits::getDemandedBits(Instruction *I) {
+ performAnalysis();
+
+ auto Found = AliveBits.find(I);
+ if (Found != AliveBits.end())
+ return Found->second;
+
+ const DataLayout &DL = I->getModule()->getDataLayout();
+ return APInt::getAllOnesValue(
+ DL.getTypeSizeInBits(I->getType()->getScalarType()));
+}
+
+bool DemandedBits::isInstructionDead(Instruction *I) {
+ performAnalysis();
+
+ return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
+ !isAlwaysLive(I);
+}
+
+bool DemandedBits::isUseDead(Use *U) {
+ // We only track integer uses, everything else is assumed live.
+ if (!(*U)->getType()->isIntOrIntVectorTy())
+ return false;
+
+ // Uses by always-live instructions are never dead.
+ Instruction *UserI = cast<Instruction>(U->getUser());
+ if (isAlwaysLive(UserI))
+ return false;
+
+ performAnalysis();
+ if (DeadUses.count(U))
+ return true;
+
+ // If no output bits are demanded, no input bits are demanded and the use
+ // is dead. These uses might not be explicitly present in the DeadUses map.
+ if (UserI->getType()->isIntOrIntVectorTy()) {
+ auto Found = AliveBits.find(UserI);
+ if (Found != AliveBits.end() && Found->second.isNullValue())
+ return true;
+ }
+
+ return false;
+}
+
+void DemandedBits::print(raw_ostream &OS) {
+ performAnalysis();
+ for (auto &KV : AliveBits) {
+ OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue())
+ << " for " << *KV.first << '\n';
+ }
+}
+
+FunctionPass *llvm::createDemandedBitsWrapperPass() {
+ return new DemandedBitsWrapperPass();
+}
+
+AnalysisKey DemandedBitsAnalysis::Key;
+
+DemandedBits DemandedBitsAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ return DemandedBits(F, AC, DT);
+}
+
+PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AM.getResult<DemandedBitsAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp
new file mode 100644
index 000000000000..0038c9fb9ce4
--- /dev/null
+++ b/llvm/lib/Analysis/DependenceAnalysis.cpp
@@ -0,0 +1,4009 @@
+//===-- DependenceAnalysis.cpp - DA Implementation --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DependenceAnalysis is an LLVM pass that analyses dependences between memory
+// accesses. Currently, it is an (incomplete) implementation of the approach
+// described in
+//
+// Practical Dependence Testing
+// Goff, Kennedy, Tseng
+// PLDI 1991
+//
+// There's a single entry point that analyzes the dependence between a pair
+// of memory references in a function, returning either NULL, for no dependence,
+// or a more-or-less detailed description of the dependence between them.
+//
+// Currently, the implementation cannot propagate constraints between
+// coupled RDIV subscripts and lacks a multi-subscript MIV test.
+// Both of these are conservative weaknesses;
+// that is, not a source of correctness problems.
+//
+// Since Clang linearizes some array subscripts, the dependence
+// analysis is using SCEV->delinearize to recover the representation of multiple
+// subscripts, and thus avoid the more expensive and less precise MIV tests. The
+// delinearization is controlled by the flag -da-delinearize.
+//
+// We should pay some careful attention to the possibility of integer overflow
+// in the implementation of the various tests. This could happen with Add,
+// Subtract, or Multiply, with both APInt's and SCEV's.
+//
+// Some non-linear subscript pairs can be handled by the GCD test
+// (and perhaps other tests).
+// Should explore how often these things occur.
+//
+// Finally, it seems like certain test cases expose weaknesses in the SCEV
+// simplification, especially in the handling of sign and zero extensions.
+// It could be useful to spend time exploring these.
+//
+// Please note that this is work in progress and the interface is subject to
+// change.
+//
+//===----------------------------------------------------------------------===//
+// //
+// In memory of Ken Kennedy, 1945 - 2007 //
+// //
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DependenceAnalysis.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "da"
+
+//===----------------------------------------------------------------------===//
+// statistics
+
+STATISTIC(TotalArrayPairs, "Array pairs tested");
+STATISTIC(SeparableSubscriptPairs, "Separable subscript pairs");
+STATISTIC(CoupledSubscriptPairs, "Coupled subscript pairs");
+STATISTIC(NonlinearSubscriptPairs, "Nonlinear subscript pairs");
+STATISTIC(ZIVapplications, "ZIV applications");
+STATISTIC(ZIVindependence, "ZIV independence");
+STATISTIC(StrongSIVapplications, "Strong SIV applications");
+STATISTIC(StrongSIVsuccesses, "Strong SIV successes");
+STATISTIC(StrongSIVindependence, "Strong SIV independence");
+STATISTIC(WeakCrossingSIVapplications, "Weak-Crossing SIV applications");
+STATISTIC(WeakCrossingSIVsuccesses, "Weak-Crossing SIV successes");
+STATISTIC(WeakCrossingSIVindependence, "Weak-Crossing SIV independence");
+STATISTIC(ExactSIVapplications, "Exact SIV applications");
+STATISTIC(ExactSIVsuccesses, "Exact SIV successes");
+STATISTIC(ExactSIVindependence, "Exact SIV independence");
+STATISTIC(WeakZeroSIVapplications, "Weak-Zero SIV applications");
+STATISTIC(WeakZeroSIVsuccesses, "Weak-Zero SIV successes");
+STATISTIC(WeakZeroSIVindependence, "Weak-Zero SIV independence");
+STATISTIC(ExactRDIVapplications, "Exact RDIV applications");
+STATISTIC(ExactRDIVindependence, "Exact RDIV independence");
+STATISTIC(SymbolicRDIVapplications, "Symbolic RDIV applications");
+STATISTIC(SymbolicRDIVindependence, "Symbolic RDIV independence");
+STATISTIC(DeltaApplications, "Delta applications");
+STATISTIC(DeltaSuccesses, "Delta successes");
+STATISTIC(DeltaIndependence, "Delta independence");
+STATISTIC(DeltaPropagations, "Delta propagations");
+STATISTIC(GCDapplications, "GCD applications");
+STATISTIC(GCDsuccesses, "GCD successes");
+STATISTIC(GCDindependence, "GCD independence");
+STATISTIC(BanerjeeApplications, "Banerjee applications");
+STATISTIC(BanerjeeIndependence, "Banerjee independence");
+STATISTIC(BanerjeeSuccesses, "Banerjee successes");
+
+static cl::opt<bool>
+ Delinearize("da-delinearize", cl::init(true), cl::Hidden, cl::ZeroOrMore,
+ cl::desc("Try to delinearize array references."));
+static cl::opt<bool> DisableDelinearizationChecks(
+ "da-disable-delinearization-checks", cl::init(false), cl::Hidden,
+ cl::ZeroOrMore,
+ cl::desc(
+ "Disable checks that try to statically verify validity of "
+ "delinearized subscripts. Enabling this option may result in incorrect "
+ "dependence vectors for languages that allow the subscript of one "
+ "dimension to underflow or overflow into another dimension."));
+
+//===----------------------------------------------------------------------===//
+// basics
+
+DependenceAnalysis::Result
+DependenceAnalysis::run(Function &F, FunctionAnalysisManager &FAM) {
+ auto &AA = FAM.getResult<AAManager>(F);
+ auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
+ auto &LI = FAM.getResult<LoopAnalysis>(F);
+ return DependenceInfo(&F, &AA, &SE, &LI);
+}
+
+AnalysisKey DependenceAnalysis::Key;
+
+INITIALIZE_PASS_BEGIN(DependenceAnalysisWrapperPass, "da",
+ "Dependence Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(DependenceAnalysisWrapperPass, "da", "Dependence Analysis",
+ true, true)
+
+char DependenceAnalysisWrapperPass::ID = 0;
+
+FunctionPass *llvm::createDependenceAnalysisWrapperPass() {
+ return new DependenceAnalysisWrapperPass();
+}
+
+bool DependenceAnalysisWrapperPass::runOnFunction(Function &F) {
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ info.reset(new DependenceInfo(&F, &AA, &SE, &LI));
+ return false;
+}
+
+DependenceInfo &DependenceAnalysisWrapperPass::getDI() const { return *info; }
+
+void DependenceAnalysisWrapperPass::releaseMemory() { info.reset(); }
+
+void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<AAResultsWrapperPass>();
+ AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
+ AU.addRequiredTransitive<LoopInfoWrapperPass>();
+}
+
+
+// Used to test the dependence analyzer.
+// Looks through the function, noting loads and stores.
+// Calls depends() on every possible pair and prints out the result.
+// Ignores all other instructions.
+static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA) {
+ auto *F = DA->getFunction();
+ for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE;
+ ++SrcI) {
+ if (isa<StoreInst>(*SrcI) || isa<LoadInst>(*SrcI)) {
+ for (inst_iterator DstI = SrcI, DstE = inst_end(F);
+ DstI != DstE; ++DstI) {
+ if (isa<StoreInst>(*DstI) || isa<LoadInst>(*DstI)) {
+ OS << "da analyze - ";
+ if (auto D = DA->depends(&*SrcI, &*DstI, true)) {
+ D->dump(OS);
+ for (unsigned Level = 1; Level <= D->getLevels(); Level++) {
+ if (D->isSplitable(Level)) {
+ OS << "da analyze - split level = " << Level;
+ OS << ", iteration = " << *DA->getSplitIteration(*D, Level);
+ OS << "!\n";
+ }
+ }
+ }
+ else
+ OS << "none!\n";
+ }
+ }
+ }
+ }
+}
+
+void DependenceAnalysisWrapperPass::print(raw_ostream &OS,
+ const Module *) const {
+ dumpExampleDependence(OS, info.get());
+}
+
+PreservedAnalyses
+DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
+ OS << "'Dependence Analysis' for function '" << F.getName() << "':\n";
+ dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F));
+ return PreservedAnalyses::all();
+}
+
+//===----------------------------------------------------------------------===//
+// Dependence methods
+
+// Returns true if this is an input dependence.
+bool Dependence::isInput() const {
+ return Src->mayReadFromMemory() && Dst->mayReadFromMemory();
+}
+
+
+// Returns true if this is an output dependence.
+bool Dependence::isOutput() const {
+ return Src->mayWriteToMemory() && Dst->mayWriteToMemory();
+}
+
+
+// Returns true if this is an flow (aka true) dependence.
+bool Dependence::isFlow() const {
+ return Src->mayWriteToMemory() && Dst->mayReadFromMemory();
+}
+
+
+// Returns true if this is an anti dependence.
+bool Dependence::isAnti() const {
+ return Src->mayReadFromMemory() && Dst->mayWriteToMemory();
+}
+
+
+// Returns true if a particular level is scalar; that is,
+// if no subscript in the source or destination mention the induction
+// variable associated with the loop at this level.
+// Leave this out of line, so it will serve as a virtual method anchor
+bool Dependence::isScalar(unsigned level) const {
+ return false;
+}
+
+
+//===----------------------------------------------------------------------===//
+// FullDependence methods
+
+FullDependence::FullDependence(Instruction *Source, Instruction *Destination,
+ bool PossiblyLoopIndependent,
+ unsigned CommonLevels)
+ : Dependence(Source, Destination), Levels(CommonLevels),
+ LoopIndependent(PossiblyLoopIndependent) {
+ Consistent = true;
+ if (CommonLevels)
+ DV = std::make_unique<DVEntry[]>(CommonLevels);
+}
+
+// The rest are simple getters that hide the implementation.
+
+// getDirection - Returns the direction associated with a particular level.
+unsigned FullDependence::getDirection(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].Direction;
+}
+
+
+// Returns the distance (or NULL) associated with a particular level.
+const SCEV *FullDependence::getDistance(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].Distance;
+}
+
+
+// Returns true if a particular level is scalar; that is,
+// if no subscript in the source or destination mention the induction
+// variable associated with the loop at this level.
+bool FullDependence::isScalar(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].Scalar;
+}
+
+
+// Returns true if peeling the first iteration from this loop
+// will break this dependence.
+bool FullDependence::isPeelFirst(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].PeelFirst;
+}
+
+
+// Returns true if peeling the last iteration from this loop
+// will break this dependence.
+bool FullDependence::isPeelLast(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].PeelLast;
+}
+
+
+// Returns true if splitting this loop will break the dependence.
+bool FullDependence::isSplitable(unsigned Level) const {
+ assert(0 < Level && Level <= Levels && "Level out of range");
+ return DV[Level - 1].Splitable;
+}
+
+
+//===----------------------------------------------------------------------===//
+// DependenceInfo::Constraint methods
+
+// If constraint is a point <X, Y>, returns X.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getX() const {
+ assert(Kind == Point && "Kind should be Point");
+ return A;
+}
+
+
+// If constraint is a point <X, Y>, returns Y.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getY() const {
+ assert(Kind == Point && "Kind should be Point");
+ return B;
+}
+
+
+// If constraint is a line AX + BY = C, returns A.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getA() const {
+ assert((Kind == Line || Kind == Distance) &&
+ "Kind should be Line (or Distance)");
+ return A;
+}
+
+
+// If constraint is a line AX + BY = C, returns B.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getB() const {
+ assert((Kind == Line || Kind == Distance) &&
+ "Kind should be Line (or Distance)");
+ return B;
+}
+
+
+// If constraint is a line AX + BY = C, returns C.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getC() const {
+ assert((Kind == Line || Kind == Distance) &&
+ "Kind should be Line (or Distance)");
+ return C;
+}
+
+
+// If constraint is a distance, returns D.
+// Otherwise assert.
+const SCEV *DependenceInfo::Constraint::getD() const {
+ assert(Kind == Distance && "Kind should be Distance");
+ return SE->getNegativeSCEV(C);
+}
+
+
+// Returns the loop associated with this constraint.
+const Loop *DependenceInfo::Constraint::getAssociatedLoop() const {
+ assert((Kind == Distance || Kind == Line || Kind == Point) &&
+ "Kind should be Distance, Line, or Point");
+ return AssociatedLoop;
+}
+
+void DependenceInfo::Constraint::setPoint(const SCEV *X, const SCEV *Y,
+ const Loop *CurLoop) {
+ Kind = Point;
+ A = X;
+ B = Y;
+ AssociatedLoop = CurLoop;
+}
+
+void DependenceInfo::Constraint::setLine(const SCEV *AA, const SCEV *BB,
+ const SCEV *CC, const Loop *CurLoop) {
+ Kind = Line;
+ A = AA;
+ B = BB;
+ C = CC;
+ AssociatedLoop = CurLoop;
+}
+
+void DependenceInfo::Constraint::setDistance(const SCEV *D,
+ const Loop *CurLoop) {
+ Kind = Distance;
+ A = SE->getOne(D->getType());
+ B = SE->getNegativeSCEV(A);
+ C = SE->getNegativeSCEV(D);
+ AssociatedLoop = CurLoop;
+}
+
+void DependenceInfo::Constraint::setEmpty() { Kind = Empty; }
+
+void DependenceInfo::Constraint::setAny(ScalarEvolution *NewSE) {
+ SE = NewSE;
+ Kind = Any;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+// For debugging purposes. Dumps the constraint out to OS.
+LLVM_DUMP_METHOD void DependenceInfo::Constraint::dump(raw_ostream &OS) const {
+ if (isEmpty())
+ OS << " Empty\n";
+ else if (isAny())
+ OS << " Any\n";
+ else if (isPoint())
+ OS << " Point is <" << *getX() << ", " << *getY() << ">\n";
+ else if (isDistance())
+ OS << " Distance is " << *getD() <<
+ " (" << *getA() << "*X + " << *getB() << "*Y = " << *getC() << ")\n";
+ else if (isLine())
+ OS << " Line is " << *getA() << "*X + " <<
+ *getB() << "*Y = " << *getC() << "\n";
+ else
+ llvm_unreachable("unknown constraint type in Constraint::dump");
+}
+#endif
+
+
+// Updates X with the intersection
+// of the Constraints X and Y. Returns true if X has changed.
+// Corresponds to Figure 4 from the paper
+//
+// Practical Dependence Testing
+// Goff, Kennedy, Tseng
+// PLDI 1991
+bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) {
+ ++DeltaApplications;
+ LLVM_DEBUG(dbgs() << "\tintersect constraints\n");
+ LLVM_DEBUG(dbgs() << "\t X ="; X->dump(dbgs()));
+ LLVM_DEBUG(dbgs() << "\t Y ="; Y->dump(dbgs()));
+ assert(!Y->isPoint() && "Y must not be a Point");
+ if (X->isAny()) {
+ if (Y->isAny())
+ return false;
+ *X = *Y;
+ return true;
+ }
+ if (X->isEmpty())
+ return false;
+ if (Y->isEmpty()) {
+ X->setEmpty();
+ return true;
+ }
+
+ if (X->isDistance() && Y->isDistance()) {
+ LLVM_DEBUG(dbgs() << "\t intersect 2 distances\n");
+ if (isKnownPredicate(CmpInst::ICMP_EQ, X->getD(), Y->getD()))
+ return false;
+ if (isKnownPredicate(CmpInst::ICMP_NE, X->getD(), Y->getD())) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ // Hmmm, interesting situation.
+ // I guess if either is constant, keep it and ignore the other.
+ if (isa<SCEVConstant>(Y->getD())) {
+ *X = *Y;
+ return true;
+ }
+ return false;
+ }
+
+ // At this point, the pseudo-code in Figure 4 of the paper
+ // checks if (X->isPoint() && Y->isPoint()).
+ // This case can't occur in our implementation,
+ // since a Point can only arise as the result of intersecting
+ // two Line constraints, and the right-hand value, Y, is never
+ // the result of an intersection.
+ assert(!(X->isPoint() && Y->isPoint()) &&
+ "We shouldn't ever see X->isPoint() && Y->isPoint()");
+
+ if (X->isLine() && Y->isLine()) {
+ LLVM_DEBUG(dbgs() << "\t intersect 2 lines\n");
+ const SCEV *Prod1 = SE->getMulExpr(X->getA(), Y->getB());
+ const SCEV *Prod2 = SE->getMulExpr(X->getB(), Y->getA());
+ if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2)) {
+ // slopes are equal, so lines are parallel
+ LLVM_DEBUG(dbgs() << "\t\tsame slope\n");
+ Prod1 = SE->getMulExpr(X->getC(), Y->getB());
+ Prod2 = SE->getMulExpr(X->getB(), Y->getC());
+ if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2))
+ return false;
+ if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ return false;
+ }
+ if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) {
+ // slopes differ, so lines intersect
+ LLVM_DEBUG(dbgs() << "\t\tdifferent slopes\n");
+ const SCEV *C1B2 = SE->getMulExpr(X->getC(), Y->getB());
+ const SCEV *C1A2 = SE->getMulExpr(X->getC(), Y->getA());
+ const SCEV *C2B1 = SE->getMulExpr(Y->getC(), X->getB());
+ const SCEV *C2A1 = SE->getMulExpr(Y->getC(), X->getA());
+ const SCEV *A1B2 = SE->getMulExpr(X->getA(), Y->getB());
+ const SCEV *A2B1 = SE->getMulExpr(Y->getA(), X->getB());
+ const SCEVConstant *C1A2_C2A1 =
+ dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1A2, C2A1));
+ const SCEVConstant *C1B2_C2B1 =
+ dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1B2, C2B1));
+ const SCEVConstant *A1B2_A2B1 =
+ dyn_cast<SCEVConstant>(SE->getMinusSCEV(A1B2, A2B1));
+ const SCEVConstant *A2B1_A1B2 =
+ dyn_cast<SCEVConstant>(SE->getMinusSCEV(A2B1, A1B2));
+ if (!C1B2_C2B1 || !C1A2_C2A1 ||
+ !A1B2_A2B1 || !A2B1_A1B2)
+ return false;
+ APInt Xtop = C1B2_C2B1->getAPInt();
+ APInt Xbot = A1B2_A2B1->getAPInt();
+ APInt Ytop = C1A2_C2A1->getAPInt();
+ APInt Ybot = A2B1_A1B2->getAPInt();
+ LLVM_DEBUG(dbgs() << "\t\tXtop = " << Xtop << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tXbot = " << Xbot << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tYtop = " << Ytop << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tYbot = " << Ybot << "\n");
+ APInt Xq = Xtop; // these need to be initialized, even
+ APInt Xr = Xtop; // though they're just going to be overwritten
+ APInt::sdivrem(Xtop, Xbot, Xq, Xr);
+ APInt Yq = Ytop;
+ APInt Yr = Ytop;
+ APInt::sdivrem(Ytop, Ybot, Yq, Yr);
+ if (Xr != 0 || Yr != 0) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ LLVM_DEBUG(dbgs() << "\t\tX = " << Xq << ", Y = " << Yq << "\n");
+ if (Xq.slt(0) || Yq.slt(0)) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ if (const SCEVConstant *CUB =
+ collectConstantUpperBound(X->getAssociatedLoop(), Prod1->getType())) {
+ const APInt &UpperBound = CUB->getAPInt();
+ LLVM_DEBUG(dbgs() << "\t\tupper bound = " << UpperBound << "\n");
+ if (Xq.sgt(UpperBound) || Yq.sgt(UpperBound)) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ }
+ X->setPoint(SE->getConstant(Xq),
+ SE->getConstant(Yq),
+ X->getAssociatedLoop());
+ ++DeltaSuccesses;
+ return true;
+ }
+ return false;
+ }
+
+ // if (X->isLine() && Y->isPoint()) This case can't occur.
+ assert(!(X->isLine() && Y->isPoint()) && "This case should never occur");
+
+ if (X->isPoint() && Y->isLine()) {
+ LLVM_DEBUG(dbgs() << "\t intersect Point and Line\n");
+ const SCEV *A1X1 = SE->getMulExpr(Y->getA(), X->getX());
+ const SCEV *B1Y1 = SE->getMulExpr(Y->getB(), X->getY());
+ const SCEV *Sum = SE->getAddExpr(A1X1, B1Y1);
+ if (isKnownPredicate(CmpInst::ICMP_EQ, Sum, Y->getC()))
+ return false;
+ if (isKnownPredicate(CmpInst::ICMP_NE, Sum, Y->getC())) {
+ X->setEmpty();
+ ++DeltaSuccesses;
+ return true;
+ }
+ return false;
+ }
+
+ llvm_unreachable("shouldn't reach the end of Constraint intersection");
+ return false;
+}
+
+
+//===----------------------------------------------------------------------===//
+// DependenceInfo methods
+
+// For debugging purposes. Dumps a dependence to OS.
+void Dependence::dump(raw_ostream &OS) const {
+ bool Splitable = false;
+ if (isConfused())
+ OS << "confused";
+ else {
+ if (isConsistent())
+ OS << "consistent ";
+ if (isFlow())
+ OS << "flow";
+ else if (isOutput())
+ OS << "output";
+ else if (isAnti())
+ OS << "anti";
+ else if (isInput())
+ OS << "input";
+ unsigned Levels = getLevels();
+ OS << " [";
+ for (unsigned II = 1; II <= Levels; ++II) {
+ if (isSplitable(II))
+ Splitable = true;
+ if (isPeelFirst(II))
+ OS << 'p';
+ const SCEV *Distance = getDistance(II);
+ if (Distance)
+ OS << *Distance;
+ else if (isScalar(II))
+ OS << "S";
+ else {
+ unsigned Direction = getDirection(II);
+ if (Direction == DVEntry::ALL)
+ OS << "*";
+ else {
+ if (Direction & DVEntry::LT)
+ OS << "<";
+ if (Direction & DVEntry::EQ)
+ OS << "=";
+ if (Direction & DVEntry::GT)
+ OS << ">";
+ }
+ }
+ if (isPeelLast(II))
+ OS << 'p';
+ if (II < Levels)
+ OS << " ";
+ }
+ if (isLoopIndependent())
+ OS << "|<";
+ OS << "]";
+ if (Splitable)
+ OS << " splitable";
+ }
+ OS << "!\n";
+}
+
+// Returns NoAlias/MayAliass/MustAlias for two memory locations based upon their
+// underlaying objects. If LocA and LocB are known to not alias (for any reason:
+// tbaa, non-overlapping regions etc), then it is known there is no dependecy.
+// Otherwise the underlying objects are checked to see if they point to
+// different identifiable objects.
+static AliasResult underlyingObjectsAlias(AliasAnalysis *AA,
+ const DataLayout &DL,
+ const MemoryLocation &LocA,
+ const MemoryLocation &LocB) {
+ // Check the original locations (minus size) for noalias, which can happen for
+ // tbaa, incompatible underlying object locations, etc.
+ MemoryLocation LocAS(LocA.Ptr, LocationSize::unknown(), LocA.AATags);
+ MemoryLocation LocBS(LocB.Ptr, LocationSize::unknown(), LocB.AATags);
+ if (AA->alias(LocAS, LocBS) == NoAlias)
+ return NoAlias;
+
+ // Check the underlying objects are the same
+ const Value *AObj = GetUnderlyingObject(LocA.Ptr, DL);
+ const Value *BObj = GetUnderlyingObject(LocB.Ptr, DL);
+
+ // If the underlying objects are the same, they must alias
+ if (AObj == BObj)
+ return MustAlias;
+
+ // We may have hit the recursion limit for underlying objects, or have
+ // underlying objects where we don't know they will alias.
+ if (!isIdentifiedObject(AObj) || !isIdentifiedObject(BObj))
+ return MayAlias;
+
+ // Otherwise we know the objects are different and both identified objects so
+ // must not alias.
+ return NoAlias;
+}
+
+
+// Returns true if the load or store can be analyzed. Atomic and volatile
+// operations have properties which this analysis does not understand.
+static
+bool isLoadOrStore(const Instruction *I) {
+ if (const LoadInst *LI = dyn_cast<LoadInst>(I))
+ return LI->isUnordered();
+ else if (const StoreInst *SI = dyn_cast<StoreInst>(I))
+ return SI->isUnordered();
+ return false;
+}
+
+
+// Examines the loop nesting of the Src and Dst
+// instructions and establishes their shared loops. Sets the variables
+// CommonLevels, SrcLevels, and MaxLevels.
+// The source and destination instructions needn't be contained in the same
+// loop. The routine establishNestingLevels finds the level of most deeply
+// nested loop that contains them both, CommonLevels. An instruction that's
+// not contained in a loop is at level = 0. MaxLevels is equal to the level
+// of the source plus the level of the destination, minus CommonLevels.
+// This lets us allocate vectors MaxLevels in length, with room for every
+// distinct loop referenced in both the source and destination subscripts.
+// The variable SrcLevels is the nesting depth of the source instruction.
+// It's used to help calculate distinct loops referenced by the destination.
+// Here's the map from loops to levels:
+// 0 - unused
+// 1 - outermost common loop
+// ... - other common loops
+// CommonLevels - innermost common loop
+// ... - loops containing Src but not Dst
+// SrcLevels - innermost loop containing Src but not Dst
+// ... - loops containing Dst but not Src
+// MaxLevels - innermost loops containing Dst but not Src
+// Consider the follow code fragment:
+// for (a = ...) {
+// for (b = ...) {
+// for (c = ...) {
+// for (d = ...) {
+// A[] = ...;
+// }
+// }
+// for (e = ...) {
+// for (f = ...) {
+// for (g = ...) {
+// ... = A[];
+// }
+// }
+// }
+// }
+// }
+// If we're looking at the possibility of a dependence between the store
+// to A (the Src) and the load from A (the Dst), we'll note that they
+// have 2 loops in common, so CommonLevels will equal 2 and the direction
+// vector for Result will have 2 entries. SrcLevels = 4 and MaxLevels = 7.
+// A map from loop names to loop numbers would look like
+// a - 1
+// b - 2 = CommonLevels
+// c - 3
+// d - 4 = SrcLevels
+// e - 5
+// f - 6
+// g - 7 = MaxLevels
+void DependenceInfo::establishNestingLevels(const Instruction *Src,
+ const Instruction *Dst) {
+ const BasicBlock *SrcBlock = Src->getParent();
+ const BasicBlock *DstBlock = Dst->getParent();
+ unsigned SrcLevel = LI->getLoopDepth(SrcBlock);
+ unsigned DstLevel = LI->getLoopDepth(DstBlock);
+ const Loop *SrcLoop = LI->getLoopFor(SrcBlock);
+ const Loop *DstLoop = LI->getLoopFor(DstBlock);
+ SrcLevels = SrcLevel;
+ MaxLevels = SrcLevel + DstLevel;
+ while (SrcLevel > DstLevel) {
+ SrcLoop = SrcLoop->getParentLoop();
+ SrcLevel--;
+ }
+ while (DstLevel > SrcLevel) {
+ DstLoop = DstLoop->getParentLoop();
+ DstLevel--;
+ }
+ while (SrcLoop != DstLoop) {
+ SrcLoop = SrcLoop->getParentLoop();
+ DstLoop = DstLoop->getParentLoop();
+ SrcLevel--;
+ }
+ CommonLevels = SrcLevel;
+ MaxLevels -= CommonLevels;
+}
+
+
+// Given one of the loops containing the source, return
+// its level index in our numbering scheme.
+unsigned DependenceInfo::mapSrcLoop(const Loop *SrcLoop) const {
+ return SrcLoop->getLoopDepth();
+}
+
+
+// Given one of the loops containing the destination,
+// return its level index in our numbering scheme.
+unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const {
+ unsigned D = DstLoop->getLoopDepth();
+ if (D > CommonLevels)
+ return D - CommonLevels + SrcLevels;
+ else
+ return D;
+}
+
+
+// Returns true if Expression is loop invariant in LoopNest.
+bool DependenceInfo::isLoopInvariant(const SCEV *Expression,
+ const Loop *LoopNest) const {
+ if (!LoopNest)
+ return true;
+ return SE->isLoopInvariant(Expression, LoopNest) &&
+ isLoopInvariant(Expression, LoopNest->getParentLoop());
+}
+
+
+
+// Finds the set of loops from the LoopNest that
+// have a level <= CommonLevels and are referred to by the SCEV Expression.
+void DependenceInfo::collectCommonLoops(const SCEV *Expression,
+ const Loop *LoopNest,
+ SmallBitVector &Loops) const {
+ while (LoopNest) {
+ unsigned Level = LoopNest->getLoopDepth();
+ if (Level <= CommonLevels && !SE->isLoopInvariant(Expression, LoopNest))
+ Loops.set(Level);
+ LoopNest = LoopNest->getParentLoop();
+ }
+}
+
+void DependenceInfo::unifySubscriptType(ArrayRef<Subscript *> Pairs) {
+
+ unsigned widestWidthSeen = 0;
+ Type *widestType;
+
+ // Go through each pair and find the widest bit to which we need
+ // to extend all of them.
+ for (Subscript *Pair : Pairs) {
+ const SCEV *Src = Pair->Src;
+ const SCEV *Dst = Pair->Dst;
+ IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType());
+ IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType());
+ if (SrcTy == nullptr || DstTy == nullptr) {
+ assert(SrcTy == DstTy && "This function only unify integer types and "
+ "expect Src and Dst share the same type "
+ "otherwise.");
+ continue;
+ }
+ if (SrcTy->getBitWidth() > widestWidthSeen) {
+ widestWidthSeen = SrcTy->getBitWidth();
+ widestType = SrcTy;
+ }
+ if (DstTy->getBitWidth() > widestWidthSeen) {
+ widestWidthSeen = DstTy->getBitWidth();
+ widestType = DstTy;
+ }
+ }
+
+
+ assert(widestWidthSeen > 0);
+
+ // Now extend each pair to the widest seen.
+ for (Subscript *Pair : Pairs) {
+ const SCEV *Src = Pair->Src;
+ const SCEV *Dst = Pair->Dst;
+ IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType());
+ IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType());
+ if (SrcTy == nullptr || DstTy == nullptr) {
+ assert(SrcTy == DstTy && "This function only unify integer types and "
+ "expect Src and Dst share the same type "
+ "otherwise.");
+ continue;
+ }
+ if (SrcTy->getBitWidth() < widestWidthSeen)
+ // Sign-extend Src to widestType
+ Pair->Src = SE->getSignExtendExpr(Src, widestType);
+ if (DstTy->getBitWidth() < widestWidthSeen) {
+ // Sign-extend Dst to widestType
+ Pair->Dst = SE->getSignExtendExpr(Dst, widestType);
+ }
+ }
+}
+
+// removeMatchingExtensions - Examines a subscript pair.
+// If the source and destination are identically sign (or zero)
+// extended, it strips off the extension in an effect to simplify
+// the actual analysis.
+void DependenceInfo::removeMatchingExtensions(Subscript *Pair) {
+ const SCEV *Src = Pair->Src;
+ const SCEV *Dst = Pair->Dst;
+ if ((isa<SCEVZeroExtendExpr>(Src) && isa<SCEVZeroExtendExpr>(Dst)) ||
+ (isa<SCEVSignExtendExpr>(Src) && isa<SCEVSignExtendExpr>(Dst))) {
+ const SCEVCastExpr *SrcCast = cast<SCEVCastExpr>(Src);
+ const SCEVCastExpr *DstCast = cast<SCEVCastExpr>(Dst);
+ const SCEV *SrcCastOp = SrcCast->getOperand();
+ const SCEV *DstCastOp = DstCast->getOperand();
+ if (SrcCastOp->getType() == DstCastOp->getType()) {
+ Pair->Src = SrcCastOp;
+ Pair->Dst = DstCastOp;
+ }
+ }
+}
+
+
+// Examine the scev and return true iff it's linear.
+// Collect any loops mentioned in the set of "Loops".
+bool DependenceInfo::checkSrcSubscript(const SCEV *Src, const Loop *LoopNest,
+ SmallBitVector &Loops) {
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Src);
+ if (!AddRec)
+ return isLoopInvariant(Src, LoopNest);
+ const SCEV *Start = AddRec->getStart();
+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
+ const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop());
+ if (!isa<SCEVCouldNotCompute>(UB)) {
+ if (SE->getTypeSizeInBits(Start->getType()) <
+ SE->getTypeSizeInBits(UB->getType())) {
+ if (!AddRec->getNoWrapFlags())
+ return false;
+ }
+ }
+ if (!isLoopInvariant(Step, LoopNest))
+ return false;
+ Loops.set(mapSrcLoop(AddRec->getLoop()));
+ return checkSrcSubscript(Start, LoopNest, Loops);
+}
+
+
+
+// Examine the scev and return true iff it's linear.
+// Collect any loops mentioned in the set of "Loops".
+bool DependenceInfo::checkDstSubscript(const SCEV *Dst, const Loop *LoopNest,
+ SmallBitVector &Loops) {
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Dst);
+ if (!AddRec)
+ return isLoopInvariant(Dst, LoopNest);
+ const SCEV *Start = AddRec->getStart();
+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
+ const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop());
+ if (!isa<SCEVCouldNotCompute>(UB)) {
+ if (SE->getTypeSizeInBits(Start->getType()) <
+ SE->getTypeSizeInBits(UB->getType())) {
+ if (!AddRec->getNoWrapFlags())
+ return false;
+ }
+ }
+ if (!isLoopInvariant(Step, LoopNest))
+ return false;
+ Loops.set(mapDstLoop(AddRec->getLoop()));
+ return checkDstSubscript(Start, LoopNest, Loops);
+}
+
+
+// Examines the subscript pair (the Src and Dst SCEVs)
+// and classifies it as either ZIV, SIV, RDIV, MIV, or Nonlinear.
+// Collects the associated loops in a set.
+DependenceInfo::Subscript::ClassificationKind
+DependenceInfo::classifyPair(const SCEV *Src, const Loop *SrcLoopNest,
+ const SCEV *Dst, const Loop *DstLoopNest,
+ SmallBitVector &Loops) {
+ SmallBitVector SrcLoops(MaxLevels + 1);
+ SmallBitVector DstLoops(MaxLevels + 1);
+ if (!checkSrcSubscript(Src, SrcLoopNest, SrcLoops))
+ return Subscript::NonLinear;
+ if (!checkDstSubscript(Dst, DstLoopNest, DstLoops))
+ return Subscript::NonLinear;
+ Loops = SrcLoops;
+ Loops |= DstLoops;
+ unsigned N = Loops.count();
+ if (N == 0)
+ return Subscript::ZIV;
+ if (N == 1)
+ return Subscript::SIV;
+ if (N == 2 && (SrcLoops.count() == 0 ||
+ DstLoops.count() == 0 ||
+ (SrcLoops.count() == 1 && DstLoops.count() == 1)))
+ return Subscript::RDIV;
+ return Subscript::MIV;
+}
+
+
+// A wrapper around SCEV::isKnownPredicate.
+// Looks for cases where we're interested in comparing for equality.
+// If both X and Y have been identically sign or zero extended,
+// it strips off the (confusing) extensions before invoking
+// SCEV::isKnownPredicate. Perhaps, someday, the ScalarEvolution package
+// will be similarly updated.
+//
+// If SCEV::isKnownPredicate can't prove the predicate,
+// we try simple subtraction, which seems to help in some cases
+// involving symbolics.
+bool DependenceInfo::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *X,
+ const SCEV *Y) const {
+ if (Pred == CmpInst::ICMP_EQ ||
+ Pred == CmpInst::ICMP_NE) {
+ if ((isa<SCEVSignExtendExpr>(X) &&
+ isa<SCEVSignExtendExpr>(Y)) ||
+ (isa<SCEVZeroExtendExpr>(X) &&
+ isa<SCEVZeroExtendExpr>(Y))) {
+ const SCEVCastExpr *CX = cast<SCEVCastExpr>(X);
+ const SCEVCastExpr *CY = cast<SCEVCastExpr>(Y);
+ const SCEV *Xop = CX->getOperand();
+ const SCEV *Yop = CY->getOperand();
+ if (Xop->getType() == Yop->getType()) {
+ X = Xop;
+ Y = Yop;
+ }
+ }
+ }
+ if (SE->isKnownPredicate(Pred, X, Y))
+ return true;
+ // If SE->isKnownPredicate can't prove the condition,
+ // we try the brute-force approach of subtracting
+ // and testing the difference.
+ // By testing with SE->isKnownPredicate first, we avoid
+ // the possibility of overflow when the arguments are constants.
+ const SCEV *Delta = SE->getMinusSCEV(X, Y);
+ switch (Pred) {
+ case CmpInst::ICMP_EQ:
+ return Delta->isZero();
+ case CmpInst::ICMP_NE:
+ return SE->isKnownNonZero(Delta);
+ case CmpInst::ICMP_SGE:
+ return SE->isKnownNonNegative(Delta);
+ case CmpInst::ICMP_SLE:
+ return SE->isKnownNonPositive(Delta);
+ case CmpInst::ICMP_SGT:
+ return SE->isKnownPositive(Delta);
+ case CmpInst::ICMP_SLT:
+ return SE->isKnownNegative(Delta);
+ default:
+ llvm_unreachable("unexpected predicate in isKnownPredicate");
+ }
+}
+
+/// Compare to see if S is less than Size, using isKnownNegative(S - max(Size, 1))
+/// with some extra checking if S is an AddRec and we can prove less-than using
+/// the loop bounds.
+bool DependenceInfo::isKnownLessThan(const SCEV *S, const SCEV *Size) const {
+ // First unify to the same type
+ auto *SType = dyn_cast<IntegerType>(S->getType());
+ auto *SizeType = dyn_cast<IntegerType>(Size->getType());
+ if (!SType || !SizeType)
+ return false;
+ Type *MaxType =
+ (SType->getBitWidth() >= SizeType->getBitWidth()) ? SType : SizeType;
+ S = SE->getTruncateOrZeroExtend(S, MaxType);
+ Size = SE->getTruncateOrZeroExtend(Size, MaxType);
+
+ // Special check for addrecs using BE taken count
+ const SCEV *Bound = SE->getMinusSCEV(S, Size);
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Bound)) {
+ if (AddRec->isAffine()) {
+ const SCEV *BECount = SE->getBackedgeTakenCount(AddRec->getLoop());
+ if (!isa<SCEVCouldNotCompute>(BECount)) {
+ const SCEV *Limit = AddRec->evaluateAtIteration(BECount, *SE);
+ if (SE->isKnownNegative(Limit))
+ return true;
+ }
+ }
+ }
+
+ // Check using normal isKnownNegative
+ const SCEV *LimitedBound =
+ SE->getMinusSCEV(S, SE->getSMaxExpr(Size, SE->getOne(Size->getType())));
+ return SE->isKnownNegative(LimitedBound);
+}
+
+bool DependenceInfo::isKnownNonNegative(const SCEV *S, const Value *Ptr) const {
+ bool Inbounds = false;
+ if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(Ptr))
+ Inbounds = SrcGEP->isInBounds();
+ if (Inbounds) {
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
+ if (AddRec->isAffine()) {
+ // We know S is for Ptr, the operand on a load/store, so doesn't wrap.
+ // If both parts are NonNegative, the end result will be NonNegative
+ if (SE->isKnownNonNegative(AddRec->getStart()) &&
+ SE->isKnownNonNegative(AddRec->getOperand(1)))
+ return true;
+ }
+ }
+ }
+
+ return SE->isKnownNonNegative(S);
+}
+
+// All subscripts are all the same type.
+// Loop bound may be smaller (e.g., a char).
+// Should zero extend loop bound, since it's always >= 0.
+// This routine collects upper bound and extends or truncates if needed.
+// Truncating is safe when subscripts are known not to wrap. Cases without
+// nowrap flags should have been rejected earlier.
+// Return null if no bound available.
+const SCEV *DependenceInfo::collectUpperBound(const Loop *L, Type *T) const {
+ if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
+ const SCEV *UB = SE->getBackedgeTakenCount(L);
+ return SE->getTruncateOrZeroExtend(UB, T);
+ }
+ return nullptr;
+}
+
+
+// Calls collectUpperBound(), then attempts to cast it to SCEVConstant.
+// If the cast fails, returns NULL.
+const SCEVConstant *DependenceInfo::collectConstantUpperBound(const Loop *L,
+ Type *T) const {
+ if (const SCEV *UB = collectUpperBound(L, T))
+ return dyn_cast<SCEVConstant>(UB);
+ return nullptr;
+}
+
+
+// testZIV -
+// When we have a pair of subscripts of the form [c1] and [c2],
+// where c1 and c2 are both loop invariant, we attack it using
+// the ZIV test. Basically, we test by comparing the two values,
+// but there are actually three possible results:
+// 1) the values are equal, so there's a dependence
+// 2) the values are different, so there's no dependence
+// 3) the values might be equal, so we have to assume a dependence.
+//
+// Return true if dependence disproved.
+bool DependenceInfo::testZIV(const SCEV *Src, const SCEV *Dst,
+ FullDependence &Result) const {
+ LLVM_DEBUG(dbgs() << " src = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n");
+ ++ZIVapplications;
+ if (isKnownPredicate(CmpInst::ICMP_EQ, Src, Dst)) {
+ LLVM_DEBUG(dbgs() << " provably dependent\n");
+ return false; // provably dependent
+ }
+ if (isKnownPredicate(CmpInst::ICMP_NE, Src, Dst)) {
+ LLVM_DEBUG(dbgs() << " provably independent\n");
+ ++ZIVindependence;
+ return true; // provably independent
+ }
+ LLVM_DEBUG(dbgs() << " possibly dependent\n");
+ Result.Consistent = false;
+ return false; // possibly dependent
+}
+
+
+// strongSIVtest -
+// From the paper, Practical Dependence Testing, Section 4.2.1
+//
+// When we have a pair of subscripts of the form [c1 + a*i] and [c2 + a*i],
+// where i is an induction variable, c1 and c2 are loop invariant,
+// and a is a constant, we can solve it exactly using the Strong SIV test.
+//
+// Can prove independence. Failing that, can compute distance (and direction).
+// In the presence of symbolic terms, we can sometimes make progress.
+//
+// If there's a dependence,
+//
+// c1 + a*i = c2 + a*i'
+//
+// The dependence distance is
+//
+// d = i' - i = (c1 - c2)/a
+//
+// A dependence only exists if d is an integer and abs(d) <= U, where U is the
+// loop's upper bound. If a dependence exists, the dependence direction is
+// defined as
+//
+// { < if d > 0
+// direction = { = if d = 0
+// { > if d < 0
+//
+// Return true if dependence disproved.
+bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst,
+ const SCEV *DstConst, const Loop *CurLoop,
+ unsigned Level, FullDependence &Result,
+ Constraint &NewConstraint) const {
+ LLVM_DEBUG(dbgs() << "\tStrong SIV test\n");
+ LLVM_DEBUG(dbgs() << "\t Coeff = " << *Coeff);
+ LLVM_DEBUG(dbgs() << ", " << *Coeff->getType() << "\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst);
+ LLVM_DEBUG(dbgs() << ", " << *SrcConst->getType() << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst);
+ LLVM_DEBUG(dbgs() << ", " << *DstConst->getType() << "\n");
+ ++StrongSIVapplications;
+ assert(0 < Level && Level <= CommonLevels && "level out of range");
+ Level--;
+
+ const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta);
+ LLVM_DEBUG(dbgs() << ", " << *Delta->getType() << "\n");
+
+ // check that |Delta| < iteration count
+ if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) {
+ LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound);
+ LLVM_DEBUG(dbgs() << ", " << *UpperBound->getType() << "\n");
+ const SCEV *AbsDelta =
+ SE->isKnownNonNegative(Delta) ? Delta : SE->getNegativeSCEV(Delta);
+ const SCEV *AbsCoeff =
+ SE->isKnownNonNegative(Coeff) ? Coeff : SE->getNegativeSCEV(Coeff);
+ const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff);
+ if (isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product)) {
+ // Distance greater than trip count - no dependence
+ ++StrongSIVindependence;
+ ++StrongSIVsuccesses;
+ return true;
+ }
+ }
+
+ // Can we compute distance?
+ if (isa<SCEVConstant>(Delta) && isa<SCEVConstant>(Coeff)) {
+ APInt ConstDelta = cast<SCEVConstant>(Delta)->getAPInt();
+ APInt ConstCoeff = cast<SCEVConstant>(Coeff)->getAPInt();
+ APInt Distance = ConstDelta; // these need to be initialized
+ APInt Remainder = ConstDelta;
+ APInt::sdivrem(ConstDelta, ConstCoeff, Distance, Remainder);
+ LLVM_DEBUG(dbgs() << "\t Distance = " << Distance << "\n");
+ LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n");
+ // Make sure Coeff divides Delta exactly
+ if (Remainder != 0) {
+ // Coeff doesn't divide Distance, no dependence
+ ++StrongSIVindependence;
+ ++StrongSIVsuccesses;
+ return true;
+ }
+ Result.DV[Level].Distance = SE->getConstant(Distance);
+ NewConstraint.setDistance(SE->getConstant(Distance), CurLoop);
+ if (Distance.sgt(0))
+ Result.DV[Level].Direction &= Dependence::DVEntry::LT;
+ else if (Distance.slt(0))
+ Result.DV[Level].Direction &= Dependence::DVEntry::GT;
+ else
+ Result.DV[Level].Direction &= Dependence::DVEntry::EQ;
+ ++StrongSIVsuccesses;
+ }
+ else if (Delta->isZero()) {
+ // since 0/X == 0
+ Result.DV[Level].Distance = Delta;
+ NewConstraint.setDistance(Delta, CurLoop);
+ Result.DV[Level].Direction &= Dependence::DVEntry::EQ;
+ ++StrongSIVsuccesses;
+ }
+ else {
+ if (Coeff->isOne()) {
+ LLVM_DEBUG(dbgs() << "\t Distance = " << *Delta << "\n");
+ Result.DV[Level].Distance = Delta; // since X/1 == X
+ NewConstraint.setDistance(Delta, CurLoop);
+ }
+ else {
+ Result.Consistent = false;
+ NewConstraint.setLine(Coeff,
+ SE->getNegativeSCEV(Coeff),
+ SE->getNegativeSCEV(Delta), CurLoop);
+ }
+
+ // maybe we can get a useful direction
+ bool DeltaMaybeZero = !SE->isKnownNonZero(Delta);
+ bool DeltaMaybePositive = !SE->isKnownNonPositive(Delta);
+ bool DeltaMaybeNegative = !SE->isKnownNonNegative(Delta);
+ bool CoeffMaybePositive = !SE->isKnownNonPositive(Coeff);
+ bool CoeffMaybeNegative = !SE->isKnownNonNegative(Coeff);
+ // The double negatives above are confusing.
+ // It helps to read !SE->isKnownNonZero(Delta)
+ // as "Delta might be Zero"
+ unsigned NewDirection = Dependence::DVEntry::NONE;
+ if ((DeltaMaybePositive && CoeffMaybePositive) ||
+ (DeltaMaybeNegative && CoeffMaybeNegative))
+ NewDirection = Dependence::DVEntry::LT;
+ if (DeltaMaybeZero)
+ NewDirection |= Dependence::DVEntry::EQ;
+ if ((DeltaMaybeNegative && CoeffMaybePositive) ||
+ (DeltaMaybePositive && CoeffMaybeNegative))
+ NewDirection |= Dependence::DVEntry::GT;
+ if (NewDirection < Result.DV[Level].Direction)
+ ++StrongSIVsuccesses;
+ Result.DV[Level].Direction &= NewDirection;
+ }
+ return false;
+}
+
+
+// weakCrossingSIVtest -
+// From the paper, Practical Dependence Testing, Section 4.2.2
+//
+// When we have a pair of subscripts of the form [c1 + a*i] and [c2 - a*i],
+// where i is an induction variable, c1 and c2 are loop invariant,
+// and a is a constant, we can solve it exactly using the
+// Weak-Crossing SIV test.
+//
+// Given c1 + a*i = c2 - a*i', we can look for the intersection of
+// the two lines, where i = i', yielding
+//
+// c1 + a*i = c2 - a*i
+// 2a*i = c2 - c1
+// i = (c2 - c1)/2a
+//
+// If i < 0, there is no dependence.
+// If i > upperbound, there is no dependence.
+// If i = 0 (i.e., if c1 = c2), there's a dependence with distance = 0.
+// If i = upperbound, there's a dependence with distance = 0.
+// If i is integral, there's a dependence (all directions).
+// If the non-integer part = 1/2, there's a dependence (<> directions).
+// Otherwise, there's no dependence.
+//
+// Can prove independence. Failing that,
+// can sometimes refine the directions.
+// Can determine iteration for splitting.
+//
+// Return true if dependence disproved.
+bool DependenceInfo::weakCrossingSIVtest(
+ const SCEV *Coeff, const SCEV *SrcConst, const SCEV *DstConst,
+ const Loop *CurLoop, unsigned Level, FullDependence &Result,
+ Constraint &NewConstraint, const SCEV *&SplitIter) const {
+ LLVM_DEBUG(dbgs() << "\tWeak-Crossing SIV test\n");
+ LLVM_DEBUG(dbgs() << "\t Coeff = " << *Coeff << "\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n");
+ ++WeakCrossingSIVapplications;
+ assert(0 < Level && Level <= CommonLevels && "Level out of range");
+ Level--;
+ Result.Consistent = false;
+ const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ NewConstraint.setLine(Coeff, Coeff, Delta, CurLoop);
+ if (Delta->isZero()) {
+ Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT);
+ Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT);
+ ++WeakCrossingSIVsuccesses;
+ if (!Result.DV[Level].Direction) {
+ ++WeakCrossingSIVindependence;
+ return true;
+ }
+ Result.DV[Level].Distance = Delta; // = 0
+ return false;
+ }
+ const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(Coeff);
+ if (!ConstCoeff)
+ return false;
+
+ Result.DV[Level].Splitable = true;
+ if (SE->isKnownNegative(ConstCoeff)) {
+ ConstCoeff = dyn_cast<SCEVConstant>(SE->getNegativeSCEV(ConstCoeff));
+ assert(ConstCoeff &&
+ "dynamic cast of negative of ConstCoeff should yield constant");
+ Delta = SE->getNegativeSCEV(Delta);
+ }
+ assert(SE->isKnownPositive(ConstCoeff) && "ConstCoeff should be positive");
+
+ // compute SplitIter for use by DependenceInfo::getSplitIteration()
+ SplitIter = SE->getUDivExpr(
+ SE->getSMaxExpr(SE->getZero(Delta->getType()), Delta),
+ SE->getMulExpr(SE->getConstant(Delta->getType(), 2), ConstCoeff));
+ LLVM_DEBUG(dbgs() << "\t Split iter = " << *SplitIter << "\n");
+
+ const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta);
+ if (!ConstDelta)
+ return false;
+
+ // We're certain that ConstCoeff > 0; therefore,
+ // if Delta < 0, then no dependence.
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ LLVM_DEBUG(dbgs() << "\t ConstCoeff = " << *ConstCoeff << "\n");
+ if (SE->isKnownNegative(Delta)) {
+ // No dependence, Delta < 0
+ ++WeakCrossingSIVindependence;
+ ++WeakCrossingSIVsuccesses;
+ return true;
+ }
+
+ // We're certain that Delta > 0 and ConstCoeff > 0.
+ // Check Delta/(2*ConstCoeff) against upper loop bound
+ if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) {
+ LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n");
+ const SCEV *ConstantTwo = SE->getConstant(UpperBound->getType(), 2);
+ const SCEV *ML = SE->getMulExpr(SE->getMulExpr(ConstCoeff, UpperBound),
+ ConstantTwo);
+ LLVM_DEBUG(dbgs() << "\t ML = " << *ML << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, ML)) {
+ // Delta too big, no dependence
+ ++WeakCrossingSIVindependence;
+ ++WeakCrossingSIVsuccesses;
+ return true;
+ }
+ if (isKnownPredicate(CmpInst::ICMP_EQ, Delta, ML)) {
+ // i = i' = UB
+ Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT);
+ Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT);
+ ++WeakCrossingSIVsuccesses;
+ if (!Result.DV[Level].Direction) {
+ ++WeakCrossingSIVindependence;
+ return true;
+ }
+ Result.DV[Level].Splitable = false;
+ Result.DV[Level].Distance = SE->getZero(Delta->getType());
+ return false;
+ }
+ }
+
+ // check that Coeff divides Delta
+ APInt APDelta = ConstDelta->getAPInt();
+ APInt APCoeff = ConstCoeff->getAPInt();
+ APInt Distance = APDelta; // these need to be initialzed
+ APInt Remainder = APDelta;
+ APInt::sdivrem(APDelta, APCoeff, Distance, Remainder);
+ LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n");
+ if (Remainder != 0) {
+ // Coeff doesn't divide Delta, no dependence
+ ++WeakCrossingSIVindependence;
+ ++WeakCrossingSIVsuccesses;
+ return true;
+ }
+ LLVM_DEBUG(dbgs() << "\t Distance = " << Distance << "\n");
+
+ // if 2*Coeff doesn't divide Delta, then the equal direction isn't possible
+ APInt Two = APInt(Distance.getBitWidth(), 2, true);
+ Remainder = Distance.srem(Two);
+ LLVM_DEBUG(dbgs() << "\t Remainder = " << Remainder << "\n");
+ if (Remainder != 0) {
+ // Equal direction isn't possible
+ Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::EQ);
+ ++WeakCrossingSIVsuccesses;
+ }
+ return false;
+}
+
+
+// Kirch's algorithm, from
+//
+// Optimizing Supercompilers for Supercomputers
+// Michael Wolfe
+// MIT Press, 1989
+//
+// Program 2.1, page 29.
+// Computes the GCD of AM and BM.
+// Also finds a solution to the equation ax - by = gcd(a, b).
+// Returns true if dependence disproved; i.e., gcd does not divide Delta.
+static bool findGCD(unsigned Bits, const APInt &AM, const APInt &BM,
+ const APInt &Delta, APInt &G, APInt &X, APInt &Y) {
+ APInt A0(Bits, 1, true), A1(Bits, 0, true);
+ APInt B0(Bits, 0, true), B1(Bits, 1, true);
+ APInt G0 = AM.abs();
+ APInt G1 = BM.abs();
+ APInt Q = G0; // these need to be initialized
+ APInt R = G0;
+ APInt::sdivrem(G0, G1, Q, R);
+ while (R != 0) {
+ APInt A2 = A0 - Q*A1; A0 = A1; A1 = A2;
+ APInt B2 = B0 - Q*B1; B0 = B1; B1 = B2;
+ G0 = G1; G1 = R;
+ APInt::sdivrem(G0, G1, Q, R);
+ }
+ G = G1;
+ LLVM_DEBUG(dbgs() << "\t GCD = " << G << "\n");
+ X = AM.slt(0) ? -A1 : A1;
+ Y = BM.slt(0) ? B1 : -B1;
+
+ // make sure gcd divides Delta
+ R = Delta.srem(G);
+ if (R != 0)
+ return true; // gcd doesn't divide Delta, no dependence
+ Q = Delta.sdiv(G);
+ X *= Q;
+ Y *= Q;
+ return false;
+}
+
+static APInt floorOfQuotient(const APInt &A, const APInt &B) {
+ APInt Q = A; // these need to be initialized
+ APInt R = A;
+ APInt::sdivrem(A, B, Q, R);
+ if (R == 0)
+ return Q;
+ if ((A.sgt(0) && B.sgt(0)) ||
+ (A.slt(0) && B.slt(0)))
+ return Q;
+ else
+ return Q - 1;
+}
+
+static APInt ceilingOfQuotient(const APInt &A, const APInt &B) {
+ APInt Q = A; // these need to be initialized
+ APInt R = A;
+ APInt::sdivrem(A, B, Q, R);
+ if (R == 0)
+ return Q;
+ if ((A.sgt(0) && B.sgt(0)) ||
+ (A.slt(0) && B.slt(0)))
+ return Q + 1;
+ else
+ return Q;
+}
+
+
+static
+APInt maxAPInt(APInt A, APInt B) {
+ return A.sgt(B) ? A : B;
+}
+
+
+static
+APInt minAPInt(APInt A, APInt B) {
+ return A.slt(B) ? A : B;
+}
+
+
+// exactSIVtest -
+// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*i],
+// where i is an induction variable, c1 and c2 are loop invariant, and a1
+// and a2 are constant, we can solve it exactly using an algorithm developed
+// by Banerjee and Wolfe. See Section 2.5.3 in
+//
+// Optimizing Supercompilers for Supercomputers
+// Michael Wolfe
+// MIT Press, 1989
+//
+// It's slower than the specialized tests (strong SIV, weak-zero SIV, etc),
+// so use them if possible. They're also a bit better with symbolics and,
+// in the case of the strong SIV test, can compute Distances.
+//
+// Return true if dependence disproved.
+bool DependenceInfo::exactSIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff,
+ const SCEV *SrcConst, const SCEV *DstConst,
+ const Loop *CurLoop, unsigned Level,
+ FullDependence &Result,
+ Constraint &NewConstraint) const {
+ LLVM_DEBUG(dbgs() << "\tExact SIV test\n");
+ LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n");
+ LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n");
+ ++ExactSIVapplications;
+ assert(0 < Level && Level <= CommonLevels && "Level out of range");
+ Level--;
+ Result.Consistent = false;
+ const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ NewConstraint.setLine(SrcCoeff, SE->getNegativeSCEV(DstCoeff),
+ Delta, CurLoop);
+ const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta);
+ const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff);
+ const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff);
+ if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff)
+ return false;
+
+ // find gcd
+ APInt G, X, Y;
+ APInt AM = ConstSrcCoeff->getAPInt();
+ APInt BM = ConstDstCoeff->getAPInt();
+ unsigned Bits = AM.getBitWidth();
+ if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) {
+ // gcd doesn't divide Delta, no dependence
+ ++ExactSIVindependence;
+ ++ExactSIVsuccesses;
+ return true;
+ }
+
+ LLVM_DEBUG(dbgs() << "\t X = " << X << ", Y = " << Y << "\n");
+
+ // since SCEV construction normalizes, LM = 0
+ APInt UM(Bits, 1, true);
+ bool UMvalid = false;
+ // UM is perhaps unavailable, let's check
+ if (const SCEVConstant *CUB =
+ collectConstantUpperBound(CurLoop, Delta->getType())) {
+ UM = CUB->getAPInt();
+ LLVM_DEBUG(dbgs() << "\t UM = " << UM << "\n");
+ UMvalid = true;
+ }
+
+ APInt TU(APInt::getSignedMaxValue(Bits));
+ APInt TL(APInt::getSignedMinValue(Bits));
+
+ // test(BM/G, LM-X) and test(-BM/G, X-UM)
+ APInt TMUL = BM.sdiv(G);
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ if (UMvalid) {
+ TU = minAPInt(TU, floorOfQuotient(UM - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ }
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(-X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ if (UMvalid) {
+ TL = maxAPInt(TL, ceilingOfQuotient(UM - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ }
+ }
+
+ // test(AM/G, LM-Y) and test(-AM/G, Y-UM)
+ TMUL = AM.sdiv(G);
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ if (UMvalid) {
+ TU = minAPInt(TU, floorOfQuotient(UM - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ }
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(-Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ if (UMvalid) {
+ TL = maxAPInt(TL, ceilingOfQuotient(UM - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ }
+ }
+ if (TL.sgt(TU)) {
+ ++ExactSIVindependence;
+ ++ExactSIVsuccesses;
+ return true;
+ }
+
+ // explore directions
+ unsigned NewDirection = Dependence::DVEntry::NONE;
+
+ // less than
+ APInt SaveTU(TU); // save these
+ APInt SaveTL(TL);
+ LLVM_DEBUG(dbgs() << "\t exploring LT direction\n");
+ TMUL = AM - BM;
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(X - Y + 1, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n");
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(X - Y + 1, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n");
+ }
+ if (TL.sle(TU)) {
+ NewDirection |= Dependence::DVEntry::LT;
+ ++ExactSIVsuccesses;
+ }
+
+ // equal
+ TU = SaveTU; // restore
+ TL = SaveTL;
+ LLVM_DEBUG(dbgs() << "\t exploring EQ direction\n");
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(X - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n");
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(X - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n");
+ }
+ TMUL = BM - AM;
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(Y - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n");
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(Y - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n");
+ }
+ if (TL.sle(TU)) {
+ NewDirection |= Dependence::DVEntry::EQ;
+ ++ExactSIVsuccesses;
+ }
+
+ // greater than
+ TU = SaveTU; // restore
+ TL = SaveTL;
+ LLVM_DEBUG(dbgs() << "\t exploring GT direction\n");
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(Y - X + 1, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TL = " << TL << "\n");
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(Y - X + 1, TMUL));
+ LLVM_DEBUG(dbgs() << "\t\t TU = " << TU << "\n");
+ }
+ if (TL.sle(TU)) {
+ NewDirection |= Dependence::DVEntry::GT;
+ ++ExactSIVsuccesses;
+ }
+
+ // finished
+ Result.DV[Level].Direction &= NewDirection;
+ if (Result.DV[Level].Direction == Dependence::DVEntry::NONE)
+ ++ExactSIVindependence;
+ return Result.DV[Level].Direction == Dependence::DVEntry::NONE;
+}
+
+
+
+// Return true if the divisor evenly divides the dividend.
+static
+bool isRemainderZero(const SCEVConstant *Dividend,
+ const SCEVConstant *Divisor) {
+ const APInt &ConstDividend = Dividend->getAPInt();
+ const APInt &ConstDivisor = Divisor->getAPInt();
+ return ConstDividend.srem(ConstDivisor) == 0;
+}
+
+
+// weakZeroSrcSIVtest -
+// From the paper, Practical Dependence Testing, Section 4.2.2
+//
+// When we have a pair of subscripts of the form [c1] and [c2 + a*i],
+// where i is an induction variable, c1 and c2 are loop invariant,
+// and a is a constant, we can solve it exactly using the
+// Weak-Zero SIV test.
+//
+// Given
+//
+// c1 = c2 + a*i
+//
+// we get
+//
+// (c1 - c2)/a = i
+//
+// If i is not an integer, there's no dependence.
+// If i < 0 or > UB, there's no dependence.
+// If i = 0, the direction is >= and peeling the
+// 1st iteration will break the dependence.
+// If i = UB, the direction is <= and peeling the
+// last iteration will break the dependence.
+// Otherwise, the direction is *.
+//
+// Can prove independence. Failing that, we can sometimes refine
+// the directions. Can sometimes show that first or last
+// iteration carries all the dependences (so worth peeling).
+//
+// (see also weakZeroDstSIVtest)
+//
+// Return true if dependence disproved.
+bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff,
+ const SCEV *SrcConst,
+ const SCEV *DstConst,
+ const Loop *CurLoop, unsigned Level,
+ FullDependence &Result,
+ Constraint &NewConstraint) const {
+ // For the WeakSIV test, it's possible the loop isn't common to
+ // the Src and Dst loops. If it isn't, then there's no need to
+ // record a direction.
+ LLVM_DEBUG(dbgs() << "\tWeak-Zero (src) SIV test\n");
+ LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << "\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n");
+ ++WeakZeroSIVapplications;
+ assert(0 < Level && Level <= MaxLevels && "Level out of range");
+ Level--;
+ Result.Consistent = false;
+ const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst);
+ NewConstraint.setLine(SE->getZero(Delta->getType()), DstCoeff, Delta,
+ CurLoop);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_EQ, SrcConst, DstConst)) {
+ if (Level < CommonLevels) {
+ Result.DV[Level].Direction &= Dependence::DVEntry::GE;
+ Result.DV[Level].PeelFirst = true;
+ ++WeakZeroSIVsuccesses;
+ }
+ return false; // dependences caused by first iteration
+ }
+ const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(DstCoeff);
+ if (!ConstCoeff)
+ return false;
+ const SCEV *AbsCoeff =
+ SE->isKnownNegative(ConstCoeff) ?
+ SE->getNegativeSCEV(ConstCoeff) : ConstCoeff;
+ const SCEV *NewDelta =
+ SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta;
+
+ // check that Delta/SrcCoeff < iteration count
+ // really check NewDelta < count*AbsCoeff
+ if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) {
+ LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n");
+ const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound);
+ if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) {
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+ if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) {
+ // dependences caused by last iteration
+ if (Level < CommonLevels) {
+ Result.DV[Level].Direction &= Dependence::DVEntry::LE;
+ Result.DV[Level].PeelLast = true;
+ ++WeakZeroSIVsuccesses;
+ }
+ return false;
+ }
+ }
+
+ // check that Delta/SrcCoeff >= 0
+ // really check that NewDelta >= 0
+ if (SE->isKnownNegative(NewDelta)) {
+ // No dependence, newDelta < 0
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+
+ // if SrcCoeff doesn't divide Delta, then no dependence
+ if (isa<SCEVConstant>(Delta) &&
+ !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) {
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+ return false;
+}
+
+
+// weakZeroDstSIVtest -
+// From the paper, Practical Dependence Testing, Section 4.2.2
+//
+// When we have a pair of subscripts of the form [c1 + a*i] and [c2],
+// where i is an induction variable, c1 and c2 are loop invariant,
+// and a is a constant, we can solve it exactly using the
+// Weak-Zero SIV test.
+//
+// Given
+//
+// c1 + a*i = c2
+//
+// we get
+//
+// i = (c2 - c1)/a
+//
+// If i is not an integer, there's no dependence.
+// If i < 0 or > UB, there's no dependence.
+// If i = 0, the direction is <= and peeling the
+// 1st iteration will break the dependence.
+// If i = UB, the direction is >= and peeling the
+// last iteration will break the dependence.
+// Otherwise, the direction is *.
+//
+// Can prove independence. Failing that, we can sometimes refine
+// the directions. Can sometimes show that first or last
+// iteration carries all the dependences (so worth peeling).
+//
+// (see also weakZeroSrcSIVtest)
+//
+// Return true if dependence disproved.
+bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff,
+ const SCEV *SrcConst,
+ const SCEV *DstConst,
+ const Loop *CurLoop, unsigned Level,
+ FullDependence &Result,
+ Constraint &NewConstraint) const {
+ // For the WeakSIV test, it's possible the loop isn't common to the
+ // Src and Dst loops. If it isn't, then there's no need to record a direction.
+ LLVM_DEBUG(dbgs() << "\tWeak-Zero (dst) SIV test\n");
+ LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << "\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n");
+ ++WeakZeroSIVapplications;
+ assert(0 < Level && Level <= SrcLevels && "Level out of range");
+ Level--;
+ Result.Consistent = false;
+ const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst);
+ NewConstraint.setLine(SrcCoeff, SE->getZero(Delta->getType()), Delta,
+ CurLoop);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_EQ, DstConst, SrcConst)) {
+ if (Level < CommonLevels) {
+ Result.DV[Level].Direction &= Dependence::DVEntry::LE;
+ Result.DV[Level].PeelFirst = true;
+ ++WeakZeroSIVsuccesses;
+ }
+ return false; // dependences caused by first iteration
+ }
+ const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(SrcCoeff);
+ if (!ConstCoeff)
+ return false;
+ const SCEV *AbsCoeff =
+ SE->isKnownNegative(ConstCoeff) ?
+ SE->getNegativeSCEV(ConstCoeff) : ConstCoeff;
+ const SCEV *NewDelta =
+ SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta;
+
+ // check that Delta/SrcCoeff < iteration count
+ // really check NewDelta < count*AbsCoeff
+ if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) {
+ LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound << "\n");
+ const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound);
+ if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) {
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+ if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) {
+ // dependences caused by last iteration
+ if (Level < CommonLevels) {
+ Result.DV[Level].Direction &= Dependence::DVEntry::GE;
+ Result.DV[Level].PeelLast = true;
+ ++WeakZeroSIVsuccesses;
+ }
+ return false;
+ }
+ }
+
+ // check that Delta/SrcCoeff >= 0
+ // really check that NewDelta >= 0
+ if (SE->isKnownNegative(NewDelta)) {
+ // No dependence, newDelta < 0
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+
+ // if SrcCoeff doesn't divide Delta, then no dependence
+ if (isa<SCEVConstant>(Delta) &&
+ !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) {
+ ++WeakZeroSIVindependence;
+ ++WeakZeroSIVsuccesses;
+ return true;
+ }
+ return false;
+}
+
+
+// exactRDIVtest - Tests the RDIV subscript pair for dependence.
+// Things of the form [c1 + a*i] and [c2 + b*j],
+// where i and j are induction variable, c1 and c2 are loop invariant,
+// and a and b are constants.
+// Returns true if any possible dependence is disproved.
+// Marks the result as inconsistent.
+// Works in some cases that symbolicRDIVtest doesn't, and vice versa.
+bool DependenceInfo::exactRDIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff,
+ const SCEV *SrcConst, const SCEV *DstConst,
+ const Loop *SrcLoop, const Loop *DstLoop,
+ FullDependence &Result) const {
+ LLVM_DEBUG(dbgs() << "\tExact RDIV test\n");
+ LLVM_DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n");
+ LLVM_DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n");
+ LLVM_DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n");
+ LLVM_DEBUG(dbgs() << "\t DstConst = " << *DstConst << "\n");
+ ++ExactRDIVapplications;
+ Result.Consistent = false;
+ const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst);
+ LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta << "\n");
+ const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta);
+ const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff);
+ const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff);
+ if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff)
+ return false;
+
+ // find gcd
+ APInt G, X, Y;
+ APInt AM = ConstSrcCoeff->getAPInt();
+ APInt BM = ConstDstCoeff->getAPInt();
+ unsigned Bits = AM.getBitWidth();
+ if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) {
+ // gcd doesn't divide Delta, no dependence
+ ++ExactRDIVindependence;
+ return true;
+ }
+
+ LLVM_DEBUG(dbgs() << "\t X = " << X << ", Y = " << Y << "\n");
+
+ // since SCEV construction seems to normalize, LM = 0
+ APInt SrcUM(Bits, 1, true);
+ bool SrcUMvalid = false;
+ // SrcUM is perhaps unavailable, let's check
+ if (const SCEVConstant *UpperBound =
+ collectConstantUpperBound(SrcLoop, Delta->getType())) {
+ SrcUM = UpperBound->getAPInt();
+ LLVM_DEBUG(dbgs() << "\t SrcUM = " << SrcUM << "\n");
+ SrcUMvalid = true;
+ }
+
+ APInt DstUM(Bits, 1, true);
+ bool DstUMvalid = false;
+ // UM is perhaps unavailable, let's check
+ if (const SCEVConstant *UpperBound =
+ collectConstantUpperBound(DstLoop, Delta->getType())) {
+ DstUM = UpperBound->getAPInt();
+ LLVM_DEBUG(dbgs() << "\t DstUM = " << DstUM << "\n");
+ DstUMvalid = true;
+ }
+
+ APInt TU(APInt::getSignedMaxValue(Bits));
+ APInt TL(APInt::getSignedMinValue(Bits));
+
+ // test(BM/G, LM-X) and test(-BM/G, X-UM)
+ APInt TMUL = BM.sdiv(G);
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ if (SrcUMvalid) {
+ TU = minAPInt(TU, floorOfQuotient(SrcUM - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ }
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(-X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ if (SrcUMvalid) {
+ TL = maxAPInt(TL, ceilingOfQuotient(SrcUM - X, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ }
+ }
+
+ // test(AM/G, LM-Y) and test(-AM/G, Y-UM)
+ TMUL = AM.sdiv(G);
+ if (TMUL.sgt(0)) {
+ TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ if (DstUMvalid) {
+ TU = minAPInt(TU, floorOfQuotient(DstUM - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ }
+ }
+ else {
+ TU = minAPInt(TU, floorOfQuotient(-Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TU = " << TU << "\n");
+ if (DstUMvalid) {
+ TL = maxAPInt(TL, ceilingOfQuotient(DstUM - Y, TMUL));
+ LLVM_DEBUG(dbgs() << "\t TL = " << TL << "\n");
+ }
+ }
+ if (TL.sgt(TU))
+ ++ExactRDIVindependence;
+ return TL.sgt(TU);
+}
+
+
+// symbolicRDIVtest -
+// In Section 4.5 of the Practical Dependence Testing paper,the authors
+// introduce a special case of Banerjee's Inequalities (also called the
+// Extreme-Value Test) that can handle some of the SIV and RDIV cases,
+// particularly cases with symbolics. Since it's only able to disprove
+// dependence (not compute distances or directions), we'll use it as a
+// fall back for the other tests.
+//
+// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j]
+// where i and j are induction variables and c1 and c2 are loop invariants,
+// we can use the symbolic tests to disprove some dependences, serving as a
+// backup for the RDIV test. Note that i and j can be the same variable,
+// letting this test serve as a backup for the various SIV tests.
+//
+// For a dependence to exist, c1 + a1*i must equal c2 + a2*j for some
+// 0 <= i <= N1 and some 0 <= j <= N2, where N1 and N2 are the (normalized)
+// loop bounds for the i and j loops, respectively. So, ...
+//
+// c1 + a1*i = c2 + a2*j
+// a1*i - a2*j = c2 - c1
+//
+// To test for a dependence, we compute c2 - c1 and make sure it's in the
+// range of the maximum and minimum possible values of a1*i - a2*j.
+// Considering the signs of a1 and a2, we have 4 possible cases:
+//
+// 1) If a1 >= 0 and a2 >= 0, then
+// a1*0 - a2*N2 <= c2 - c1 <= a1*N1 - a2*0
+// -a2*N2 <= c2 - c1 <= a1*N1
+//
+// 2) If a1 >= 0 and a2 <= 0, then
+// a1*0 - a2*0 <= c2 - c1 <= a1*N1 - a2*N2
+// 0 <= c2 - c1 <= a1*N1 - a2*N2
+//
+// 3) If a1 <= 0 and a2 >= 0, then
+// a1*N1 - a2*N2 <= c2 - c1 <= a1*0 - a2*0
+// a1*N1 - a2*N2 <= c2 - c1 <= 0
+//
+// 4) If a1 <= 0 and a2 <= 0, then
+// a1*N1 - a2*0 <= c2 - c1 <= a1*0 - a2*N2
+// a1*N1 <= c2 - c1 <= -a2*N2
+//
+// return true if dependence disproved
+bool DependenceInfo::symbolicRDIVtest(const SCEV *A1, const SCEV *A2,
+ const SCEV *C1, const SCEV *C2,
+ const Loop *Loop1,
+ const Loop *Loop2) const {
+ ++SymbolicRDIVapplications;
+ LLVM_DEBUG(dbgs() << "\ttry symbolic RDIV test\n");
+ LLVM_DEBUG(dbgs() << "\t A1 = " << *A1);
+ LLVM_DEBUG(dbgs() << ", type = " << *A1->getType() << "\n");
+ LLVM_DEBUG(dbgs() << "\t A2 = " << *A2 << "\n");
+ LLVM_DEBUG(dbgs() << "\t C1 = " << *C1 << "\n");
+ LLVM_DEBUG(dbgs() << "\t C2 = " << *C2 << "\n");
+ const SCEV *N1 = collectUpperBound(Loop1, A1->getType());
+ const SCEV *N2 = collectUpperBound(Loop2, A1->getType());
+ LLVM_DEBUG(if (N1) dbgs() << "\t N1 = " << *N1 << "\n");
+ LLVM_DEBUG(if (N2) dbgs() << "\t N2 = " << *N2 << "\n");
+ const SCEV *C2_C1 = SE->getMinusSCEV(C2, C1);
+ const SCEV *C1_C2 = SE->getMinusSCEV(C1, C2);
+ LLVM_DEBUG(dbgs() << "\t C2 - C1 = " << *C2_C1 << "\n");
+ LLVM_DEBUG(dbgs() << "\t C1 - C2 = " << *C1_C2 << "\n");
+ if (SE->isKnownNonNegative(A1)) {
+ if (SE->isKnownNonNegative(A2)) {
+ // A1 >= 0 && A2 >= 0
+ if (N1) {
+ // make sure that c2 - c1 <= a1*N1
+ const SCEV *A1N1 = SE->getMulExpr(A1, N1);
+ LLVM_DEBUG(dbgs() << "\t A1*N1 = " << *A1N1 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ if (N2) {
+ // make sure that -a2*N2 <= c2 - c1, or a2*N2 >= c1 - c2
+ const SCEV *A2N2 = SE->getMulExpr(A2, N2);
+ LLVM_DEBUG(dbgs() << "\t A2*N2 = " << *A2N2 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SLT, A2N2, C1_C2)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ }
+ else if (SE->isKnownNonPositive(A2)) {
+ // a1 >= 0 && a2 <= 0
+ if (N1 && N2) {
+ // make sure that c2 - c1 <= a1*N1 - a2*N2
+ const SCEV *A1N1 = SE->getMulExpr(A1, N1);
+ const SCEV *A2N2 = SE->getMulExpr(A2, N2);
+ const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2);
+ LLVM_DEBUG(dbgs() << "\t A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1_A2N2)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ // make sure that 0 <= c2 - c1
+ if (SE->isKnownNegative(C2_C1)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ }
+ else if (SE->isKnownNonPositive(A1)) {
+ if (SE->isKnownNonNegative(A2)) {
+ // a1 <= 0 && a2 >= 0
+ if (N1 && N2) {
+ // make sure that a1*N1 - a2*N2 <= c2 - c1
+ const SCEV *A1N1 = SE->getMulExpr(A1, N1);
+ const SCEV *A2N2 = SE->getMulExpr(A2, N2);
+ const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2);
+ LLVM_DEBUG(dbgs() << "\t A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1_A2N2, C2_C1)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ // make sure that c2 - c1 <= 0
+ if (SE->isKnownPositive(C2_C1)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ else if (SE->isKnownNonPositive(A2)) {
+ // a1 <= 0 && a2 <= 0
+ if (N1) {
+ // make sure that a1*N1 <= c2 - c1
+ const SCEV *A1N1 = SE->getMulExpr(A1, N1);
+ LLVM_DEBUG(dbgs() << "\t A1*N1 = " << *A1N1 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1, C2_C1)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ if (N2) {
+ // make sure that c2 - c1 <= -a2*N2, or c1 - c2 >= a2*N2
+ const SCEV *A2N2 = SE->getMulExpr(A2, N2);
+ LLVM_DEBUG(dbgs() << "\t A2*N2 = " << *A2N2 << "\n");
+ if (isKnownPredicate(CmpInst::ICMP_SLT, C1_C2, A2N2)) {
+ ++SymbolicRDIVindependence;
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+
+// testSIV -
+// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 - a2*i]
+// where i is an induction variable, c1 and c2 are loop invariant, and a1 and
+// a2 are constant, we attack it with an SIV test. While they can all be
+// solved with the Exact SIV test, it's worthwhile to use simpler tests when
+// they apply; they're cheaper and sometimes more precise.
+//
+// Return true if dependence disproved.
+bool DependenceInfo::testSIV(const SCEV *Src, const SCEV *Dst, unsigned &Level,
+ FullDependence &Result, Constraint &NewConstraint,
+ const SCEV *&SplitIter) const {
+ LLVM_DEBUG(dbgs() << " src = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n");
+ const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src);
+ const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst);
+ if (SrcAddRec && DstAddRec) {
+ const SCEV *SrcConst = SrcAddRec->getStart();
+ const SCEV *DstConst = DstAddRec->getStart();
+ const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE);
+ const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE);
+ const Loop *CurLoop = SrcAddRec->getLoop();
+ assert(CurLoop == DstAddRec->getLoop() &&
+ "both loops in SIV should be same");
+ Level = mapSrcLoop(CurLoop);
+ bool disproven;
+ if (SrcCoeff == DstCoeff)
+ disproven = strongSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop,
+ Level, Result, NewConstraint);
+ else if (SrcCoeff == SE->getNegativeSCEV(DstCoeff))
+ disproven = weakCrossingSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop,
+ Level, Result, NewConstraint, SplitIter);
+ else
+ disproven = exactSIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop,
+ Level, Result, NewConstraint);
+ return disproven ||
+ gcdMIVtest(Src, Dst, Result) ||
+ symbolicRDIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop, CurLoop);
+ }
+ if (SrcAddRec) {
+ const SCEV *SrcConst = SrcAddRec->getStart();
+ const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE);
+ const SCEV *DstConst = Dst;
+ const Loop *CurLoop = SrcAddRec->getLoop();
+ Level = mapSrcLoop(CurLoop);
+ return weakZeroDstSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop,
+ Level, Result, NewConstraint) ||
+ gcdMIVtest(Src, Dst, Result);
+ }
+ if (DstAddRec) {
+ const SCEV *DstConst = DstAddRec->getStart();
+ const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE);
+ const SCEV *SrcConst = Src;
+ const Loop *CurLoop = DstAddRec->getLoop();
+ Level = mapDstLoop(CurLoop);
+ return weakZeroSrcSIVtest(DstCoeff, SrcConst, DstConst,
+ CurLoop, Level, Result, NewConstraint) ||
+ gcdMIVtest(Src, Dst, Result);
+ }
+ llvm_unreachable("SIV test expected at least one AddRec");
+ return false;
+}
+
+
+// testRDIV -
+// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j]
+// where i and j are induction variables, c1 and c2 are loop invariant,
+// and a1 and a2 are constant, we can solve it exactly with an easy adaptation
+// of the Exact SIV test, the Restricted Double Index Variable (RDIV) test.
+// It doesn't make sense to talk about distance or direction in this case,
+// so there's no point in making special versions of the Strong SIV test or
+// the Weak-crossing SIV test.
+//
+// With minor algebra, this test can also be used for things like
+// [c1 + a1*i + a2*j][c2].
+//
+// Return true if dependence disproved.
+bool DependenceInfo::testRDIV(const SCEV *Src, const SCEV *Dst,
+ FullDependence &Result) const {
+ // we have 3 possible situations here:
+ // 1) [a*i + b] and [c*j + d]
+ // 2) [a*i + c*j + b] and [d]
+ // 3) [b] and [a*i + c*j + d]
+ // We need to find what we've got and get organized
+
+ const SCEV *SrcConst, *DstConst;
+ const SCEV *SrcCoeff, *DstCoeff;
+ const Loop *SrcLoop, *DstLoop;
+
+ LLVM_DEBUG(dbgs() << " src = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n");
+ const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src);
+ const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst);
+ if (SrcAddRec && DstAddRec) {
+ SrcConst = SrcAddRec->getStart();
+ SrcCoeff = SrcAddRec->getStepRecurrence(*SE);
+ SrcLoop = SrcAddRec->getLoop();
+ DstConst = DstAddRec->getStart();
+ DstCoeff = DstAddRec->getStepRecurrence(*SE);
+ DstLoop = DstAddRec->getLoop();
+ }
+ else if (SrcAddRec) {
+ if (const SCEVAddRecExpr *tmpAddRec =
+ dyn_cast<SCEVAddRecExpr>(SrcAddRec->getStart())) {
+ SrcConst = tmpAddRec->getStart();
+ SrcCoeff = tmpAddRec->getStepRecurrence(*SE);
+ SrcLoop = tmpAddRec->getLoop();
+ DstConst = Dst;
+ DstCoeff = SE->getNegativeSCEV(SrcAddRec->getStepRecurrence(*SE));
+ DstLoop = SrcAddRec->getLoop();
+ }
+ else
+ llvm_unreachable("RDIV reached by surprising SCEVs");
+ }
+ else if (DstAddRec) {
+ if (const SCEVAddRecExpr *tmpAddRec =
+ dyn_cast<SCEVAddRecExpr>(DstAddRec->getStart())) {
+ DstConst = tmpAddRec->getStart();
+ DstCoeff = tmpAddRec->getStepRecurrence(*SE);
+ DstLoop = tmpAddRec->getLoop();
+ SrcConst = Src;
+ SrcCoeff = SE->getNegativeSCEV(DstAddRec->getStepRecurrence(*SE));
+ SrcLoop = DstAddRec->getLoop();
+ }
+ else
+ llvm_unreachable("RDIV reached by surprising SCEVs");
+ }
+ else
+ llvm_unreachable("RDIV expected at least one AddRec");
+ return exactRDIVtest(SrcCoeff, DstCoeff,
+ SrcConst, DstConst,
+ SrcLoop, DstLoop,
+ Result) ||
+ gcdMIVtest(Src, Dst, Result) ||
+ symbolicRDIVtest(SrcCoeff, DstCoeff,
+ SrcConst, DstConst,
+ SrcLoop, DstLoop);
+}
+
+
+// Tests the single-subscript MIV pair (Src and Dst) for dependence.
+// Return true if dependence disproved.
+// Can sometimes refine direction vectors.
+bool DependenceInfo::testMIV(const SCEV *Src, const SCEV *Dst,
+ const SmallBitVector &Loops,
+ FullDependence &Result) const {
+ LLVM_DEBUG(dbgs() << " src = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << " dst = " << *Dst << "\n");
+ Result.Consistent = false;
+ return gcdMIVtest(Src, Dst, Result) ||
+ banerjeeMIVtest(Src, Dst, Loops, Result);
+}
+
+
+// Given a product, e.g., 10*X*Y, returns the first constant operand,
+// in this case 10. If there is no constant part, returns NULL.
+static
+const SCEVConstant *getConstantPart(const SCEV *Expr) {
+ if (const auto *Constant = dyn_cast<SCEVConstant>(Expr))
+ return Constant;
+ else if (const auto *Product = dyn_cast<SCEVMulExpr>(Expr))
+ if (const auto *Constant = dyn_cast<SCEVConstant>(Product->getOperand(0)))
+ return Constant;
+ return nullptr;
+}
+
+
+//===----------------------------------------------------------------------===//
+// gcdMIVtest -
+// Tests an MIV subscript pair for dependence.
+// Returns true if any possible dependence is disproved.
+// Marks the result as inconsistent.
+// Can sometimes disprove the equal direction for 1 or more loops,
+// as discussed in Michael Wolfe's book,
+// High Performance Compilers for Parallel Computing, page 235.
+//
+// We spend some effort (code!) to handle cases like
+// [10*i + 5*N*j + 15*M + 6], where i and j are induction variables,
+// but M and N are just loop-invariant variables.
+// This should help us handle linearized subscripts;
+// also makes this test a useful backup to the various SIV tests.
+//
+// It occurs to me that the presence of loop-invariant variables
+// changes the nature of the test from "greatest common divisor"
+// to "a common divisor".
+bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst,
+ FullDependence &Result) const {
+ LLVM_DEBUG(dbgs() << "starting gcd\n");
+ ++GCDapplications;
+ unsigned BitWidth = SE->getTypeSizeInBits(Src->getType());
+ APInt RunningGCD = APInt::getNullValue(BitWidth);
+
+ // Examine Src coefficients.
+ // Compute running GCD and record source constant.
+ // Because we're looking for the constant at the end of the chain,
+ // we can't quit the loop just because the GCD == 1.
+ const SCEV *Coefficients = Src;
+ while (const SCEVAddRecExpr *AddRec =
+ dyn_cast<SCEVAddRecExpr>(Coefficients)) {
+ const SCEV *Coeff = AddRec->getStepRecurrence(*SE);
+ // If the coefficient is the product of a constant and other stuff,
+ // we can use the constant in the GCD computation.
+ const auto *Constant = getConstantPart(Coeff);
+ if (!Constant)
+ return false;
+ APInt ConstCoeff = Constant->getAPInt();
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs());
+ Coefficients = AddRec->getStart();
+ }
+ const SCEV *SrcConst = Coefficients;
+
+ // Examine Dst coefficients.
+ // Compute running GCD and record destination constant.
+ // Because we're looking for the constant at the end of the chain,
+ // we can't quit the loop just because the GCD == 1.
+ Coefficients = Dst;
+ while (const SCEVAddRecExpr *AddRec =
+ dyn_cast<SCEVAddRecExpr>(Coefficients)) {
+ const SCEV *Coeff = AddRec->getStepRecurrence(*SE);
+ // If the coefficient is the product of a constant and other stuff,
+ // we can use the constant in the GCD computation.
+ const auto *Constant = getConstantPart(Coeff);
+ if (!Constant)
+ return false;
+ APInt ConstCoeff = Constant->getAPInt();
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs());
+ Coefficients = AddRec->getStart();
+ }
+ const SCEV *DstConst = Coefficients;
+
+ APInt ExtraGCD = APInt::getNullValue(BitWidth);
+ const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst);
+ LLVM_DEBUG(dbgs() << " Delta = " << *Delta << "\n");
+ const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Delta);
+ if (const SCEVAddExpr *Sum = dyn_cast<SCEVAddExpr>(Delta)) {
+ // If Delta is a sum of products, we may be able to make further progress.
+ for (unsigned Op = 0, Ops = Sum->getNumOperands(); Op < Ops; Op++) {
+ const SCEV *Operand = Sum->getOperand(Op);
+ if (isa<SCEVConstant>(Operand)) {
+ assert(!Constant && "Surprised to find multiple constants");
+ Constant = cast<SCEVConstant>(Operand);
+ }
+ else if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Operand)) {
+ // Search for constant operand to participate in GCD;
+ // If none found; return false.
+ const SCEVConstant *ConstOp = getConstantPart(Product);
+ if (!ConstOp)
+ return false;
+ APInt ConstOpValue = ConstOp->getAPInt();
+ ExtraGCD = APIntOps::GreatestCommonDivisor(ExtraGCD,
+ ConstOpValue.abs());
+ }
+ else
+ return false;
+ }
+ }
+ if (!Constant)
+ return false;
+ APInt ConstDelta = cast<SCEVConstant>(Constant)->getAPInt();
+ LLVM_DEBUG(dbgs() << " ConstDelta = " << ConstDelta << "\n");
+ if (ConstDelta == 0)
+ return false;
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ExtraGCD);
+ LLVM_DEBUG(dbgs() << " RunningGCD = " << RunningGCD << "\n");
+ APInt Remainder = ConstDelta.srem(RunningGCD);
+ if (Remainder != 0) {
+ ++GCDindependence;
+ return true;
+ }
+
+ // Try to disprove equal directions.
+ // For example, given a subscript pair [3*i + 2*j] and [i' + 2*j' - 1],
+ // the code above can't disprove the dependence because the GCD = 1.
+ // So we consider what happen if i = i' and what happens if j = j'.
+ // If i = i', we can simplify the subscript to [2*i + 2*j] and [2*j' - 1],
+ // which is infeasible, so we can disallow the = direction for the i level.
+ // Setting j = j' doesn't help matters, so we end up with a direction vector
+ // of [<>, *]
+ //
+ // Given A[5*i + 10*j*M + 9*M*N] and A[15*i + 20*j*M - 21*N*M + 5],
+ // we need to remember that the constant part is 5 and the RunningGCD should
+ // be initialized to ExtraGCD = 30.
+ LLVM_DEBUG(dbgs() << " ExtraGCD = " << ExtraGCD << '\n');
+
+ bool Improved = false;
+ Coefficients = Src;
+ while (const SCEVAddRecExpr *AddRec =
+ dyn_cast<SCEVAddRecExpr>(Coefficients)) {
+ Coefficients = AddRec->getStart();
+ const Loop *CurLoop = AddRec->getLoop();
+ RunningGCD = ExtraGCD;
+ const SCEV *SrcCoeff = AddRec->getStepRecurrence(*SE);
+ const SCEV *DstCoeff = SE->getMinusSCEV(SrcCoeff, SrcCoeff);
+ const SCEV *Inner = Src;
+ while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) {
+ AddRec = cast<SCEVAddRecExpr>(Inner);
+ const SCEV *Coeff = AddRec->getStepRecurrence(*SE);
+ if (CurLoop == AddRec->getLoop())
+ ; // SrcCoeff == Coeff
+ else {
+ // If the coefficient is the product of a constant and other stuff,
+ // we can use the constant in the GCD computation.
+ Constant = getConstantPart(Coeff);
+ if (!Constant)
+ return false;
+ APInt ConstCoeff = Constant->getAPInt();
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs());
+ }
+ Inner = AddRec->getStart();
+ }
+ Inner = Dst;
+ while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) {
+ AddRec = cast<SCEVAddRecExpr>(Inner);
+ const SCEV *Coeff = AddRec->getStepRecurrence(*SE);
+ if (CurLoop == AddRec->getLoop())
+ DstCoeff = Coeff;
+ else {
+ // If the coefficient is the product of a constant and other stuff,
+ // we can use the constant in the GCD computation.
+ Constant = getConstantPart(Coeff);
+ if (!Constant)
+ return false;
+ APInt ConstCoeff = Constant->getAPInt();
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs());
+ }
+ Inner = AddRec->getStart();
+ }
+ Delta = SE->getMinusSCEV(SrcCoeff, DstCoeff);
+ // If the coefficient is the product of a constant and other stuff,
+ // we can use the constant in the GCD computation.
+ Constant = getConstantPart(Delta);
+ if (!Constant)
+ // The difference of the two coefficients might not be a product
+ // or constant, in which case we give up on this direction.
+ continue;
+ APInt ConstCoeff = Constant->getAPInt();
+ RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs());
+ LLVM_DEBUG(dbgs() << "\tRunningGCD = " << RunningGCD << "\n");
+ if (RunningGCD != 0) {
+ Remainder = ConstDelta.srem(RunningGCD);
+ LLVM_DEBUG(dbgs() << "\tRemainder = " << Remainder << "\n");
+ if (Remainder != 0) {
+ unsigned Level = mapSrcLoop(CurLoop);
+ Result.DV[Level - 1].Direction &= unsigned(~Dependence::DVEntry::EQ);
+ Improved = true;
+ }
+ }
+ }
+ if (Improved)
+ ++GCDsuccesses;
+ LLVM_DEBUG(dbgs() << "all done\n");
+ return false;
+}
+
+
+//===----------------------------------------------------------------------===//
+// banerjeeMIVtest -
+// Use Banerjee's Inequalities to test an MIV subscript pair.
+// (Wolfe, in the race-car book, calls this the Extreme Value Test.)
+// Generally follows the discussion in Section 2.5.2 of
+//
+// Optimizing Supercompilers for Supercomputers
+// Michael Wolfe
+//
+// The inequalities given on page 25 are simplified in that loops are
+// normalized so that the lower bound is always 0 and the stride is always 1.
+// For example, Wolfe gives
+//
+// LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k
+//
+// where A_k is the coefficient of the kth index in the source subscript,
+// B_k is the coefficient of the kth index in the destination subscript,
+// U_k is the upper bound of the kth index, L_k is the lower bound of the Kth
+// index, and N_k is the stride of the kth index. Since all loops are normalized
+// by the SCEV package, N_k = 1 and L_k = 0, allowing us to simplify the
+// equation to
+//
+// LB^<_k = (A^-_k - B_k)^- (U_k - 0 - 1) + (A_k - B_k)0 - B_k 1
+// = (A^-_k - B_k)^- (U_k - 1) - B_k
+//
+// Similar simplifications are possible for the other equations.
+//
+// When we can't determine the number of iterations for a loop,
+// we use NULL as an indicator for the worst case, infinity.
+// When computing the upper bound, NULL denotes +inf;
+// for the lower bound, NULL denotes -inf.
+//
+// Return true if dependence disproved.
+bool DependenceInfo::banerjeeMIVtest(const SCEV *Src, const SCEV *Dst,
+ const SmallBitVector &Loops,
+ FullDependence &Result) const {
+ LLVM_DEBUG(dbgs() << "starting Banerjee\n");
+ ++BanerjeeApplications;
+ LLVM_DEBUG(dbgs() << " Src = " << *Src << '\n');
+ const SCEV *A0;
+ CoefficientInfo *A = collectCoeffInfo(Src, true, A0);
+ LLVM_DEBUG(dbgs() << " Dst = " << *Dst << '\n');
+ const SCEV *B0;
+ CoefficientInfo *B = collectCoeffInfo(Dst, false, B0);
+ BoundInfo *Bound = new BoundInfo[MaxLevels + 1];
+ const SCEV *Delta = SE->getMinusSCEV(B0, A0);
+ LLVM_DEBUG(dbgs() << "\tDelta = " << *Delta << '\n');
+
+ // Compute bounds for all the * directions.
+ LLVM_DEBUG(dbgs() << "\tBounds[*]\n");
+ for (unsigned K = 1; K <= MaxLevels; ++K) {
+ Bound[K].Iterations = A[K].Iterations ? A[K].Iterations : B[K].Iterations;
+ Bound[K].Direction = Dependence::DVEntry::ALL;
+ Bound[K].DirSet = Dependence::DVEntry::NONE;
+ findBoundsALL(A, B, Bound, K);
+#ifndef NDEBUG
+ LLVM_DEBUG(dbgs() << "\t " << K << '\t');
+ if (Bound[K].Lower[Dependence::DVEntry::ALL])
+ LLVM_DEBUG(dbgs() << *Bound[K].Lower[Dependence::DVEntry::ALL] << '\t');
+ else
+ LLVM_DEBUG(dbgs() << "-inf\t");
+ if (Bound[K].Upper[Dependence::DVEntry::ALL])
+ LLVM_DEBUG(dbgs() << *Bound[K].Upper[Dependence::DVEntry::ALL] << '\n');
+ else
+ LLVM_DEBUG(dbgs() << "+inf\n");
+#endif
+ }
+
+ // Test the *, *, *, ... case.
+ bool Disproved = false;
+ if (testBounds(Dependence::DVEntry::ALL, 0, Bound, Delta)) {
+ // Explore the direction vector hierarchy.
+ unsigned DepthExpanded = 0;
+ unsigned NewDeps = exploreDirections(1, A, B, Bound,
+ Loops, DepthExpanded, Delta);
+ if (NewDeps > 0) {
+ bool Improved = false;
+ for (unsigned K = 1; K <= CommonLevels; ++K) {
+ if (Loops[K]) {
+ unsigned Old = Result.DV[K - 1].Direction;
+ Result.DV[K - 1].Direction = Old & Bound[K].DirSet;
+ Improved |= Old != Result.DV[K - 1].Direction;
+ if (!Result.DV[K - 1].Direction) {
+ Improved = false;
+ Disproved = true;
+ break;
+ }
+ }
+ }
+ if (Improved)
+ ++BanerjeeSuccesses;
+ }
+ else {
+ ++BanerjeeIndependence;
+ Disproved = true;
+ }
+ }
+ else {
+ ++BanerjeeIndependence;
+ Disproved = true;
+ }
+ delete [] Bound;
+ delete [] A;
+ delete [] B;
+ return Disproved;
+}
+
+
+// Hierarchically expands the direction vector
+// search space, combining the directions of discovered dependences
+// in the DirSet field of Bound. Returns the number of distinct
+// dependences discovered. If the dependence is disproved,
+// it will return 0.
+unsigned DependenceInfo::exploreDirections(unsigned Level, CoefficientInfo *A,
+ CoefficientInfo *B, BoundInfo *Bound,
+ const SmallBitVector &Loops,
+ unsigned &DepthExpanded,
+ const SCEV *Delta) const {
+ if (Level > CommonLevels) {
+ // record result
+ LLVM_DEBUG(dbgs() << "\t[");
+ for (unsigned K = 1; K <= CommonLevels; ++K) {
+ if (Loops[K]) {
+ Bound[K].DirSet |= Bound[K].Direction;
+#ifndef NDEBUG
+ switch (Bound[K].Direction) {
+ case Dependence::DVEntry::LT:
+ LLVM_DEBUG(dbgs() << " <");
+ break;
+ case Dependence::DVEntry::EQ:
+ LLVM_DEBUG(dbgs() << " =");
+ break;
+ case Dependence::DVEntry::GT:
+ LLVM_DEBUG(dbgs() << " >");
+ break;
+ case Dependence::DVEntry::ALL:
+ LLVM_DEBUG(dbgs() << " *");
+ break;
+ default:
+ llvm_unreachable("unexpected Bound[K].Direction");
+ }
+#endif
+ }
+ }
+ LLVM_DEBUG(dbgs() << " ]\n");
+ return 1;
+ }
+ if (Loops[Level]) {
+ if (Level > DepthExpanded) {
+ DepthExpanded = Level;
+ // compute bounds for <, =, > at current level
+ findBoundsLT(A, B, Bound, Level);
+ findBoundsGT(A, B, Bound, Level);
+ findBoundsEQ(A, B, Bound, Level);
+#ifndef NDEBUG
+ LLVM_DEBUG(dbgs() << "\tBound for level = " << Level << '\n');
+ LLVM_DEBUG(dbgs() << "\t <\t");
+ if (Bound[Level].Lower[Dependence::DVEntry::LT])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::LT]
+ << '\t');
+ else
+ LLVM_DEBUG(dbgs() << "-inf\t");
+ if (Bound[Level].Upper[Dependence::DVEntry::LT])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::LT]
+ << '\n');
+ else
+ LLVM_DEBUG(dbgs() << "+inf\n");
+ LLVM_DEBUG(dbgs() << "\t =\t");
+ if (Bound[Level].Lower[Dependence::DVEntry::EQ])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::EQ]
+ << '\t');
+ else
+ LLVM_DEBUG(dbgs() << "-inf\t");
+ if (Bound[Level].Upper[Dependence::DVEntry::EQ])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::EQ]
+ << '\n');
+ else
+ LLVM_DEBUG(dbgs() << "+inf\n");
+ LLVM_DEBUG(dbgs() << "\t >\t");
+ if (Bound[Level].Lower[Dependence::DVEntry::GT])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::GT]
+ << '\t');
+ else
+ LLVM_DEBUG(dbgs() << "-inf\t");
+ if (Bound[Level].Upper[Dependence::DVEntry::GT])
+ LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::GT]
+ << '\n');
+ else
+ LLVM_DEBUG(dbgs() << "+inf\n");
+#endif
+ }
+
+ unsigned NewDeps = 0;
+
+ // test bounds for <, *, *, ...
+ if (testBounds(Dependence::DVEntry::LT, Level, Bound, Delta))
+ NewDeps += exploreDirections(Level + 1, A, B, Bound,
+ Loops, DepthExpanded, Delta);
+
+ // Test bounds for =, *, *, ...
+ if (testBounds(Dependence::DVEntry::EQ, Level, Bound, Delta))
+ NewDeps += exploreDirections(Level + 1, A, B, Bound,
+ Loops, DepthExpanded, Delta);
+
+ // test bounds for >, *, *, ...
+ if (testBounds(Dependence::DVEntry::GT, Level, Bound, Delta))
+ NewDeps += exploreDirections(Level + 1, A, B, Bound,
+ Loops, DepthExpanded, Delta);
+
+ Bound[Level].Direction = Dependence::DVEntry::ALL;
+ return NewDeps;
+ }
+ else
+ return exploreDirections(Level + 1, A, B, Bound, Loops, DepthExpanded, Delta);
+}
+
+
+// Returns true iff the current bounds are plausible.
+bool DependenceInfo::testBounds(unsigned char DirKind, unsigned Level,
+ BoundInfo *Bound, const SCEV *Delta) const {
+ Bound[Level].Direction = DirKind;
+ if (const SCEV *LowerBound = getLowerBound(Bound))
+ if (isKnownPredicate(CmpInst::ICMP_SGT, LowerBound, Delta))
+ return false;
+ if (const SCEV *UpperBound = getUpperBound(Bound))
+ if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, UpperBound))
+ return false;
+ return true;
+}
+
+
+// Computes the upper and lower bounds for level K
+// using the * direction. Records them in Bound.
+// Wolfe gives the equations
+//
+// LB^*_k = (A^-_k - B^+_k)(U_k - L_k) + (A_k - B_k)L_k
+// UB^*_k = (A^+_k - B^-_k)(U_k - L_k) + (A_k - B_k)L_k
+//
+// Since we normalize loops, we can simplify these equations to
+//
+// LB^*_k = (A^-_k - B^+_k)U_k
+// UB^*_k = (A^+_k - B^-_k)U_k
+//
+// We must be careful to handle the case where the upper bound is unknown.
+// Note that the lower bound is always <= 0
+// and the upper bound is always >= 0.
+void DependenceInfo::findBoundsALL(CoefficientInfo *A, CoefficientInfo *B,
+ BoundInfo *Bound, unsigned K) const {
+ Bound[K].Lower[Dependence::DVEntry::ALL] = nullptr; // Default value = -infinity.
+ Bound[K].Upper[Dependence::DVEntry::ALL] = nullptr; // Default value = +infinity.
+ if (Bound[K].Iterations) {
+ Bound[K].Lower[Dependence::DVEntry::ALL] =
+ SE->getMulExpr(SE->getMinusSCEV(A[K].NegPart, B[K].PosPart),
+ Bound[K].Iterations);
+ Bound[K].Upper[Dependence::DVEntry::ALL] =
+ SE->getMulExpr(SE->getMinusSCEV(A[K].PosPart, B[K].NegPart),
+ Bound[K].Iterations);
+ }
+ else {
+ // If the difference is 0, we won't need to know the number of iterations.
+ if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].NegPart, B[K].PosPart))
+ Bound[K].Lower[Dependence::DVEntry::ALL] =
+ SE->getZero(A[K].Coeff->getType());
+ if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].PosPart, B[K].NegPart))
+ Bound[K].Upper[Dependence::DVEntry::ALL] =
+ SE->getZero(A[K].Coeff->getType());
+ }
+}
+
+
+// Computes the upper and lower bounds for level K
+// using the = direction. Records them in Bound.
+// Wolfe gives the equations
+//
+// LB^=_k = (A_k - B_k)^- (U_k - L_k) + (A_k - B_k)L_k
+// UB^=_k = (A_k - B_k)^+ (U_k - L_k) + (A_k - B_k)L_k
+//
+// Since we normalize loops, we can simplify these equations to
+//
+// LB^=_k = (A_k - B_k)^- U_k
+// UB^=_k = (A_k - B_k)^+ U_k
+//
+// We must be careful to handle the case where the upper bound is unknown.
+// Note that the lower bound is always <= 0
+// and the upper bound is always >= 0.
+void DependenceInfo::findBoundsEQ(CoefficientInfo *A, CoefficientInfo *B,
+ BoundInfo *Bound, unsigned K) const {
+ Bound[K].Lower[Dependence::DVEntry::EQ] = nullptr; // Default value = -infinity.
+ Bound[K].Upper[Dependence::DVEntry::EQ] = nullptr; // Default value = +infinity.
+ if (Bound[K].Iterations) {
+ const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff);
+ const SCEV *NegativePart = getNegativePart(Delta);
+ Bound[K].Lower[Dependence::DVEntry::EQ] =
+ SE->getMulExpr(NegativePart, Bound[K].Iterations);
+ const SCEV *PositivePart = getPositivePart(Delta);
+ Bound[K].Upper[Dependence::DVEntry::EQ] =
+ SE->getMulExpr(PositivePart, Bound[K].Iterations);
+ }
+ else {
+ // If the positive/negative part of the difference is 0,
+ // we won't need to know the number of iterations.
+ const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff);
+ const SCEV *NegativePart = getNegativePart(Delta);
+ if (NegativePart->isZero())
+ Bound[K].Lower[Dependence::DVEntry::EQ] = NegativePart; // Zero
+ const SCEV *PositivePart = getPositivePart(Delta);
+ if (PositivePart->isZero())
+ Bound[K].Upper[Dependence::DVEntry::EQ] = PositivePart; // Zero
+ }
+}
+
+
+// Computes the upper and lower bounds for level K
+// using the < direction. Records them in Bound.
+// Wolfe gives the equations
+//
+// LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k
+// UB^<_k = (A^+_k - B_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k
+//
+// Since we normalize loops, we can simplify these equations to
+//
+// LB^<_k = (A^-_k - B_k)^- (U_k - 1) - B_k
+// UB^<_k = (A^+_k - B_k)^+ (U_k - 1) - B_k
+//
+// We must be careful to handle the case where the upper bound is unknown.
+void DependenceInfo::findBoundsLT(CoefficientInfo *A, CoefficientInfo *B,
+ BoundInfo *Bound, unsigned K) const {
+ Bound[K].Lower[Dependence::DVEntry::LT] = nullptr; // Default value = -infinity.
+ Bound[K].Upper[Dependence::DVEntry::LT] = nullptr; // Default value = +infinity.
+ if (Bound[K].Iterations) {
+ const SCEV *Iter_1 = SE->getMinusSCEV(
+ Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType()));
+ const SCEV *NegPart =
+ getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff));
+ Bound[K].Lower[Dependence::DVEntry::LT] =
+ SE->getMinusSCEV(SE->getMulExpr(NegPart, Iter_1), B[K].Coeff);
+ const SCEV *PosPart =
+ getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff));
+ Bound[K].Upper[Dependence::DVEntry::LT] =
+ SE->getMinusSCEV(SE->getMulExpr(PosPart, Iter_1), B[K].Coeff);
+ }
+ else {
+ // If the positive/negative part of the difference is 0,
+ // we won't need to know the number of iterations.
+ const SCEV *NegPart =
+ getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff));
+ if (NegPart->isZero())
+ Bound[K].Lower[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff);
+ const SCEV *PosPart =
+ getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff));
+ if (PosPart->isZero())
+ Bound[K].Upper[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff);
+ }
+}
+
+
+// Computes the upper and lower bounds for level K
+// using the > direction. Records them in Bound.
+// Wolfe gives the equations
+//
+// LB^>_k = (A_k - B^+_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k
+// UB^>_k = (A_k - B^-_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k
+//
+// Since we normalize loops, we can simplify these equations to
+//
+// LB^>_k = (A_k - B^+_k)^- (U_k - 1) + A_k
+// UB^>_k = (A_k - B^-_k)^+ (U_k - 1) + A_k
+//
+// We must be careful to handle the case where the upper bound is unknown.
+void DependenceInfo::findBoundsGT(CoefficientInfo *A, CoefficientInfo *B,
+ BoundInfo *Bound, unsigned K) const {
+ Bound[K].Lower[Dependence::DVEntry::GT] = nullptr; // Default value = -infinity.
+ Bound[K].Upper[Dependence::DVEntry::GT] = nullptr; // Default value = +infinity.
+ if (Bound[K].Iterations) {
+ const SCEV *Iter_1 = SE->getMinusSCEV(
+ Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType()));
+ const SCEV *NegPart =
+ getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart));
+ Bound[K].Lower[Dependence::DVEntry::GT] =
+ SE->getAddExpr(SE->getMulExpr(NegPart, Iter_1), A[K].Coeff);
+ const SCEV *PosPart =
+ getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart));
+ Bound[K].Upper[Dependence::DVEntry::GT] =
+ SE->getAddExpr(SE->getMulExpr(PosPart, Iter_1), A[K].Coeff);
+ }
+ else {
+ // If the positive/negative part of the difference is 0,
+ // we won't need to know the number of iterations.
+ const SCEV *NegPart = getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart));
+ if (NegPart->isZero())
+ Bound[K].Lower[Dependence::DVEntry::GT] = A[K].Coeff;
+ const SCEV *PosPart = getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart));
+ if (PosPart->isZero())
+ Bound[K].Upper[Dependence::DVEntry::GT] = A[K].Coeff;
+ }
+}
+
+
+// X^+ = max(X, 0)
+const SCEV *DependenceInfo::getPositivePart(const SCEV *X) const {
+ return SE->getSMaxExpr(X, SE->getZero(X->getType()));
+}
+
+
+// X^- = min(X, 0)
+const SCEV *DependenceInfo::getNegativePart(const SCEV *X) const {
+ return SE->getSMinExpr(X, SE->getZero(X->getType()));
+}
+
+
+// Walks through the subscript,
+// collecting each coefficient, the associated loop bounds,
+// and recording its positive and negative parts for later use.
+DependenceInfo::CoefficientInfo *
+DependenceInfo::collectCoeffInfo(const SCEV *Subscript, bool SrcFlag,
+ const SCEV *&Constant) const {
+ const SCEV *Zero = SE->getZero(Subscript->getType());
+ CoefficientInfo *CI = new CoefficientInfo[MaxLevels + 1];
+ for (unsigned K = 1; K <= MaxLevels; ++K) {
+ CI[K].Coeff = Zero;
+ CI[K].PosPart = Zero;
+ CI[K].NegPart = Zero;
+ CI[K].Iterations = nullptr;
+ }
+ while (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Subscript)) {
+ const Loop *L = AddRec->getLoop();
+ unsigned K = SrcFlag ? mapSrcLoop(L) : mapDstLoop(L);
+ CI[K].Coeff = AddRec->getStepRecurrence(*SE);
+ CI[K].PosPart = getPositivePart(CI[K].Coeff);
+ CI[K].NegPart = getNegativePart(CI[K].Coeff);
+ CI[K].Iterations = collectUpperBound(L, Subscript->getType());
+ Subscript = AddRec->getStart();
+ }
+ Constant = Subscript;
+#ifndef NDEBUG
+ LLVM_DEBUG(dbgs() << "\tCoefficient Info\n");
+ for (unsigned K = 1; K <= MaxLevels; ++K) {
+ LLVM_DEBUG(dbgs() << "\t " << K << "\t" << *CI[K].Coeff);
+ LLVM_DEBUG(dbgs() << "\tPos Part = ");
+ LLVM_DEBUG(dbgs() << *CI[K].PosPart);
+ LLVM_DEBUG(dbgs() << "\tNeg Part = ");
+ LLVM_DEBUG(dbgs() << *CI[K].NegPart);
+ LLVM_DEBUG(dbgs() << "\tUpper Bound = ");
+ if (CI[K].Iterations)
+ LLVM_DEBUG(dbgs() << *CI[K].Iterations);
+ else
+ LLVM_DEBUG(dbgs() << "+inf");
+ LLVM_DEBUG(dbgs() << '\n');
+ }
+ LLVM_DEBUG(dbgs() << "\t Constant = " << *Subscript << '\n');
+#endif
+ return CI;
+}
+
+
+// Looks through all the bounds info and
+// computes the lower bound given the current direction settings
+// at each level. If the lower bound for any level is -inf,
+// the result is -inf.
+const SCEV *DependenceInfo::getLowerBound(BoundInfo *Bound) const {
+ const SCEV *Sum = Bound[1].Lower[Bound[1].Direction];
+ for (unsigned K = 2; Sum && K <= MaxLevels; ++K) {
+ if (Bound[K].Lower[Bound[K].Direction])
+ Sum = SE->getAddExpr(Sum, Bound[K].Lower[Bound[K].Direction]);
+ else
+ Sum = nullptr;
+ }
+ return Sum;
+}
+
+
+// Looks through all the bounds info and
+// computes the upper bound given the current direction settings
+// at each level. If the upper bound at any level is +inf,
+// the result is +inf.
+const SCEV *DependenceInfo::getUpperBound(BoundInfo *Bound) const {
+ const SCEV *Sum = Bound[1].Upper[Bound[1].Direction];
+ for (unsigned K = 2; Sum && K <= MaxLevels; ++K) {
+ if (Bound[K].Upper[Bound[K].Direction])
+ Sum = SE->getAddExpr(Sum, Bound[K].Upper[Bound[K].Direction]);
+ else
+ Sum = nullptr;
+ }
+ return Sum;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Constraint manipulation for Delta test.
+
+// Given a linear SCEV,
+// return the coefficient (the step)
+// corresponding to the specified loop.
+// If there isn't one, return 0.
+// For example, given a*i + b*j + c*k, finding the coefficient
+// corresponding to the j loop would yield b.
+const SCEV *DependenceInfo::findCoefficient(const SCEV *Expr,
+ const Loop *TargetLoop) const {
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr);
+ if (!AddRec)
+ return SE->getZero(Expr->getType());
+ if (AddRec->getLoop() == TargetLoop)
+ return AddRec->getStepRecurrence(*SE);
+ return findCoefficient(AddRec->getStart(), TargetLoop);
+}
+
+
+// Given a linear SCEV,
+// return the SCEV given by zeroing out the coefficient
+// corresponding to the specified loop.
+// For example, given a*i + b*j + c*k, zeroing the coefficient
+// corresponding to the j loop would yield a*i + c*k.
+const SCEV *DependenceInfo::zeroCoefficient(const SCEV *Expr,
+ const Loop *TargetLoop) const {
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr);
+ if (!AddRec)
+ return Expr; // ignore
+ if (AddRec->getLoop() == TargetLoop)
+ return AddRec->getStart();
+ return SE->getAddRecExpr(zeroCoefficient(AddRec->getStart(), TargetLoop),
+ AddRec->getStepRecurrence(*SE),
+ AddRec->getLoop(),
+ AddRec->getNoWrapFlags());
+}
+
+
+// Given a linear SCEV Expr,
+// return the SCEV given by adding some Value to the
+// coefficient corresponding to the specified TargetLoop.
+// For example, given a*i + b*j + c*k, adding 1 to the coefficient
+// corresponding to the j loop would yield a*i + (b+1)*j + c*k.
+const SCEV *DependenceInfo::addToCoefficient(const SCEV *Expr,
+ const Loop *TargetLoop,
+ const SCEV *Value) const {
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr);
+ if (!AddRec) // create a new addRec
+ return SE->getAddRecExpr(Expr,
+ Value,
+ TargetLoop,
+ SCEV::FlagAnyWrap); // Worst case, with no info.
+ if (AddRec->getLoop() == TargetLoop) {
+ const SCEV *Sum = SE->getAddExpr(AddRec->getStepRecurrence(*SE), Value);
+ if (Sum->isZero())
+ return AddRec->getStart();
+ return SE->getAddRecExpr(AddRec->getStart(),
+ Sum,
+ AddRec->getLoop(),
+ AddRec->getNoWrapFlags());
+ }
+ if (SE->isLoopInvariant(AddRec, TargetLoop))
+ return SE->getAddRecExpr(AddRec, Value, TargetLoop, SCEV::FlagAnyWrap);
+ return SE->getAddRecExpr(
+ addToCoefficient(AddRec->getStart(), TargetLoop, Value),
+ AddRec->getStepRecurrence(*SE), AddRec->getLoop(),
+ AddRec->getNoWrapFlags());
+}
+
+
+// Review the constraints, looking for opportunities
+// to simplify a subscript pair (Src and Dst).
+// Return true if some simplification occurs.
+// If the simplification isn't exact (that is, if it is conservative
+// in terms of dependence), set consistent to false.
+// Corresponds to Figure 5 from the paper
+//
+// Practical Dependence Testing
+// Goff, Kennedy, Tseng
+// PLDI 1991
+bool DependenceInfo::propagate(const SCEV *&Src, const SCEV *&Dst,
+ SmallBitVector &Loops,
+ SmallVectorImpl<Constraint> &Constraints,
+ bool &Consistent) {
+ bool Result = false;
+ for (unsigned LI : Loops.set_bits()) {
+ LLVM_DEBUG(dbgs() << "\t Constraint[" << LI << "] is");
+ LLVM_DEBUG(Constraints[LI].dump(dbgs()));
+ if (Constraints[LI].isDistance())
+ Result |= propagateDistance(Src, Dst, Constraints[LI], Consistent);
+ else if (Constraints[LI].isLine())
+ Result |= propagateLine(Src, Dst, Constraints[LI], Consistent);
+ else if (Constraints[LI].isPoint())
+ Result |= propagatePoint(Src, Dst, Constraints[LI]);
+ }
+ return Result;
+}
+
+
+// Attempt to propagate a distance
+// constraint into a subscript pair (Src and Dst).
+// Return true if some simplification occurs.
+// If the simplification isn't exact (that is, if it is conservative
+// in terms of dependence), set consistent to false.
+bool DependenceInfo::propagateDistance(const SCEV *&Src, const SCEV *&Dst,
+ Constraint &CurConstraint,
+ bool &Consistent) {
+ const Loop *CurLoop = CurConstraint.getAssociatedLoop();
+ LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n");
+ const SCEV *A_K = findCoefficient(Src, CurLoop);
+ if (A_K->isZero())
+ return false;
+ const SCEV *DA_K = SE->getMulExpr(A_K, CurConstraint.getD());
+ Src = SE->getMinusSCEV(Src, DA_K);
+ Src = zeroCoefficient(Src, CurLoop);
+ LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n");
+ Dst = addToCoefficient(Dst, CurLoop, SE->getNegativeSCEV(A_K));
+ LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n");
+ if (!findCoefficient(Dst, CurLoop)->isZero())
+ Consistent = false;
+ return true;
+}
+
+
+// Attempt to propagate a line
+// constraint into a subscript pair (Src and Dst).
+// Return true if some simplification occurs.
+// If the simplification isn't exact (that is, if it is conservative
+// in terms of dependence), set consistent to false.
+bool DependenceInfo::propagateLine(const SCEV *&Src, const SCEV *&Dst,
+ Constraint &CurConstraint,
+ bool &Consistent) {
+ const Loop *CurLoop = CurConstraint.getAssociatedLoop();
+ const SCEV *A = CurConstraint.getA();
+ const SCEV *B = CurConstraint.getB();
+ const SCEV *C = CurConstraint.getC();
+ LLVM_DEBUG(dbgs() << "\t\tA = " << *A << ", B = " << *B << ", C = " << *C
+ << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tSrc = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tDst = " << *Dst << "\n");
+ if (A->isZero()) {
+ const SCEVConstant *Bconst = dyn_cast<SCEVConstant>(B);
+ const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C);
+ if (!Bconst || !Cconst) return false;
+ APInt Beta = Bconst->getAPInt();
+ APInt Charlie = Cconst->getAPInt();
+ APInt CdivB = Charlie.sdiv(Beta);
+ assert(Charlie.srem(Beta) == 0 && "C should be evenly divisible by B");
+ const SCEV *AP_K = findCoefficient(Dst, CurLoop);
+ // Src = SE->getAddExpr(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB)));
+ Src = SE->getMinusSCEV(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB)));
+ Dst = zeroCoefficient(Dst, CurLoop);
+ if (!findCoefficient(Src, CurLoop)->isZero())
+ Consistent = false;
+ }
+ else if (B->isZero()) {
+ const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A);
+ const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C);
+ if (!Aconst || !Cconst) return false;
+ APInt Alpha = Aconst->getAPInt();
+ APInt Charlie = Cconst->getAPInt();
+ APInt CdivA = Charlie.sdiv(Alpha);
+ assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A");
+ const SCEV *A_K = findCoefficient(Src, CurLoop);
+ Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA)));
+ Src = zeroCoefficient(Src, CurLoop);
+ if (!findCoefficient(Dst, CurLoop)->isZero())
+ Consistent = false;
+ }
+ else if (isKnownPredicate(CmpInst::ICMP_EQ, A, B)) {
+ const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A);
+ const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C);
+ if (!Aconst || !Cconst) return false;
+ APInt Alpha = Aconst->getAPInt();
+ APInt Charlie = Cconst->getAPInt();
+ APInt CdivA = Charlie.sdiv(Alpha);
+ assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A");
+ const SCEV *A_K = findCoefficient(Src, CurLoop);
+ Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA)));
+ Src = zeroCoefficient(Src, CurLoop);
+ Dst = addToCoefficient(Dst, CurLoop, A_K);
+ if (!findCoefficient(Dst, CurLoop)->isZero())
+ Consistent = false;
+ }
+ else {
+ // paper is incorrect here, or perhaps just misleading
+ const SCEV *A_K = findCoefficient(Src, CurLoop);
+ Src = SE->getMulExpr(Src, A);
+ Dst = SE->getMulExpr(Dst, A);
+ Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, C));
+ Src = zeroCoefficient(Src, CurLoop);
+ Dst = addToCoefficient(Dst, CurLoop, SE->getMulExpr(A_K, B));
+ if (!findCoefficient(Dst, CurLoop)->isZero())
+ Consistent = false;
+ }
+ LLVM_DEBUG(dbgs() << "\t\tnew Src = " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tnew Dst = " << *Dst << "\n");
+ return true;
+}
+
+
+// Attempt to propagate a point
+// constraint into a subscript pair (Src and Dst).
+// Return true if some simplification occurs.
+bool DependenceInfo::propagatePoint(const SCEV *&Src, const SCEV *&Dst,
+ Constraint &CurConstraint) {
+ const Loop *CurLoop = CurConstraint.getAssociatedLoop();
+ const SCEV *A_K = findCoefficient(Src, CurLoop);
+ const SCEV *AP_K = findCoefficient(Dst, CurLoop);
+ const SCEV *XA_K = SE->getMulExpr(A_K, CurConstraint.getX());
+ const SCEV *YAP_K = SE->getMulExpr(AP_K, CurConstraint.getY());
+ LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n");
+ Src = SE->getAddExpr(Src, SE->getMinusSCEV(XA_K, YAP_K));
+ Src = zeroCoefficient(Src, CurLoop);
+ LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n");
+ LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n");
+ Dst = zeroCoefficient(Dst, CurLoop);
+ LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n");
+ return true;
+}
+
+
+// Update direction vector entry based on the current constraint.
+void DependenceInfo::updateDirection(Dependence::DVEntry &Level,
+ const Constraint &CurConstraint) const {
+ LLVM_DEBUG(dbgs() << "\tUpdate direction, constraint =");
+ LLVM_DEBUG(CurConstraint.dump(dbgs()));
+ if (CurConstraint.isAny())
+ ; // use defaults
+ else if (CurConstraint.isDistance()) {
+ // this one is consistent, the others aren't
+ Level.Scalar = false;
+ Level.Distance = CurConstraint.getD();
+ unsigned NewDirection = Dependence::DVEntry::NONE;
+ if (!SE->isKnownNonZero(Level.Distance)) // if may be zero
+ NewDirection = Dependence::DVEntry::EQ;
+ if (!SE->isKnownNonPositive(Level.Distance)) // if may be positive
+ NewDirection |= Dependence::DVEntry::LT;
+ if (!SE->isKnownNonNegative(Level.Distance)) // if may be negative
+ NewDirection |= Dependence::DVEntry::GT;
+ Level.Direction &= NewDirection;
+ }
+ else if (CurConstraint.isLine()) {
+ Level.Scalar = false;
+ Level.Distance = nullptr;
+ // direction should be accurate
+ }
+ else if (CurConstraint.isPoint()) {
+ Level.Scalar = false;
+ Level.Distance = nullptr;
+ unsigned NewDirection = Dependence::DVEntry::NONE;
+ if (!isKnownPredicate(CmpInst::ICMP_NE,
+ CurConstraint.getY(),
+ CurConstraint.getX()))
+ // if X may be = Y
+ NewDirection |= Dependence::DVEntry::EQ;
+ if (!isKnownPredicate(CmpInst::ICMP_SLE,
+ CurConstraint.getY(),
+ CurConstraint.getX()))
+ // if Y may be > X
+ NewDirection |= Dependence::DVEntry::LT;
+ if (!isKnownPredicate(CmpInst::ICMP_SGE,
+ CurConstraint.getY(),
+ CurConstraint.getX()))
+ // if Y may be < X
+ NewDirection |= Dependence::DVEntry::GT;
+ Level.Direction &= NewDirection;
+ }
+ else
+ llvm_unreachable("constraint has unexpected kind");
+}
+
+/// Check if we can delinearize the subscripts. If the SCEVs representing the
+/// source and destination array references are recurrences on a nested loop,
+/// this function flattens the nested recurrences into separate recurrences
+/// for each loop level.
+bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst,
+ SmallVectorImpl<Subscript> &Pair) {
+ assert(isLoadOrStore(Src) && "instruction is not load or store");
+ assert(isLoadOrStore(Dst) && "instruction is not load or store");
+ Value *SrcPtr = getLoadStorePointerOperand(Src);
+ Value *DstPtr = getLoadStorePointerOperand(Dst);
+
+ Loop *SrcLoop = LI->getLoopFor(Src->getParent());
+ Loop *DstLoop = LI->getLoopFor(Dst->getParent());
+
+ // Below code mimics the code in Delinearization.cpp
+ const SCEV *SrcAccessFn =
+ SE->getSCEVAtScope(SrcPtr, SrcLoop);
+ const SCEV *DstAccessFn =
+ SE->getSCEVAtScope(DstPtr, DstLoop);
+
+ const SCEVUnknown *SrcBase =
+ dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcAccessFn));
+ const SCEVUnknown *DstBase =
+ dyn_cast<SCEVUnknown>(SE->getPointerBase(DstAccessFn));
+
+ if (!SrcBase || !DstBase || SrcBase != DstBase)
+ return false;
+
+ const SCEV *ElementSize = SE->getElementSize(Src);
+ if (ElementSize != SE->getElementSize(Dst))
+ return false;
+
+ const SCEV *SrcSCEV = SE->getMinusSCEV(SrcAccessFn, SrcBase);
+ const SCEV *DstSCEV = SE->getMinusSCEV(DstAccessFn, DstBase);
+
+ const SCEVAddRecExpr *SrcAR = dyn_cast<SCEVAddRecExpr>(SrcSCEV);
+ const SCEVAddRecExpr *DstAR = dyn_cast<SCEVAddRecExpr>(DstSCEV);
+ if (!SrcAR || !DstAR || !SrcAR->isAffine() || !DstAR->isAffine())
+ return false;
+
+ // First step: collect parametric terms in both array references.
+ SmallVector<const SCEV *, 4> Terms;
+ SE->collectParametricTerms(SrcAR, Terms);
+ SE->collectParametricTerms(DstAR, Terms);
+
+ // Second step: find subscript sizes.
+ SmallVector<const SCEV *, 4> Sizes;
+ SE->findArrayDimensions(Terms, Sizes, ElementSize);
+
+ // Third step: compute the access functions for each subscript.
+ SmallVector<const SCEV *, 4> SrcSubscripts, DstSubscripts;
+ SE->computeAccessFunctions(SrcAR, SrcSubscripts, Sizes);
+ SE->computeAccessFunctions(DstAR, DstSubscripts, Sizes);
+
+ // Fail when there is only a subscript: that's a linearized access function.
+ if (SrcSubscripts.size() < 2 || DstSubscripts.size() < 2 ||
+ SrcSubscripts.size() != DstSubscripts.size())
+ return false;
+
+ int size = SrcSubscripts.size();
+
+ // Statically check that the array bounds are in-range. The first subscript we
+ // don't have a size for and it cannot overflow into another subscript, so is
+ // always safe. The others need to be 0 <= subscript[i] < bound, for both src
+ // and dst.
+ // FIXME: It may be better to record these sizes and add them as constraints
+ // to the dependency checks.
+ if (!DisableDelinearizationChecks)
+ for (int i = 1; i < size; ++i) {
+ if (!isKnownNonNegative(SrcSubscripts[i], SrcPtr))
+ return false;
+
+ if (!isKnownLessThan(SrcSubscripts[i], Sizes[i - 1]))
+ return false;
+
+ if (!isKnownNonNegative(DstSubscripts[i], DstPtr))
+ return false;
+
+ if (!isKnownLessThan(DstSubscripts[i], Sizes[i - 1]))
+ return false;
+ }
+
+ LLVM_DEBUG({
+ dbgs() << "\nSrcSubscripts: ";
+ for (int i = 0; i < size; i++)
+ dbgs() << *SrcSubscripts[i];
+ dbgs() << "\nDstSubscripts: ";
+ for (int i = 0; i < size; i++)
+ dbgs() << *DstSubscripts[i];
+ });
+
+ // The delinearization transforms a single-subscript MIV dependence test into
+ // a multi-subscript SIV dependence test that is easier to compute. So we
+ // resize Pair to contain as many pairs of subscripts as the delinearization
+ // has found, and then initialize the pairs following the delinearization.
+ Pair.resize(size);
+ for (int i = 0; i < size; ++i) {
+ Pair[i].Src = SrcSubscripts[i];
+ Pair[i].Dst = DstSubscripts[i];
+ unifySubscriptType(&Pair[i]);
+ }
+
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+// For debugging purposes, dump a small bit vector to dbgs().
+static void dumpSmallBitVector(SmallBitVector &BV) {
+ dbgs() << "{";
+ for (unsigned VI : BV.set_bits()) {
+ dbgs() << VI;
+ if (BV.find_next(VI) >= 0)
+ dbgs() << ' ';
+ }
+ dbgs() << "}\n";
+}
+#endif
+
+bool DependenceInfo::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // Check if the analysis itself has been invalidated.
+ auto PAC = PA.getChecker<DependenceAnalysis>();
+ if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>())
+ return true;
+
+ // Check transitive dependencies.
+ return Inv.invalidate<AAManager>(F, PA) ||
+ Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) ||
+ Inv.invalidate<LoopAnalysis>(F, PA);
+}
+
+// depends -
+// Returns NULL if there is no dependence.
+// Otherwise, return a Dependence with as many details as possible.
+// Corresponds to Section 3.1 in the paper
+//
+// Practical Dependence Testing
+// Goff, Kennedy, Tseng
+// PLDI 1991
+//
+// Care is required to keep the routine below, getSplitIteration(),
+// up to date with respect to this routine.
+std::unique_ptr<Dependence>
+DependenceInfo::depends(Instruction *Src, Instruction *Dst,
+ bool PossiblyLoopIndependent) {
+ if (Src == Dst)
+ PossiblyLoopIndependent = false;
+
+ if ((!Src->mayReadFromMemory() && !Src->mayWriteToMemory()) ||
+ (!Dst->mayReadFromMemory() && !Dst->mayWriteToMemory()))
+ // if both instructions don't reference memory, there's no dependence
+ return nullptr;
+
+ if (!isLoadOrStore(Src) || !isLoadOrStore(Dst)) {
+ // can only analyze simple loads and stores, i.e., no calls, invokes, etc.
+ LLVM_DEBUG(dbgs() << "can only handle simple loads and stores\n");
+ return std::make_unique<Dependence>(Src, Dst);
+ }
+
+ assert(isLoadOrStore(Src) && "instruction is not load or store");
+ assert(isLoadOrStore(Dst) && "instruction is not load or store");
+ Value *SrcPtr = getLoadStorePointerOperand(Src);
+ Value *DstPtr = getLoadStorePointerOperand(Dst);
+
+ switch (underlyingObjectsAlias(AA, F->getParent()->getDataLayout(),
+ MemoryLocation::get(Dst),
+ MemoryLocation::get(Src))) {
+ case MayAlias:
+ case PartialAlias:
+ // cannot analyse objects if we don't understand their aliasing.
+ LLVM_DEBUG(dbgs() << "can't analyze may or partial alias\n");
+ return std::make_unique<Dependence>(Src, Dst);
+ case NoAlias:
+ // If the objects noalias, they are distinct, accesses are independent.
+ LLVM_DEBUG(dbgs() << "no alias\n");
+ return nullptr;
+ case MustAlias:
+ break; // The underlying objects alias; test accesses for dependence.
+ }
+
+ // establish loop nesting levels
+ establishNestingLevels(Src, Dst);
+ LLVM_DEBUG(dbgs() << " common nesting levels = " << CommonLevels << "\n");
+ LLVM_DEBUG(dbgs() << " maximum nesting levels = " << MaxLevels << "\n");
+
+ FullDependence Result(Src, Dst, PossiblyLoopIndependent, CommonLevels);
+ ++TotalArrayPairs;
+
+ unsigned Pairs = 1;
+ SmallVector<Subscript, 2> Pair(Pairs);
+ const SCEV *SrcSCEV = SE->getSCEV(SrcPtr);
+ const SCEV *DstSCEV = SE->getSCEV(DstPtr);
+ LLVM_DEBUG(dbgs() << " SrcSCEV = " << *SrcSCEV << "\n");
+ LLVM_DEBUG(dbgs() << " DstSCEV = " << *DstSCEV << "\n");
+ Pair[0].Src = SrcSCEV;
+ Pair[0].Dst = DstSCEV;
+
+ if (Delinearize) {
+ if (tryDelinearize(Src, Dst, Pair)) {
+ LLVM_DEBUG(dbgs() << " delinearized\n");
+ Pairs = Pair.size();
+ }
+ }
+
+ for (unsigned P = 0; P < Pairs; ++P) {
+ Pair[P].Loops.resize(MaxLevels + 1);
+ Pair[P].GroupLoops.resize(MaxLevels + 1);
+ Pair[P].Group.resize(Pairs);
+ removeMatchingExtensions(&Pair[P]);
+ Pair[P].Classification =
+ classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()),
+ Pair[P].Dst, LI->getLoopFor(Dst->getParent()),
+ Pair[P].Loops);
+ Pair[P].GroupLoops = Pair[P].Loops;
+ Pair[P].Group.set(P);
+ LLVM_DEBUG(dbgs() << " subscript " << P << "\n");
+ LLVM_DEBUG(dbgs() << "\tsrc = " << *Pair[P].Src << "\n");
+ LLVM_DEBUG(dbgs() << "\tdst = " << *Pair[P].Dst << "\n");
+ LLVM_DEBUG(dbgs() << "\tclass = " << Pair[P].Classification << "\n");
+ LLVM_DEBUG(dbgs() << "\tloops = ");
+ LLVM_DEBUG(dumpSmallBitVector(Pair[P].Loops));
+ }
+
+ SmallBitVector Separable(Pairs);
+ SmallBitVector Coupled(Pairs);
+
+ // Partition subscripts into separable and minimally-coupled groups
+ // Algorithm in paper is algorithmically better;
+ // this may be faster in practice. Check someday.
+ //
+ // Here's an example of how it works. Consider this code:
+ //
+ // for (i = ...) {
+ // for (j = ...) {
+ // for (k = ...) {
+ // for (l = ...) {
+ // for (m = ...) {
+ // A[i][j][k][m] = ...;
+ // ... = A[0][j][l][i + j];
+ // }
+ // }
+ // }
+ // }
+ // }
+ //
+ // There are 4 subscripts here:
+ // 0 [i] and [0]
+ // 1 [j] and [j]
+ // 2 [k] and [l]
+ // 3 [m] and [i + j]
+ //
+ // We've already classified each subscript pair as ZIV, SIV, etc.,
+ // and collected all the loops mentioned by pair P in Pair[P].Loops.
+ // In addition, we've initialized Pair[P].GroupLoops to Pair[P].Loops
+ // and set Pair[P].Group = {P}.
+ //
+ // Src Dst Classification Loops GroupLoops Group
+ // 0 [i] [0] SIV {1} {1} {0}
+ // 1 [j] [j] SIV {2} {2} {1}
+ // 2 [k] [l] RDIV {3,4} {3,4} {2}
+ // 3 [m] [i + j] MIV {1,2,5} {1,2,5} {3}
+ //
+ // For each subscript SI 0 .. 3, we consider each remaining subscript, SJ.
+ // So, 0 is compared against 1, 2, and 3; 1 is compared against 2 and 3, etc.
+ //
+ // We begin by comparing 0 and 1. The intersection of the GroupLoops is empty.
+ // Next, 0 and 2. Again, the intersection of their GroupLoops is empty.
+ // Next 0 and 3. The intersection of their GroupLoop = {1}, not empty,
+ // so Pair[3].Group = {0,3} and Done = false (that is, 0 will not be added
+ // to either Separable or Coupled).
+ //
+ // Next, we consider 1 and 2. The intersection of the GroupLoops is empty.
+ // Next, 1 and 3. The intersection of their GroupLoops = {2}, not empty,
+ // so Pair[3].Group = {0, 1, 3} and Done = false.
+ //
+ // Next, we compare 2 against 3. The intersection of the GroupLoops is empty.
+ // Since Done remains true, we add 2 to the set of Separable pairs.
+ //
+ // Finally, we consider 3. There's nothing to compare it with,
+ // so Done remains true and we add it to the Coupled set.
+ // Pair[3].Group = {0, 1, 3} and GroupLoops = {1, 2, 5}.
+ //
+ // In the end, we've got 1 separable subscript and 1 coupled group.
+ for (unsigned SI = 0; SI < Pairs; ++SI) {
+ if (Pair[SI].Classification == Subscript::NonLinear) {
+ // ignore these, but collect loops for later
+ ++NonlinearSubscriptPairs;
+ collectCommonLoops(Pair[SI].Src,
+ LI->getLoopFor(Src->getParent()),
+ Pair[SI].Loops);
+ collectCommonLoops(Pair[SI].Dst,
+ LI->getLoopFor(Dst->getParent()),
+ Pair[SI].Loops);
+ Result.Consistent = false;
+ } else if (Pair[SI].Classification == Subscript::ZIV) {
+ // always separable
+ Separable.set(SI);
+ }
+ else {
+ // SIV, RDIV, or MIV, so check for coupled group
+ bool Done = true;
+ for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) {
+ SmallBitVector Intersection = Pair[SI].GroupLoops;
+ Intersection &= Pair[SJ].GroupLoops;
+ if (Intersection.any()) {
+ // accumulate set of all the loops in group
+ Pair[SJ].GroupLoops |= Pair[SI].GroupLoops;
+ // accumulate set of all subscripts in group
+ Pair[SJ].Group |= Pair[SI].Group;
+ Done = false;
+ }
+ }
+ if (Done) {
+ if (Pair[SI].Group.count() == 1) {
+ Separable.set(SI);
+ ++SeparableSubscriptPairs;
+ }
+ else {
+ Coupled.set(SI);
+ ++CoupledSubscriptPairs;
+ }
+ }
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << " Separable = ");
+ LLVM_DEBUG(dumpSmallBitVector(Separable));
+ LLVM_DEBUG(dbgs() << " Coupled = ");
+ LLVM_DEBUG(dumpSmallBitVector(Coupled));
+
+ Constraint NewConstraint;
+ NewConstraint.setAny(SE);
+
+ // test separable subscripts
+ for (unsigned SI : Separable.set_bits()) {
+ LLVM_DEBUG(dbgs() << "testing subscript " << SI);
+ switch (Pair[SI].Classification) {
+ case Subscript::ZIV:
+ LLVM_DEBUG(dbgs() << ", ZIV\n");
+ if (testZIV(Pair[SI].Src, Pair[SI].Dst, Result))
+ return nullptr;
+ break;
+ case Subscript::SIV: {
+ LLVM_DEBUG(dbgs() << ", SIV\n");
+ unsigned Level;
+ const SCEV *SplitIter = nullptr;
+ if (testSIV(Pair[SI].Src, Pair[SI].Dst, Level, Result, NewConstraint,
+ SplitIter))
+ return nullptr;
+ break;
+ }
+ case Subscript::RDIV:
+ LLVM_DEBUG(dbgs() << ", RDIV\n");
+ if (testRDIV(Pair[SI].Src, Pair[SI].Dst, Result))
+ return nullptr;
+ break;
+ case Subscript::MIV:
+ LLVM_DEBUG(dbgs() << ", MIV\n");
+ if (testMIV(Pair[SI].Src, Pair[SI].Dst, Pair[SI].Loops, Result))
+ return nullptr;
+ break;
+ default:
+ llvm_unreachable("subscript has unexpected classification");
+ }
+ }
+
+ if (Coupled.count()) {
+ // test coupled subscript groups
+ LLVM_DEBUG(dbgs() << "starting on coupled subscripts\n");
+ LLVM_DEBUG(dbgs() << "MaxLevels + 1 = " << MaxLevels + 1 << "\n");
+ SmallVector<Constraint, 4> Constraints(MaxLevels + 1);
+ for (unsigned II = 0; II <= MaxLevels; ++II)
+ Constraints[II].setAny(SE);
+ for (unsigned SI : Coupled.set_bits()) {
+ LLVM_DEBUG(dbgs() << "testing subscript group " << SI << " { ");
+ SmallBitVector Group(Pair[SI].Group);
+ SmallBitVector Sivs(Pairs);
+ SmallBitVector Mivs(Pairs);
+ SmallBitVector ConstrainedLevels(MaxLevels + 1);
+ SmallVector<Subscript *, 4> PairsInGroup;
+ for (unsigned SJ : Group.set_bits()) {
+ LLVM_DEBUG(dbgs() << SJ << " ");
+ if (Pair[SJ].Classification == Subscript::SIV)
+ Sivs.set(SJ);
+ else
+ Mivs.set(SJ);
+ PairsInGroup.push_back(&Pair[SJ]);
+ }
+ unifySubscriptType(PairsInGroup);
+ LLVM_DEBUG(dbgs() << "}\n");
+ while (Sivs.any()) {
+ bool Changed = false;
+ for (unsigned SJ : Sivs.set_bits()) {
+ LLVM_DEBUG(dbgs() << "testing subscript " << SJ << ", SIV\n");
+ // SJ is an SIV subscript that's part of the current coupled group
+ unsigned Level;
+ const SCEV *SplitIter = nullptr;
+ LLVM_DEBUG(dbgs() << "SIV\n");
+ if (testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level, Result, NewConstraint,
+ SplitIter))
+ return nullptr;
+ ConstrainedLevels.set(Level);
+ if (intersectConstraints(&Constraints[Level], &NewConstraint)) {
+ if (Constraints[Level].isEmpty()) {
+ ++DeltaIndependence;
+ return nullptr;
+ }
+ Changed = true;
+ }
+ Sivs.reset(SJ);
+ }
+ if (Changed) {
+ // propagate, possibly creating new SIVs and ZIVs
+ LLVM_DEBUG(dbgs() << " propagating\n");
+ LLVM_DEBUG(dbgs() << "\tMivs = ");
+ LLVM_DEBUG(dumpSmallBitVector(Mivs));
+ for (unsigned SJ : Mivs.set_bits()) {
+ // SJ is an MIV subscript that's part of the current coupled group
+ LLVM_DEBUG(dbgs() << "\tSJ = " << SJ << "\n");
+ if (propagate(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops,
+ Constraints, Result.Consistent)) {
+ LLVM_DEBUG(dbgs() << "\t Changed\n");
+ ++DeltaPropagations;
+ Pair[SJ].Classification =
+ classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()),
+ Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()),
+ Pair[SJ].Loops);
+ switch (Pair[SJ].Classification) {
+ case Subscript::ZIV:
+ LLVM_DEBUG(dbgs() << "ZIV\n");
+ if (testZIV(Pair[SJ].Src, Pair[SJ].Dst, Result))
+ return nullptr;
+ Mivs.reset(SJ);
+ break;
+ case Subscript::SIV:
+ Sivs.set(SJ);
+ Mivs.reset(SJ);
+ break;
+ case Subscript::RDIV:
+ case Subscript::MIV:
+ break;
+ default:
+ llvm_unreachable("bad subscript classification");
+ }
+ }
+ }
+ }
+ }
+
+ // test & propagate remaining RDIVs
+ for (unsigned SJ : Mivs.set_bits()) {
+ if (Pair[SJ].Classification == Subscript::RDIV) {
+ LLVM_DEBUG(dbgs() << "RDIV test\n");
+ if (testRDIV(Pair[SJ].Src, Pair[SJ].Dst, Result))
+ return nullptr;
+ // I don't yet understand how to propagate RDIV results
+ Mivs.reset(SJ);
+ }
+ }
+
+ // test remaining MIVs
+ // This code is temporary.
+ // Better to somehow test all remaining subscripts simultaneously.
+ for (unsigned SJ : Mivs.set_bits()) {
+ if (Pair[SJ].Classification == Subscript::MIV) {
+ LLVM_DEBUG(dbgs() << "MIV test\n");
+ if (testMIV(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, Result))
+ return nullptr;
+ }
+ else
+ llvm_unreachable("expected only MIV subscripts at this point");
+ }
+
+ // update Result.DV from constraint vector
+ LLVM_DEBUG(dbgs() << " updating\n");
+ for (unsigned SJ : ConstrainedLevels.set_bits()) {
+ if (SJ > CommonLevels)
+ break;
+ updateDirection(Result.DV[SJ - 1], Constraints[SJ]);
+ if (Result.DV[SJ - 1].Direction == Dependence::DVEntry::NONE)
+ return nullptr;
+ }
+ }
+ }
+
+ // Make sure the Scalar flags are set correctly.
+ SmallBitVector CompleteLoops(MaxLevels + 1);
+ for (unsigned SI = 0; SI < Pairs; ++SI)
+ CompleteLoops |= Pair[SI].Loops;
+ for (unsigned II = 1; II <= CommonLevels; ++II)
+ if (CompleteLoops[II])
+ Result.DV[II - 1].Scalar = false;
+
+ if (PossiblyLoopIndependent) {
+ // Make sure the LoopIndependent flag is set correctly.
+ // All directions must include equal, otherwise no
+ // loop-independent dependence is possible.
+ for (unsigned II = 1; II <= CommonLevels; ++II) {
+ if (!(Result.getDirection(II) & Dependence::DVEntry::EQ)) {
+ Result.LoopIndependent = false;
+ break;
+ }
+ }
+ }
+ else {
+ // On the other hand, if all directions are equal and there's no
+ // loop-independent dependence possible, then no dependence exists.
+ bool AllEqual = true;
+ for (unsigned II = 1; II <= CommonLevels; ++II) {
+ if (Result.getDirection(II) != Dependence::DVEntry::EQ) {
+ AllEqual = false;
+ break;
+ }
+ }
+ if (AllEqual)
+ return nullptr;
+ }
+
+ return std::make_unique<FullDependence>(std::move(Result));
+}
+
+
+
+//===----------------------------------------------------------------------===//
+// getSplitIteration -
+// Rather than spend rarely-used space recording the splitting iteration
+// during the Weak-Crossing SIV test, we re-compute it on demand.
+// The re-computation is basically a repeat of the entire dependence test,
+// though simplified since we know that the dependence exists.
+// It's tedious, since we must go through all propagations, etc.
+//
+// Care is required to keep this code up to date with respect to the routine
+// above, depends().
+//
+// Generally, the dependence analyzer will be used to build
+// a dependence graph for a function (basically a map from instructions
+// to dependences). Looking for cycles in the graph shows us loops
+// that cannot be trivially vectorized/parallelized.
+//
+// We can try to improve the situation by examining all the dependences
+// that make up the cycle, looking for ones we can break.
+// Sometimes, peeling the first or last iteration of a loop will break
+// dependences, and we've got flags for those possibilities.
+// Sometimes, splitting a loop at some other iteration will do the trick,
+// and we've got a flag for that case. Rather than waste the space to
+// record the exact iteration (since we rarely know), we provide
+// a method that calculates the iteration. It's a drag that it must work
+// from scratch, but wonderful in that it's possible.
+//
+// Here's an example:
+//
+// for (i = 0; i < 10; i++)
+// A[i] = ...
+// ... = A[11 - i]
+//
+// There's a loop-carried flow dependence from the store to the load,
+// found by the weak-crossing SIV test. The dependence will have a flag,
+// indicating that the dependence can be broken by splitting the loop.
+// Calling getSplitIteration will return 5.
+// Splitting the loop breaks the dependence, like so:
+//
+// for (i = 0; i <= 5; i++)
+// A[i] = ...
+// ... = A[11 - i]
+// for (i = 6; i < 10; i++)
+// A[i] = ...
+// ... = A[11 - i]
+//
+// breaks the dependence and allows us to vectorize/parallelize
+// both loops.
+const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep,
+ unsigned SplitLevel) {
+ assert(Dep.isSplitable(SplitLevel) &&
+ "Dep should be splitable at SplitLevel");
+ Instruction *Src = Dep.getSrc();
+ Instruction *Dst = Dep.getDst();
+ assert(Src->mayReadFromMemory() || Src->mayWriteToMemory());
+ assert(Dst->mayReadFromMemory() || Dst->mayWriteToMemory());
+ assert(isLoadOrStore(Src));
+ assert(isLoadOrStore(Dst));
+ Value *SrcPtr = getLoadStorePointerOperand(Src);
+ Value *DstPtr = getLoadStorePointerOperand(Dst);
+ assert(underlyingObjectsAlias(AA, F->getParent()->getDataLayout(),
+ MemoryLocation::get(Dst),
+ MemoryLocation::get(Src)) == MustAlias);
+
+ // establish loop nesting levels
+ establishNestingLevels(Src, Dst);
+
+ FullDependence Result(Src, Dst, false, CommonLevels);
+
+ unsigned Pairs = 1;
+ SmallVector<Subscript, 2> Pair(Pairs);
+ const SCEV *SrcSCEV = SE->getSCEV(SrcPtr);
+ const SCEV *DstSCEV = SE->getSCEV(DstPtr);
+ Pair[0].Src = SrcSCEV;
+ Pair[0].Dst = DstSCEV;
+
+ if (Delinearize) {
+ if (tryDelinearize(Src, Dst, Pair)) {
+ LLVM_DEBUG(dbgs() << " delinearized\n");
+ Pairs = Pair.size();
+ }
+ }
+
+ for (unsigned P = 0; P < Pairs; ++P) {
+ Pair[P].Loops.resize(MaxLevels + 1);
+ Pair[P].GroupLoops.resize(MaxLevels + 1);
+ Pair[P].Group.resize(Pairs);
+ removeMatchingExtensions(&Pair[P]);
+ Pair[P].Classification =
+ classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()),
+ Pair[P].Dst, LI->getLoopFor(Dst->getParent()),
+ Pair[P].Loops);
+ Pair[P].GroupLoops = Pair[P].Loops;
+ Pair[P].Group.set(P);
+ }
+
+ SmallBitVector Separable(Pairs);
+ SmallBitVector Coupled(Pairs);
+
+ // partition subscripts into separable and minimally-coupled groups
+ for (unsigned SI = 0; SI < Pairs; ++SI) {
+ if (Pair[SI].Classification == Subscript::NonLinear) {
+ // ignore these, but collect loops for later
+ collectCommonLoops(Pair[SI].Src,
+ LI->getLoopFor(Src->getParent()),
+ Pair[SI].Loops);
+ collectCommonLoops(Pair[SI].Dst,
+ LI->getLoopFor(Dst->getParent()),
+ Pair[SI].Loops);
+ Result.Consistent = false;
+ }
+ else if (Pair[SI].Classification == Subscript::ZIV)
+ Separable.set(SI);
+ else {
+ // SIV, RDIV, or MIV, so check for coupled group
+ bool Done = true;
+ for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) {
+ SmallBitVector Intersection = Pair[SI].GroupLoops;
+ Intersection &= Pair[SJ].GroupLoops;
+ if (Intersection.any()) {
+ // accumulate set of all the loops in group
+ Pair[SJ].GroupLoops |= Pair[SI].GroupLoops;
+ // accumulate set of all subscripts in group
+ Pair[SJ].Group |= Pair[SI].Group;
+ Done = false;
+ }
+ }
+ if (Done) {
+ if (Pair[SI].Group.count() == 1)
+ Separable.set(SI);
+ else
+ Coupled.set(SI);
+ }
+ }
+ }
+
+ Constraint NewConstraint;
+ NewConstraint.setAny(SE);
+
+ // test separable subscripts
+ for (unsigned SI : Separable.set_bits()) {
+ switch (Pair[SI].Classification) {
+ case Subscript::SIV: {
+ unsigned Level;
+ const SCEV *SplitIter = nullptr;
+ (void) testSIV(Pair[SI].Src, Pair[SI].Dst, Level,
+ Result, NewConstraint, SplitIter);
+ if (Level == SplitLevel) {
+ assert(SplitIter != nullptr);
+ return SplitIter;
+ }
+ break;
+ }
+ case Subscript::ZIV:
+ case Subscript::RDIV:
+ case Subscript::MIV:
+ break;
+ default:
+ llvm_unreachable("subscript has unexpected classification");
+ }
+ }
+
+ if (Coupled.count()) {
+ // test coupled subscript groups
+ SmallVector<Constraint, 4> Constraints(MaxLevels + 1);
+ for (unsigned II = 0; II <= MaxLevels; ++II)
+ Constraints[II].setAny(SE);
+ for (unsigned SI : Coupled.set_bits()) {
+ SmallBitVector Group(Pair[SI].Group);
+ SmallBitVector Sivs(Pairs);
+ SmallBitVector Mivs(Pairs);
+ SmallBitVector ConstrainedLevels(MaxLevels + 1);
+ for (unsigned SJ : Group.set_bits()) {
+ if (Pair[SJ].Classification == Subscript::SIV)
+ Sivs.set(SJ);
+ else
+ Mivs.set(SJ);
+ }
+ while (Sivs.any()) {
+ bool Changed = false;
+ for (unsigned SJ : Sivs.set_bits()) {
+ // SJ is an SIV subscript that's part of the current coupled group
+ unsigned Level;
+ const SCEV *SplitIter = nullptr;
+ (void) testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level,
+ Result, NewConstraint, SplitIter);
+ if (Level == SplitLevel && SplitIter)
+ return SplitIter;
+ ConstrainedLevels.set(Level);
+ if (intersectConstraints(&Constraints[Level], &NewConstraint))
+ Changed = true;
+ Sivs.reset(SJ);
+ }
+ if (Changed) {
+ // propagate, possibly creating new SIVs and ZIVs
+ for (unsigned SJ : Mivs.set_bits()) {
+ // SJ is an MIV subscript that's part of the current coupled group
+ if (propagate(Pair[SJ].Src, Pair[SJ].Dst,
+ Pair[SJ].Loops, Constraints, Result.Consistent)) {
+ Pair[SJ].Classification =
+ classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()),
+ Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()),
+ Pair[SJ].Loops);
+ switch (Pair[SJ].Classification) {
+ case Subscript::ZIV:
+ Mivs.reset(SJ);
+ break;
+ case Subscript::SIV:
+ Sivs.set(SJ);
+ Mivs.reset(SJ);
+ break;
+ case Subscript::RDIV:
+ case Subscript::MIV:
+ break;
+ default:
+ llvm_unreachable("bad subscript classification");
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ llvm_unreachable("somehow reached end of routine");
+ return nullptr;
+}
diff --git a/llvm/lib/Analysis/DependenceGraphBuilder.cpp b/llvm/lib/Analysis/DependenceGraphBuilder.cpp
new file mode 100644
index 000000000000..ed1d8351b2f0
--- /dev/null
+++ b/llvm/lib/Analysis/DependenceGraphBuilder.cpp
@@ -0,0 +1,228 @@
+//===- DependenceGraphBuilder.cpp ------------------------------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file implements common steps of the build algorithm for construction
+// of dependence graphs such as DDG and PDG.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DependenceGraphBuilder.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/DDG.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "dgb"
+
+STATISTIC(TotalGraphs, "Number of dependence graphs created.");
+STATISTIC(TotalDefUseEdges, "Number of def-use edges created.");
+STATISTIC(TotalMemoryEdges, "Number of memory dependence edges created.");
+STATISTIC(TotalFineGrainedNodes, "Number of fine-grained nodes created.");
+STATISTIC(TotalConfusedEdges,
+ "Number of confused memory dependencies between two nodes.");
+STATISTIC(TotalEdgeReversals,
+ "Number of times the source and sink of dependence was reversed to "
+ "expose cycles in the graph.");
+
+using InstructionListType = SmallVector<Instruction *, 2>;
+
+//===--------------------------------------------------------------------===//
+// AbstractDependenceGraphBuilder implementation
+//===--------------------------------------------------------------------===//
+
+template <class G>
+void AbstractDependenceGraphBuilder<G>::createFineGrainedNodes() {
+ ++TotalGraphs;
+ assert(IMap.empty() && "Expected empty instruction map at start");
+ for (BasicBlock *BB : BBList)
+ for (Instruction &I : *BB) {
+ auto &NewNode = createFineGrainedNode(I);
+ IMap.insert(std::make_pair(&I, &NewNode));
+ ++TotalFineGrainedNodes;
+ }
+}
+
+template <class G>
+void AbstractDependenceGraphBuilder<G>::createAndConnectRootNode() {
+ // Create a root node that connects to every connected component of the graph.
+ // This is done to allow graph iterators to visit all the disjoint components
+ // of the graph, in a single walk.
+ //
+ // This algorithm works by going through each node of the graph and for each
+ // node N, do a DFS starting from N. A rooted edge is established between the
+ // root node and N (if N is not yet visited). All the nodes reachable from N
+ // are marked as visited and are skipped in the DFS of subsequent nodes.
+ //
+ // Note: This algorithm tries to limit the number of edges out of the root
+ // node to some extent, but there may be redundant edges created depending on
+ // the iteration order. For example for a graph {A -> B}, an edge from the
+ // root node is added to both nodes if B is visited before A. While it does
+ // not result in minimal number of edges, this approach saves compile-time
+ // while keeping the number of edges in check.
+ auto &RootNode = createRootNode();
+ df_iterator_default_set<const NodeType *, 4> Visited;
+ for (auto *N : Graph) {
+ if (*N == RootNode)
+ continue;
+ for (auto I : depth_first_ext(N, Visited))
+ if (I == N)
+ createRootedEdge(RootNode, *N);
+ }
+}
+
+template <class G> void AbstractDependenceGraphBuilder<G>::createDefUseEdges() {
+ for (NodeType *N : Graph) {
+ InstructionListType SrcIList;
+ N->collectInstructions([](const Instruction *I) { return true; }, SrcIList);
+
+ // Use a set to mark the targets that we link to N, so we don't add
+ // duplicate def-use edges when more than one instruction in a target node
+ // use results of instructions that are contained in N.
+ SmallPtrSet<NodeType *, 4> VisitedTargets;
+
+ for (Instruction *II : SrcIList) {
+ for (User *U : II->users()) {
+ Instruction *UI = dyn_cast<Instruction>(U);
+ if (!UI)
+ continue;
+ NodeType *DstNode = nullptr;
+ if (IMap.find(UI) != IMap.end())
+ DstNode = IMap.find(UI)->second;
+
+ // In the case of loops, the scope of the subgraph is all the
+ // basic blocks (and instructions within them) belonging to the loop. We
+ // simply ignore all the edges coming from (or going into) instructions
+ // or basic blocks outside of this range.
+ if (!DstNode) {
+ LLVM_DEBUG(
+ dbgs()
+ << "skipped def-use edge since the sink" << *UI
+ << " is outside the range of instructions being considered.\n");
+ continue;
+ }
+
+ // Self dependencies are ignored because they are redundant and
+ // uninteresting.
+ if (DstNode == N) {
+ LLVM_DEBUG(dbgs()
+ << "skipped def-use edge since the sink and the source ("
+ << N << ") are the same.\n");
+ continue;
+ }
+
+ if (VisitedTargets.insert(DstNode).second) {
+ createDefUseEdge(*N, *DstNode);
+ ++TotalDefUseEdges;
+ }
+ }
+ }
+ }
+}
+
+template <class G>
+void AbstractDependenceGraphBuilder<G>::createMemoryDependencyEdges() {
+ using DGIterator = typename G::iterator;
+ auto isMemoryAccess = [](const Instruction *I) {
+ return I->mayReadOrWriteMemory();
+ };
+ for (DGIterator SrcIt = Graph.begin(), E = Graph.end(); SrcIt != E; ++SrcIt) {
+ InstructionListType SrcIList;
+ (*SrcIt)->collectInstructions(isMemoryAccess, SrcIList);
+ if (SrcIList.empty())
+ continue;
+
+ for (DGIterator DstIt = SrcIt; DstIt != E; ++DstIt) {
+ if (**SrcIt == **DstIt)
+ continue;
+ InstructionListType DstIList;
+ (*DstIt)->collectInstructions(isMemoryAccess, DstIList);
+ if (DstIList.empty())
+ continue;
+ bool ForwardEdgeCreated = false;
+ bool BackwardEdgeCreated = false;
+ for (Instruction *ISrc : SrcIList) {
+ for (Instruction *IDst : DstIList) {
+ auto D = DI.depends(ISrc, IDst, true);
+ if (!D)
+ continue;
+
+ // If we have a dependence with its left-most non-'=' direction
+ // being '>' we need to reverse the direction of the edge, because
+ // the source of the dependence cannot occur after the sink. For
+ // confused dependencies, we will create edges in both directions to
+ // represent the possibility of a cycle.
+
+ auto createConfusedEdges = [&](NodeType &Src, NodeType &Dst) {
+ if (!ForwardEdgeCreated) {
+ createMemoryEdge(Src, Dst);
+ ++TotalMemoryEdges;
+ }
+ if (!BackwardEdgeCreated) {
+ createMemoryEdge(Dst, Src);
+ ++TotalMemoryEdges;
+ }
+ ForwardEdgeCreated = BackwardEdgeCreated = true;
+ ++TotalConfusedEdges;
+ };
+
+ auto createForwardEdge = [&](NodeType &Src, NodeType &Dst) {
+ if (!ForwardEdgeCreated) {
+ createMemoryEdge(Src, Dst);
+ ++TotalMemoryEdges;
+ }
+ ForwardEdgeCreated = true;
+ };
+
+ auto createBackwardEdge = [&](NodeType &Src, NodeType &Dst) {
+ if (!BackwardEdgeCreated) {
+ createMemoryEdge(Dst, Src);
+ ++TotalMemoryEdges;
+ }
+ BackwardEdgeCreated = true;
+ };
+
+ if (D->isConfused())
+ createConfusedEdges(**SrcIt, **DstIt);
+ else if (D->isOrdered() && !D->isLoopIndependent()) {
+ bool ReversedEdge = false;
+ for (unsigned Level = 1; Level <= D->getLevels(); ++Level) {
+ if (D->getDirection(Level) == Dependence::DVEntry::EQ)
+ continue;
+ else if (D->getDirection(Level) == Dependence::DVEntry::GT) {
+ createBackwardEdge(**SrcIt, **DstIt);
+ ReversedEdge = true;
+ ++TotalEdgeReversals;
+ break;
+ } else if (D->getDirection(Level) == Dependence::DVEntry::LT)
+ break;
+ else {
+ createConfusedEdges(**SrcIt, **DstIt);
+ break;
+ }
+ }
+ if (!ReversedEdge)
+ createForwardEdge(**SrcIt, **DstIt);
+ } else
+ createForwardEdge(**SrcIt, **DstIt);
+
+ // Avoid creating duplicate edges.
+ if (ForwardEdgeCreated && BackwardEdgeCreated)
+ break;
+ }
+
+ // If we've created edges in both directions, there is no more
+ // unique edge that we can create between these two nodes, so we
+ // can exit early.
+ if (ForwardEdgeCreated && BackwardEdgeCreated)
+ break;
+ }
+ }
+ }
+}
+
+template class llvm::AbstractDependenceGraphBuilder<DataDependenceGraph>;
+template class llvm::DependenceGraphInfo<DDGNode>;
diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp
new file mode 100644
index 000000000000..3d1be1e1cce0
--- /dev/null
+++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp
@@ -0,0 +1,466 @@
+//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a general divergence analysis for loop vectorization
+// and GPU programs. It determines which branches and values in a loop or GPU
+// program are divergent. It can help branch optimizations such as jump
+// threading and loop unswitching to make better decisions.
+//
+// GPU programs typically use the SIMD execution model, where multiple threads
+// in the same execution group have to execute in lock-step. Therefore, if the
+// code contains divergent branches (i.e., threads in a group do not agree on
+// which path of the branch to take), the group of threads has to execute all
+// the paths from that branch with different subsets of threads enabled until
+// they re-converge.
+//
+// Due to this execution model, some optimizations such as jump
+// threading and loop unswitching can interfere with thread re-convergence.
+// Therefore, an analysis that computes which branches in a GPU program are
+// divergent can help the compiler to selectively run these optimizations.
+//
+// This implementation is derived from the Vectorization Analysis of the
+// Region Vectorizer (RV). That implementation in turn is based on the approach
+// described in
+//
+// Improving Performance of OpenCL on CPUs
+// Ralf Karrenberg and Sebastian Hack
+// CC '12
+//
+// This DivergenceAnalysis implementation is generic in the sense that it does
+// not itself identify original sources of divergence.
+// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
+// (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence
+// (e.g., special variables that hold the thread ID or the iteration variable).
+//
+// The generic implementation propagates divergence to variables that are data
+// or sync dependent on a source of divergence.
+//
+// While data dependency is a well-known concept, the notion of sync dependency
+// is worth more explanation. Sync dependence characterizes the control flow
+// aspect of the propagation of branch divergence. For example,
+//
+// %cond = icmp slt i32 %tid, 10
+// br i1 %cond, label %then, label %else
+// then:
+// br label %merge
+// else:
+// br label %merge
+// merge:
+// %a = phi i32 [ 0, %then ], [ 1, %else ]
+//
+// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
+// because %tid is not on its use-def chains, %a is sync dependent on %tid
+// because the branch "br i1 %cond" depends on %tid and affects which value %a
+// is assigned to.
+//
+// The sync dependence detection (which branch induces divergence in which join
+// points) is implemented in the SyncDependenceAnalysis.
+//
+// The current DivergenceAnalysis implementation has the following limitations:
+// 1. intra-procedural. It conservatively considers the arguments of a
+// non-kernel-entry function and the return value of a function call as
+// divergent.
+// 2. memory as black box. It conservatively considers values loaded from
+// generic or local address as divergent. This can be improved by leveraging
+// pointer analysis and/or by modelling non-escaping memory objects in SSA
+// as done in RV.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DivergenceAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "divergence-analysis"
+
+// class DivergenceAnalysis
+DivergenceAnalysis::DivergenceAnalysis(
+ const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
+ const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
+ : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
+ IsLCSSAForm(IsLCSSAForm) {}
+
+void DivergenceAnalysis::markDivergent(const Value &DivVal) {
+ assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
+ assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
+ DivergentValues.insert(&DivVal);
+}
+
+void DivergenceAnalysis::addUniformOverride(const Value &UniVal) {
+ UniformOverrides.insert(&UniVal);
+}
+
+bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const {
+ if (Term.getNumSuccessors() <= 1)
+ return false;
+ if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) {
+ assert(BranchTerm->isConditional());
+ return isDivergent(*BranchTerm->getCondition());
+ }
+ if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) {
+ return isDivergent(*SwitchTerm->getCondition());
+ }
+ if (isa<InvokeInst>(Term)) {
+ return false; // ignore abnormal executions through landingpad
+ }
+
+ llvm_unreachable("unexpected terminator");
+}
+
+bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const {
+ // TODO function calls with side effects, etc
+ for (const auto &Op : I.operands()) {
+ if (isDivergent(*Op))
+ return true;
+ }
+ return false;
+}
+
+bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
+ const Value &Val) const {
+ const auto *Inst = dyn_cast<const Instruction>(&Val);
+ if (!Inst)
+ return false;
+ // check whether any divergent loop carrying Val terminates before control
+ // proceeds to ObservingBlock
+ for (const auto *Loop = LI.getLoopFor(Inst->getParent());
+ Loop != RegionLoop && !Loop->contains(&ObservingBlock);
+ Loop = Loop->getParentLoop()) {
+ if (DivergentLoops.find(Loop) != DivergentLoops.end())
+ return true;
+ }
+
+ return false;
+}
+
+bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const {
+ // joining divergent disjoint path in Phi parent block
+ if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) {
+ return true;
+ }
+
+ // An incoming value could be divergent by itself.
+ // Otherwise, an incoming value could be uniform within the loop
+ // that carries its definition but it may appear divergent
+ // from outside the loop. This happens when divergent loop exits
+ // drop definitions of that uniform value in different iterations.
+ //
+ // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop
+ // if (i % thread_id == 0) break; // divergent loop exit
+ // }
+ // int divI = i; // divI is divergent
+ for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) {
+ const auto *InVal = Phi.getIncomingValue(i);
+ if (isDivergent(*Phi.getIncomingValue(i)) ||
+ isTemporalDivergent(*Phi.getParent(), *InVal)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DivergenceAnalysis::inRegion(const Instruction &I) const {
+ return I.getParent() && inRegion(*I.getParent());
+}
+
+bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
+ return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
+}
+
+// marks all users of loop-carried values of the loop headed by LoopHeader as
+// divergent
+void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
+ auto *DivLoop = LI.getLoopFor(&LoopHeader);
+ assert(DivLoop && "loopHeader is not actually part of a loop");
+
+ SmallVector<BasicBlock *, 8> TaintStack;
+ DivLoop->getExitBlocks(TaintStack);
+
+ // Otherwise potential users of loop-carried values could be anywhere in the
+ // dominance region of DivLoop (including its fringes for phi nodes)
+ DenseSet<const BasicBlock *> Visited;
+ for (auto *Block : TaintStack) {
+ Visited.insert(Block);
+ }
+ Visited.insert(&LoopHeader);
+
+ while (!TaintStack.empty()) {
+ auto *UserBlock = TaintStack.back();
+ TaintStack.pop_back();
+
+ // don't spread divergence beyond the region
+ if (!inRegion(*UserBlock))
+ continue;
+
+ assert(!DivLoop->contains(UserBlock) &&
+ "irreducible control flow detected");
+
+ // phi nodes at the fringes of the dominance region
+ if (!DT.dominates(&LoopHeader, UserBlock)) {
+ // all PHI nodes of UserBlock become divergent
+ for (auto &Phi : UserBlock->phis()) {
+ Worklist.push_back(&Phi);
+ }
+ continue;
+ }
+
+ // taint outside users of values carried by DivLoop
+ for (auto &I : *UserBlock) {
+ if (isAlwaysUniform(I))
+ continue;
+ if (isDivergent(I))
+ continue;
+
+ for (auto &Op : I.operands()) {
+ auto *OpInst = dyn_cast<Instruction>(&Op);
+ if (!OpInst)
+ continue;
+ if (DivLoop->contains(OpInst->getParent())) {
+ markDivergent(I);
+ pushUsers(I);
+ break;
+ }
+ }
+ }
+
+ // visit all blocks in the dominance region
+ for (auto *SuccBlock : successors(UserBlock)) {
+ if (!Visited.insert(SuccBlock).second) {
+ continue;
+ }
+ TaintStack.push_back(SuccBlock);
+ }
+ }
+}
+
+void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) {
+ for (const auto &Phi : Block.phis()) {
+ if (isDivergent(Phi))
+ continue;
+ Worklist.push_back(&Phi);
+ }
+}
+
+void DivergenceAnalysis::pushUsers(const Value &V) {
+ for (const auto *User : V.users()) {
+ const auto *UserInst = dyn_cast<const Instruction>(User);
+ if (!UserInst)
+ continue;
+
+ if (isDivergent(*UserInst))
+ continue;
+
+ // only compute divergent inside loop
+ if (!inRegion(*UserInst))
+ continue;
+ Worklist.push_back(UserInst);
+ }
+}
+
+bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock,
+ const Loop *BranchLoop) {
+ LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n");
+
+ // ignore divergence outside the region
+ if (!inRegion(JoinBlock)) {
+ return false;
+ }
+
+ // push non-divergent phi nodes in JoinBlock to the worklist
+ pushPHINodes(JoinBlock);
+
+ // JoinBlock is a divergent loop exit
+ if (BranchLoop && !BranchLoop->contains(&JoinBlock)) {
+ return true;
+ }
+
+ // disjoint-paths divergent at JoinBlock
+ markBlockJoinDivergent(JoinBlock);
+ return false;
+}
+
+void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
+ LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n");
+
+ markDivergent(Term);
+
+ const auto *BranchLoop = LI.getLoopFor(Term.getParent());
+
+ // whether there is a divergent loop exit from BranchLoop (if any)
+ bool IsBranchLoopDivergent = false;
+
+ // iterate over all blocks reachable by disjoint from Term within the loop
+ // also iterates over loop exits that become divergent due to Term.
+ for (const auto *JoinBlock : SDA.join_blocks(Term)) {
+ IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
+ }
+
+ // Branch loop is a divergent loop due to the divergent branch in Term
+ if (IsBranchLoopDivergent) {
+ assert(BranchLoop);
+ if (!DivergentLoops.insert(BranchLoop).second) {
+ return;
+ }
+ propagateLoopDivergence(*BranchLoop);
+ }
+}
+
+void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) {
+ LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n");
+
+ // don't propagate beyond region
+ if (!inRegion(*ExitingLoop.getHeader()))
+ return;
+
+ const auto *BranchLoop = ExitingLoop.getParentLoop();
+
+ // Uses of loop-carried values could occur anywhere
+ // within the dominance region of the definition. All loop-carried
+ // definitions are dominated by the loop header (reducible control).
+ // Thus all users have to be in the dominance region of the loop header,
+ // except PHI nodes that can also live at the fringes of the dom region
+ // (incoming defining value).
+ if (!IsLCSSAForm)
+ taintLoopLiveOuts(*ExitingLoop.getHeader());
+
+ // whether there is a divergent loop exit from BranchLoop (if any)
+ bool IsBranchLoopDivergent = false;
+
+ // iterate over all blocks reachable by disjoint paths from exits of
+ // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn
+ // become divergent.
+ for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) {
+ IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
+ }
+
+ // Branch loop is a divergent due to divergent loop exit in ExitingLoop
+ if (IsBranchLoopDivergent) {
+ assert(BranchLoop);
+ if (!DivergentLoops.insert(BranchLoop).second) {
+ return;
+ }
+ propagateLoopDivergence(*BranchLoop);
+ }
+}
+
+void DivergenceAnalysis::compute() {
+ for (auto *DivVal : DivergentValues) {
+ pushUsers(*DivVal);
+ }
+
+ // propagate divergence
+ while (!Worklist.empty()) {
+ const Instruction &I = *Worklist.back();
+ Worklist.pop_back();
+
+ // maintain uniformity of overrides
+ if (isAlwaysUniform(I))
+ continue;
+
+ bool WasDivergent = isDivergent(I);
+ if (WasDivergent)
+ continue;
+
+ // propagate divergence caused by terminator
+ if (I.isTerminator()) {
+ if (updateTerminator(I)) {
+ // propagate control divergence to affected instructions
+ propagateBranchDivergence(I);
+ continue;
+ }
+ }
+
+ // update divergence of I due to divergent operands
+ bool DivergentUpd = false;
+ const auto *Phi = dyn_cast<const PHINode>(&I);
+ if (Phi) {
+ DivergentUpd = updatePHINode(*Phi);
+ } else {
+ DivergentUpd = updateNormalInstruction(I);
+ }
+
+ // propagate value divergence to users
+ if (DivergentUpd) {
+ markDivergent(I);
+ pushUsers(I);
+ }
+ }
+}
+
+bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const {
+ return UniformOverrides.find(&V) != UniformOverrides.end();
+}
+
+bool DivergenceAnalysis::isDivergent(const Value &V) const {
+ return DivergentValues.find(&V) != DivergentValues.end();
+}
+
+bool DivergenceAnalysis::isDivergentUse(const Use &U) const {
+ Value &V = *U.get();
+ Instruction &I = *cast<Instruction>(U.getUser());
+ return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
+}
+
+void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
+ if (DivergentValues.empty())
+ return;
+ // iterate instructions using instructions() to ensure a deterministic order.
+ for (auto &I : instructions(F)) {
+ if (isDivergent(I))
+ OS << "DIVERGENT:" << I << '\n';
+ }
+}
+
+// class GPUDivergenceAnalysis
+GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F,
+ const DominatorTree &DT,
+ const PostDominatorTree &PDT,
+ const LoopInfo &LI,
+ const TargetTransformInfo &TTI)
+ : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) {
+ for (auto &I : instructions(F)) {
+ if (TTI.isSourceOfDivergence(&I)) {
+ DA.markDivergent(I);
+ } else if (TTI.isAlwaysUniform(&I)) {
+ DA.addUniformOverride(I);
+ }
+ }
+ for (auto &Arg : F.args()) {
+ if (TTI.isSourceOfDivergence(&Arg)) {
+ DA.markDivergent(Arg);
+ }
+ }
+
+ DA.compute();
+}
+
+bool GPUDivergenceAnalysis::isDivergent(const Value &val) const {
+ return DA.isDivergent(val);
+}
+
+bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const {
+ return DA.isDivergentUse(use);
+}
+
+void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const {
+ OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n";
+ DA.print(OS, mod);
+ OS << "}\n";
+}
diff --git a/llvm/lib/Analysis/DomPrinter.cpp b/llvm/lib/Analysis/DomPrinter.cpp
new file mode 100644
index 000000000000..d9f43dd746ef
--- /dev/null
+++ b/llvm/lib/Analysis/DomPrinter.cpp
@@ -0,0 +1,297 @@
+//===- DomPrinter.cpp - DOT printer for the dominance trees ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines '-dot-dom' and '-dot-postdom' analysis passes, which emit
+// a dom.<fnname>.dot or postdom.<fnname>.dot file for each function in the
+// program, with a graph of the dominance/postdominance tree of that
+// function.
+//
+// There are also passes available to directly call dotty ('-view-dom' or
+// '-view-postdom'). By appending '-only' like '-dot-dom-only' only the
+// names of the bbs are printed, but the content is hidden.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DomPrinter.h"
+#include "llvm/Analysis/DOTGraphTraitsPass.h"
+#include "llvm/Analysis/PostDominators.h"
+
+using namespace llvm;
+
+namespace llvm {
+template<>
+struct DOTGraphTraits<DomTreeNode*> : public DefaultDOTGraphTraits {
+
+ DOTGraphTraits (bool isSimple=false)
+ : DefaultDOTGraphTraits(isSimple) {}
+
+ std::string getNodeLabel(DomTreeNode *Node, DomTreeNode *Graph) {
+
+ BasicBlock *BB = Node->getBlock();
+
+ if (!BB)
+ return "Post dominance root node";
+
+
+ if (isSimple())
+ return DOTGraphTraits<const Function*>
+ ::getSimpleNodeLabel(BB, BB->getParent());
+ else
+ return DOTGraphTraits<const Function*>
+ ::getCompleteNodeLabel(BB, BB->getParent());
+ }
+};
+
+template<>
+struct DOTGraphTraits<DominatorTree*> : public DOTGraphTraits<DomTreeNode*> {
+
+ DOTGraphTraits (bool isSimple=false)
+ : DOTGraphTraits<DomTreeNode*>(isSimple) {}
+
+ static std::string getGraphName(DominatorTree *DT) {
+ return "Dominator tree";
+ }
+
+ std::string getNodeLabel(DomTreeNode *Node, DominatorTree *G) {
+ return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode());
+ }
+};
+
+template<>
+struct DOTGraphTraits<PostDominatorTree*>
+ : public DOTGraphTraits<DomTreeNode*> {
+
+ DOTGraphTraits (bool isSimple=false)
+ : DOTGraphTraits<DomTreeNode*>(isSimple) {}
+
+ static std::string getGraphName(PostDominatorTree *DT) {
+ return "Post dominator tree";
+ }
+
+ std::string getNodeLabel(DomTreeNode *Node, PostDominatorTree *G ) {
+ return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode());
+ }
+};
+}
+
+void DominatorTree::viewGraph(const Twine &Name, const Twine &Title) {
+#ifndef NDEBUG
+ ViewGraph(this, Name, false, Title);
+#else
+ errs() << "DomTree dump not available, build with DEBUG\n";
+#endif // NDEBUG
+}
+
+void DominatorTree::viewGraph() {
+#ifndef NDEBUG
+ this->viewGraph("domtree", "Dominator Tree for function");
+#else
+ errs() << "DomTree dump not available, build with DEBUG\n";
+#endif // NDEBUG
+}
+
+namespace {
+struct DominatorTreeWrapperPassAnalysisGraphTraits {
+ static DominatorTree *getGraph(DominatorTreeWrapperPass *DTWP) {
+ return &DTWP->getDomTree();
+ }
+};
+
+struct DomViewer : public DOTGraphTraitsViewer<
+ DominatorTreeWrapperPass, false, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ DomViewer()
+ : DOTGraphTraitsViewer<DominatorTreeWrapperPass, false, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits>(
+ "dom", ID) {
+ initializeDomViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct DomOnlyViewer : public DOTGraphTraitsViewer<
+ DominatorTreeWrapperPass, true, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ DomOnlyViewer()
+ : DOTGraphTraitsViewer<DominatorTreeWrapperPass, true, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits>(
+ "domonly", ID) {
+ initializeDomOnlyViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct PostDominatorTreeWrapperPassAnalysisGraphTraits {
+ static PostDominatorTree *getGraph(PostDominatorTreeWrapperPass *PDTWP) {
+ return &PDTWP->getPostDomTree();
+ }
+};
+
+struct PostDomViewer : public DOTGraphTraitsViewer<
+ PostDominatorTreeWrapperPass, false,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ PostDomViewer() :
+ DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, false,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits>(
+ "postdom", ID){
+ initializePostDomViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct PostDomOnlyViewer : public DOTGraphTraitsViewer<
+ PostDominatorTreeWrapperPass, true,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ PostDomOnlyViewer() :
+ DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, true,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits>(
+ "postdomonly", ID){
+ initializePostDomOnlyViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+} // end anonymous namespace
+
+char DomViewer::ID = 0;
+INITIALIZE_PASS(DomViewer, "view-dom",
+ "View dominance tree of function", false, false)
+
+char DomOnlyViewer::ID = 0;
+INITIALIZE_PASS(DomOnlyViewer, "view-dom-only",
+ "View dominance tree of function (with no function bodies)",
+ false, false)
+
+char PostDomViewer::ID = 0;
+INITIALIZE_PASS(PostDomViewer, "view-postdom",
+ "View postdominance tree of function", false, false)
+
+char PostDomOnlyViewer::ID = 0;
+INITIALIZE_PASS(PostDomOnlyViewer, "view-postdom-only",
+ "View postdominance tree of function "
+ "(with no function bodies)",
+ false, false)
+
+namespace {
+struct DomPrinter : public DOTGraphTraitsPrinter<
+ DominatorTreeWrapperPass, false, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ DomPrinter()
+ : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, false, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits>(
+ "dom", ID) {
+ initializeDomPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct DomOnlyPrinter : public DOTGraphTraitsPrinter<
+ DominatorTreeWrapperPass, true, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ DomOnlyPrinter()
+ : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, true, DominatorTree *,
+ DominatorTreeWrapperPassAnalysisGraphTraits>(
+ "domonly", ID) {
+ initializeDomOnlyPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct PostDomPrinter
+ : public DOTGraphTraitsPrinter<
+ PostDominatorTreeWrapperPass, false,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ PostDomPrinter() :
+ DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, false,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits>(
+ "postdom", ID) {
+ initializePostDomPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct PostDomOnlyPrinter
+ : public DOTGraphTraitsPrinter<
+ PostDominatorTreeWrapperPass, true,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits> {
+ static char ID;
+ PostDomOnlyPrinter() :
+ DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, true,
+ PostDominatorTree *,
+ PostDominatorTreeWrapperPassAnalysisGraphTraits>(
+ "postdomonly", ID) {
+ initializePostDomOnlyPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+} // end anonymous namespace
+
+
+
+char DomPrinter::ID = 0;
+INITIALIZE_PASS(DomPrinter, "dot-dom",
+ "Print dominance tree of function to 'dot' file",
+ false, false)
+
+char DomOnlyPrinter::ID = 0;
+INITIALIZE_PASS(DomOnlyPrinter, "dot-dom-only",
+ "Print dominance tree of function to 'dot' file "
+ "(with no function bodies)",
+ false, false)
+
+char PostDomPrinter::ID = 0;
+INITIALIZE_PASS(PostDomPrinter, "dot-postdom",
+ "Print postdominance tree of function to 'dot' file",
+ false, false)
+
+char PostDomOnlyPrinter::ID = 0;
+INITIALIZE_PASS(PostDomOnlyPrinter, "dot-postdom-only",
+ "Print postdominance tree of function to 'dot' file "
+ "(with no function bodies)",
+ false, false)
+
+// Create methods available outside of this file, to use them
+// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
+// the link time optimization.
+
+FunctionPass *llvm::createDomPrinterPass() {
+ return new DomPrinter();
+}
+
+FunctionPass *llvm::createDomOnlyPrinterPass() {
+ return new DomOnlyPrinter();
+}
+
+FunctionPass *llvm::createDomViewerPass() {
+ return new DomViewer();
+}
+
+FunctionPass *llvm::createDomOnlyViewerPass() {
+ return new DomOnlyViewer();
+}
+
+FunctionPass *llvm::createPostDomPrinterPass() {
+ return new PostDomPrinter();
+}
+
+FunctionPass *llvm::createPostDomOnlyPrinterPass() {
+ return new PostDomOnlyPrinter();
+}
+
+FunctionPass *llvm::createPostDomViewerPass() {
+ return new PostDomViewer();
+}
+
+FunctionPass *llvm::createPostDomOnlyViewerPass() {
+ return new PostDomOnlyViewer();
+}
diff --git a/llvm/lib/Analysis/DomTreeUpdater.cpp b/llvm/lib/Analysis/DomTreeUpdater.cpp
new file mode 100644
index 000000000000..49215889cfd6
--- /dev/null
+++ b/llvm/lib/Analysis/DomTreeUpdater.cpp
@@ -0,0 +1,533 @@
+//===- DomTreeUpdater.cpp - DomTree/Post DomTree Updater --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the DomTreeUpdater class, which provides a uniform way
+// to update dominator tree related data structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/Support/GenericDomTree.h"
+#include <algorithm>
+#include <functional>
+#include <utility>
+
+namespace llvm {
+
+bool DomTreeUpdater::isUpdateValid(
+ const DominatorTree::UpdateType Update) const {
+ const auto *From = Update.getFrom();
+ const auto *To = Update.getTo();
+ const auto Kind = Update.getKind();
+
+ // Discard updates by inspecting the current state of successors of From.
+ // Since isUpdateValid() must be called *after* the Terminator of From is
+ // altered we can determine if the update is unnecessary for batch updates
+ // or invalid for a single update.
+ const bool HasEdge = llvm::any_of(
+ successors(From), [To](const BasicBlock *B) { return B == To; });
+
+ // If the IR does not match the update,
+ // 1. In batch updates, this update is unnecessary.
+ // 2. When called by insertEdge*()/deleteEdge*(), this update is invalid.
+ // Edge does not exist in IR.
+ if (Kind == DominatorTree::Insert && !HasEdge)
+ return false;
+
+ // Edge exists in IR.
+ if (Kind == DominatorTree::Delete && HasEdge)
+ return false;
+
+ return true;
+}
+
+bool DomTreeUpdater::isSelfDominance(
+ const DominatorTree::UpdateType Update) const {
+ // Won't affect DomTree and PostDomTree.
+ return Update.getFrom() == Update.getTo();
+}
+
+void DomTreeUpdater::applyDomTreeUpdates() {
+ // No pending DomTreeUpdates.
+ if (Strategy != UpdateStrategy::Lazy || !DT)
+ return;
+
+ // Only apply updates not are applied by DomTree.
+ if (hasPendingDomTreeUpdates()) {
+ const auto I = PendUpdates.begin() + PendDTUpdateIndex;
+ const auto E = PendUpdates.end();
+ assert(I < E && "Iterator range invalid; there should be DomTree updates.");
+ DT->applyUpdates(ArrayRef<DominatorTree::UpdateType>(I, E));
+ PendDTUpdateIndex = PendUpdates.size();
+ }
+}
+
+void DomTreeUpdater::flush() {
+ applyDomTreeUpdates();
+ applyPostDomTreeUpdates();
+ dropOutOfDateUpdates();
+}
+
+void DomTreeUpdater::applyPostDomTreeUpdates() {
+ // No pending PostDomTreeUpdates.
+ if (Strategy != UpdateStrategy::Lazy || !PDT)
+ return;
+
+ // Only apply updates not are applied by PostDomTree.
+ if (hasPendingPostDomTreeUpdates()) {
+ const auto I = PendUpdates.begin() + PendPDTUpdateIndex;
+ const auto E = PendUpdates.end();
+ assert(I < E &&
+ "Iterator range invalid; there should be PostDomTree updates.");
+ PDT->applyUpdates(ArrayRef<DominatorTree::UpdateType>(I, E));
+ PendPDTUpdateIndex = PendUpdates.size();
+ }
+}
+
+void DomTreeUpdater::tryFlushDeletedBB() {
+ if (!hasPendingUpdates())
+ forceFlushDeletedBB();
+}
+
+bool DomTreeUpdater::forceFlushDeletedBB() {
+ if (DeletedBBs.empty())
+ return false;
+
+ for (auto *BB : DeletedBBs) {
+ // After calling deleteBB or callbackDeleteBB under Lazy UpdateStrategy,
+ // validateDeleteBB() removes all instructions of DelBB and adds an
+ // UnreachableInst as its terminator. So we check whether the BasicBlock to
+ // delete only has an UnreachableInst inside.
+ assert(BB->getInstList().size() == 1 &&
+ isa<UnreachableInst>(BB->getTerminator()) &&
+ "DelBB has been modified while awaiting deletion.");
+ BB->removeFromParent();
+ eraseDelBBNode(BB);
+ delete BB;
+ }
+ DeletedBBs.clear();
+ Callbacks.clear();
+ return true;
+}
+
+void DomTreeUpdater::recalculate(Function &F) {
+
+ if (Strategy == UpdateStrategy::Eager) {
+ if (DT)
+ DT->recalculate(F);
+ if (PDT)
+ PDT->recalculate(F);
+ return;
+ }
+
+ // There is little performance gain if we pend the recalculation under
+ // Lazy UpdateStrategy so we recalculate available trees immediately.
+
+ // Prevent forceFlushDeletedBB() from erasing DomTree or PostDomTree nodes.
+ IsRecalculatingDomTree = IsRecalculatingPostDomTree = true;
+
+ // Because all trees are going to be up-to-date after recalculation,
+ // flush awaiting deleted BasicBlocks.
+ forceFlushDeletedBB();
+ if (DT)
+ DT->recalculate(F);
+ if (PDT)
+ PDT->recalculate(F);
+
+ // Resume forceFlushDeletedBB() to erase DomTree or PostDomTree nodes.
+ IsRecalculatingDomTree = IsRecalculatingPostDomTree = false;
+ PendDTUpdateIndex = PendPDTUpdateIndex = PendUpdates.size();
+ dropOutOfDateUpdates();
+}
+
+bool DomTreeUpdater::hasPendingUpdates() const {
+ return hasPendingDomTreeUpdates() || hasPendingPostDomTreeUpdates();
+}
+
+bool DomTreeUpdater::hasPendingDomTreeUpdates() const {
+ if (!DT)
+ return false;
+ return PendUpdates.size() != PendDTUpdateIndex;
+}
+
+bool DomTreeUpdater::hasPendingPostDomTreeUpdates() const {
+ if (!PDT)
+ return false;
+ return PendUpdates.size() != PendPDTUpdateIndex;
+}
+
+bool DomTreeUpdater::isBBPendingDeletion(llvm::BasicBlock *DelBB) const {
+ if (Strategy == UpdateStrategy::Eager || DeletedBBs.empty())
+ return false;
+ return DeletedBBs.count(DelBB) != 0;
+}
+
+// The DT and PDT require the nodes related to updates
+// are not deleted when update functions are called.
+// So BasicBlock deletions must be pended when the
+// UpdateStrategy is Lazy. When the UpdateStrategy is
+// Eager, the BasicBlock will be deleted immediately.
+void DomTreeUpdater::deleteBB(BasicBlock *DelBB) {
+ validateDeleteBB(DelBB);
+ if (Strategy == UpdateStrategy::Lazy) {
+ DeletedBBs.insert(DelBB);
+ return;
+ }
+
+ DelBB->removeFromParent();
+ eraseDelBBNode(DelBB);
+ delete DelBB;
+}
+
+void DomTreeUpdater::callbackDeleteBB(
+ BasicBlock *DelBB, std::function<void(BasicBlock *)> Callback) {
+ validateDeleteBB(DelBB);
+ if (Strategy == UpdateStrategy::Lazy) {
+ Callbacks.push_back(CallBackOnDeletion(DelBB, Callback));
+ DeletedBBs.insert(DelBB);
+ return;
+ }
+
+ DelBB->removeFromParent();
+ eraseDelBBNode(DelBB);
+ Callback(DelBB);
+ delete DelBB;
+}
+
+void DomTreeUpdater::eraseDelBBNode(BasicBlock *DelBB) {
+ if (DT && !IsRecalculatingDomTree)
+ if (DT->getNode(DelBB))
+ DT->eraseNode(DelBB);
+
+ if (PDT && !IsRecalculatingPostDomTree)
+ if (PDT->getNode(DelBB))
+ PDT->eraseNode(DelBB);
+}
+
+void DomTreeUpdater::validateDeleteBB(BasicBlock *DelBB) {
+ assert(DelBB && "Invalid push_back of nullptr DelBB.");
+ assert(pred_empty(DelBB) && "DelBB has one or more predecessors.");
+ // DelBB is unreachable and all its instructions are dead.
+ while (!DelBB->empty()) {
+ Instruction &I = DelBB->back();
+ // Replace used instructions with an arbitrary value (undef).
+ if (!I.use_empty())
+ I.replaceAllUsesWith(llvm::UndefValue::get(I.getType()));
+ DelBB->getInstList().pop_back();
+ }
+ // Make sure DelBB has a valid terminator instruction. As long as DelBB is a
+ // Child of Function F it must contain valid IR.
+ new UnreachableInst(DelBB->getContext(), DelBB);
+}
+
+void DomTreeUpdater::applyUpdates(ArrayRef<DominatorTree::UpdateType> Updates) {
+ if (!DT && !PDT)
+ return;
+
+ if (Strategy == UpdateStrategy::Lazy) {
+ for (const auto U : Updates)
+ if (!isSelfDominance(U))
+ PendUpdates.push_back(U);
+
+ return;
+ }
+
+ if (DT)
+ DT->applyUpdates(Updates);
+ if (PDT)
+ PDT->applyUpdates(Updates);
+}
+
+void DomTreeUpdater::applyUpdatesPermissive(
+ ArrayRef<DominatorTree::UpdateType> Updates) {
+ if (!DT && !PDT)
+ return;
+
+ SmallSet<std::pair<BasicBlock *, BasicBlock *>, 8> Seen;
+ SmallVector<DominatorTree::UpdateType, 8> DeduplicatedUpdates;
+ for (const auto U : Updates) {
+ auto Edge = std::make_pair(U.getFrom(), U.getTo());
+ // Because it is illegal to submit updates that have already been applied
+ // and updates to an edge need to be strictly ordered,
+ // it is safe to infer the existence of an edge from the first update
+ // to this edge.
+ // If the first update to an edge is "Delete", it means that the edge
+ // existed before. If the first update to an edge is "Insert", it means
+ // that the edge didn't exist before.
+ //
+ // For example, if the user submits {{Delete, A, B}, {Insert, A, B}},
+ // because
+ // 1. it is illegal to submit updates that have already been applied,
+ // i.e., user cannot delete an nonexistent edge,
+ // 2. updates to an edge need to be strictly ordered,
+ // So, initially edge A -> B existed.
+ // We can then safely ignore future updates to this edge and directly
+ // inspect the current CFG:
+ // a. If the edge still exists, because the user cannot insert an existent
+ // edge, so both {Delete, A, B}, {Insert, A, B} actually happened and
+ // resulted in a no-op. DTU won't submit any update in this case.
+ // b. If the edge doesn't exist, we can then infer that {Delete, A, B}
+ // actually happened but {Insert, A, B} was an invalid update which never
+ // happened. DTU will submit {Delete, A, B} in this case.
+ if (!isSelfDominance(U) && Seen.count(Edge) == 0) {
+ Seen.insert(Edge);
+ // If the update doesn't appear in the CFG, it means that
+ // either the change isn't made or relevant operations
+ // result in a no-op.
+ if (isUpdateValid(U)) {
+ if (isLazy())
+ PendUpdates.push_back(U);
+ else
+ DeduplicatedUpdates.push_back(U);
+ }
+ }
+ }
+
+ if (Strategy == UpdateStrategy::Lazy)
+ return;
+
+ if (DT)
+ DT->applyUpdates(DeduplicatedUpdates);
+ if (PDT)
+ PDT->applyUpdates(DeduplicatedUpdates);
+}
+
+DominatorTree &DomTreeUpdater::getDomTree() {
+ assert(DT && "Invalid acquisition of a null DomTree");
+ applyDomTreeUpdates();
+ dropOutOfDateUpdates();
+ return *DT;
+}
+
+PostDominatorTree &DomTreeUpdater::getPostDomTree() {
+ assert(PDT && "Invalid acquisition of a null PostDomTree");
+ applyPostDomTreeUpdates();
+ dropOutOfDateUpdates();
+ return *PDT;
+}
+
+void DomTreeUpdater::insertEdge(BasicBlock *From, BasicBlock *To) {
+
+#ifndef NDEBUG
+ assert(isUpdateValid({DominatorTree::Insert, From, To}) &&
+ "Inserted edge does not appear in the CFG");
+#endif
+
+ if (!DT && !PDT)
+ return;
+
+ // Won't affect DomTree and PostDomTree; discard update.
+ if (From == To)
+ return;
+
+ if (Strategy == UpdateStrategy::Eager) {
+ if (DT)
+ DT->insertEdge(From, To);
+ if (PDT)
+ PDT->insertEdge(From, To);
+ return;
+ }
+
+ PendUpdates.push_back({DominatorTree::Insert, From, To});
+}
+
+void DomTreeUpdater::insertEdgeRelaxed(BasicBlock *From, BasicBlock *To) {
+ if (From == To)
+ return;
+
+ if (!DT && !PDT)
+ return;
+
+ if (!isUpdateValid({DominatorTree::Insert, From, To}))
+ return;
+
+ if (Strategy == UpdateStrategy::Eager) {
+ if (DT)
+ DT->insertEdge(From, To);
+ if (PDT)
+ PDT->insertEdge(From, To);
+ return;
+ }
+
+ PendUpdates.push_back({DominatorTree::Insert, From, To});
+}
+
+void DomTreeUpdater::deleteEdge(BasicBlock *From, BasicBlock *To) {
+
+#ifndef NDEBUG
+ assert(isUpdateValid({DominatorTree::Delete, From, To}) &&
+ "Deleted edge still exists in the CFG!");
+#endif
+
+ if (!DT && !PDT)
+ return;
+
+ // Won't affect DomTree and PostDomTree; discard update.
+ if (From == To)
+ return;
+
+ if (Strategy == UpdateStrategy::Eager) {
+ if (DT)
+ DT->deleteEdge(From, To);
+ if (PDT)
+ PDT->deleteEdge(From, To);
+ return;
+ }
+
+ PendUpdates.push_back({DominatorTree::Delete, From, To});
+}
+
+void DomTreeUpdater::deleteEdgeRelaxed(BasicBlock *From, BasicBlock *To) {
+ if (From == To)
+ return;
+
+ if (!DT && !PDT)
+ return;
+
+ if (!isUpdateValid({DominatorTree::Delete, From, To}))
+ return;
+
+ if (Strategy == UpdateStrategy::Eager) {
+ if (DT)
+ DT->deleteEdge(From, To);
+ if (PDT)
+ PDT->deleteEdge(From, To);
+ return;
+ }
+
+ PendUpdates.push_back({DominatorTree::Delete, From, To});
+}
+
+void DomTreeUpdater::dropOutOfDateUpdates() {
+ if (Strategy == DomTreeUpdater::UpdateStrategy::Eager)
+ return;
+
+ tryFlushDeletedBB();
+
+ // Drop all updates applied by both trees.
+ if (!DT)
+ PendDTUpdateIndex = PendUpdates.size();
+ if (!PDT)
+ PendPDTUpdateIndex = PendUpdates.size();
+
+ const size_t dropIndex = std::min(PendDTUpdateIndex, PendPDTUpdateIndex);
+ const auto B = PendUpdates.begin();
+ const auto E = PendUpdates.begin() + dropIndex;
+ assert(B <= E && "Iterator out of range.");
+ PendUpdates.erase(B, E);
+ // Calculate current index.
+ PendDTUpdateIndex -= dropIndex;
+ PendPDTUpdateIndex -= dropIndex;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void DomTreeUpdater::dump() const {
+ raw_ostream &OS = llvm::dbgs();
+
+ OS << "Available Trees: ";
+ if (DT || PDT) {
+ if (DT)
+ OS << "DomTree ";
+ if (PDT)
+ OS << "PostDomTree ";
+ OS << "\n";
+ } else
+ OS << "None\n";
+
+ OS << "UpdateStrategy: ";
+ if (Strategy == UpdateStrategy::Eager) {
+ OS << "Eager\n";
+ return;
+ } else
+ OS << "Lazy\n";
+ int Index = 0;
+
+ auto printUpdates =
+ [&](ArrayRef<DominatorTree::UpdateType>::const_iterator begin,
+ ArrayRef<DominatorTree::UpdateType>::const_iterator end) {
+ if (begin == end)
+ OS << " None\n";
+ Index = 0;
+ for (auto It = begin, ItEnd = end; It != ItEnd; ++It) {
+ auto U = *It;
+ OS << " " << Index << " : ";
+ ++Index;
+ if (U.getKind() == DominatorTree::Insert)
+ OS << "Insert, ";
+ else
+ OS << "Delete, ";
+ BasicBlock *From = U.getFrom();
+ if (From) {
+ auto S = From->getName();
+ if (!From->hasName())
+ S = "(no name)";
+ OS << S << "(" << From << "), ";
+ } else {
+ OS << "(badref), ";
+ }
+ BasicBlock *To = U.getTo();
+ if (To) {
+ auto S = To->getName();
+ if (!To->hasName())
+ S = "(no_name)";
+ OS << S << "(" << To << ")\n";
+ } else {
+ OS << "(badref)\n";
+ }
+ }
+ };
+
+ if (DT) {
+ const auto I = PendUpdates.begin() + PendDTUpdateIndex;
+ assert(PendUpdates.begin() <= I && I <= PendUpdates.end() &&
+ "Iterator out of range.");
+ OS << "Applied but not cleared DomTreeUpdates:\n";
+ printUpdates(PendUpdates.begin(), I);
+ OS << "Pending DomTreeUpdates:\n";
+ printUpdates(I, PendUpdates.end());
+ }
+
+ if (PDT) {
+ const auto I = PendUpdates.begin() + PendPDTUpdateIndex;
+ assert(PendUpdates.begin() <= I && I <= PendUpdates.end() &&
+ "Iterator out of range.");
+ OS << "Applied but not cleared PostDomTreeUpdates:\n";
+ printUpdates(PendUpdates.begin(), I);
+ OS << "Pending PostDomTreeUpdates:\n";
+ printUpdates(I, PendUpdates.end());
+ }
+
+ OS << "Pending DeletedBBs:\n";
+ Index = 0;
+ for (auto BB : DeletedBBs) {
+ OS << " " << Index << " : ";
+ ++Index;
+ if (BB->hasName())
+ OS << BB->getName() << "(";
+ else
+ OS << "(no_name)(";
+ OS << BB << ")\n";
+ }
+
+ OS << "Pending Callbacks:\n";
+ Index = 0;
+ for (auto BB : Callbacks) {
+ OS << " " << Index << " : ";
+ ++Index;
+ if (BB->hasName())
+ OS << BB->getName() << "(";
+ else
+ OS << "(no_name)(";
+ OS << BB << ")\n";
+ }
+}
+#endif
+} // namespace llvm
diff --git a/llvm/lib/Analysis/DominanceFrontier.cpp b/llvm/lib/Analysis/DominanceFrontier.cpp
new file mode 100644
index 000000000000..f9a554acb7ea
--- /dev/null
+++ b/llvm/lib/Analysis/DominanceFrontier.cpp
@@ -0,0 +1,96 @@
+//===- DominanceFrontier.cpp - Dominance Frontier Calculation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/DominanceFrontierImpl.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+namespace llvm {
+
+template class DominanceFrontierBase<BasicBlock, false>;
+template class DominanceFrontierBase<BasicBlock, true>;
+template class ForwardDominanceFrontierBase<BasicBlock>;
+
+} // end namespace llvm
+
+char DominanceFrontierWrapperPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DominanceFrontierWrapperPass, "domfrontier",
+ "Dominance Frontier Construction", true, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(DominanceFrontierWrapperPass, "domfrontier",
+ "Dominance Frontier Construction", true, true)
+
+DominanceFrontierWrapperPass::DominanceFrontierWrapperPass()
+ : FunctionPass(ID), DF() {
+ initializeDominanceFrontierWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void DominanceFrontierWrapperPass::releaseMemory() {
+ DF.releaseMemory();
+}
+
+bool DominanceFrontierWrapperPass::runOnFunction(Function &) {
+ releaseMemory();
+ DF.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree());
+ return false;
+}
+
+void DominanceFrontierWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<DominatorTreeWrapperPass>();
+}
+
+void DominanceFrontierWrapperPass::print(raw_ostream &OS, const Module *) const {
+ DF.print(OS);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void DominanceFrontierWrapperPass::dump() const {
+ print(dbgs());
+}
+#endif
+
+/// Handle invalidation explicitly.
+bool DominanceFrontier::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // Check whether the analysis, all analyses on functions, or the function's
+ // CFG have been preserved.
+ auto PAC = PA.getChecker<DominanceFrontierAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
+ PAC.preservedSet<CFGAnalyses>());
+}
+
+AnalysisKey DominanceFrontierAnalysis::Key;
+
+DominanceFrontier DominanceFrontierAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ DominanceFrontier DF;
+ DF.analyze(AM.getResult<DominatorTreeAnalysis>(F));
+ return DF;
+}
+
+DominanceFrontierPrinterPass::DominanceFrontierPrinterPass(raw_ostream &OS)
+ : OS(OS) {}
+
+PreservedAnalyses
+DominanceFrontierPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ OS << "DominanceFrontier for function: " << F.getName() << "\n";
+ AM.getResult<DominanceFrontierAnalysis>(F).print(OS);
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/EHPersonalities.cpp b/llvm/lib/Analysis/EHPersonalities.cpp
new file mode 100644
index 000000000000..2242541696a4
--- /dev/null
+++ b/llvm/lib/Analysis/EHPersonalities.cpp
@@ -0,0 +1,135 @@
+//===- EHPersonalities.cpp - Compute EH-related information ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/EHPersonalities.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+/// See if the given exception handling personality function is one that we
+/// understand. If so, return a description of it; otherwise return Unknown.
+EHPersonality llvm::classifyEHPersonality(const Value *Pers) {
+ const Function *F =
+ Pers ? dyn_cast<Function>(Pers->stripPointerCasts()) : nullptr;
+ if (!F)
+ return EHPersonality::Unknown;
+ return StringSwitch<EHPersonality>(F->getName())
+ .Case("__gnat_eh_personality", EHPersonality::GNU_Ada)
+ .Case("__gxx_personality_v0", EHPersonality::GNU_CXX)
+ .Case("__gxx_personality_seh0", EHPersonality::GNU_CXX)
+ .Case("__gxx_personality_sj0", EHPersonality::GNU_CXX_SjLj)
+ .Case("__gcc_personality_v0", EHPersonality::GNU_C)
+ .Case("__gcc_personality_seh0", EHPersonality::GNU_C)
+ .Case("__gcc_personality_sj0", EHPersonality::GNU_C_SjLj)
+ .Case("__objc_personality_v0", EHPersonality::GNU_ObjC)
+ .Case("_except_handler3", EHPersonality::MSVC_X86SEH)
+ .Case("_except_handler4", EHPersonality::MSVC_X86SEH)
+ .Case("__C_specific_handler", EHPersonality::MSVC_Win64SEH)
+ .Case("__CxxFrameHandler3", EHPersonality::MSVC_CXX)
+ .Case("ProcessCLRException", EHPersonality::CoreCLR)
+ .Case("rust_eh_personality", EHPersonality::Rust)
+ .Case("__gxx_wasm_personality_v0", EHPersonality::Wasm_CXX)
+ .Default(EHPersonality::Unknown);
+}
+
+StringRef llvm::getEHPersonalityName(EHPersonality Pers) {
+ switch (Pers) {
+ case EHPersonality::GNU_Ada: return "__gnat_eh_personality";
+ case EHPersonality::GNU_CXX: return "__gxx_personality_v0";
+ case EHPersonality::GNU_CXX_SjLj: return "__gxx_personality_sj0";
+ case EHPersonality::GNU_C: return "__gcc_personality_v0";
+ case EHPersonality::GNU_C_SjLj: return "__gcc_personality_sj0";
+ case EHPersonality::GNU_ObjC: return "__objc_personality_v0";
+ case EHPersonality::MSVC_X86SEH: return "_except_handler3";
+ case EHPersonality::MSVC_Win64SEH: return "__C_specific_handler";
+ case EHPersonality::MSVC_CXX: return "__CxxFrameHandler3";
+ case EHPersonality::CoreCLR: return "ProcessCLRException";
+ case EHPersonality::Rust: return "rust_eh_personality";
+ case EHPersonality::Wasm_CXX: return "__gxx_wasm_personality_v0";
+ case EHPersonality::Unknown: llvm_unreachable("Unknown EHPersonality!");
+ }
+
+ llvm_unreachable("Invalid EHPersonality!");
+}
+
+EHPersonality llvm::getDefaultEHPersonality(const Triple &T) {
+ return EHPersonality::GNU_C;
+}
+
+bool llvm::canSimplifyInvokeNoUnwind(const Function *F) {
+ EHPersonality Personality = classifyEHPersonality(F->getPersonalityFn());
+ // We can't simplify any invokes to nounwind functions if the personality
+ // function wants to catch asynch exceptions. The nounwind attribute only
+ // implies that the function does not throw synchronous exceptions.
+ return !isAsynchronousEHPersonality(Personality);
+}
+
+DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) {
+ SmallVector<std::pair<BasicBlock *, BasicBlock *>, 16> Worklist;
+ BasicBlock *EntryBlock = &F.getEntryBlock();
+ DenseMap<BasicBlock *, ColorVector> BlockColors;
+
+ // Build up the color map, which maps each block to its set of 'colors'.
+ // For any block B the "colors" of B are the set of funclets F (possibly
+ // including a root "funclet" representing the main function) such that
+ // F will need to directly contain B or a copy of B (where the term "directly
+ // contain" is used to distinguish from being "transitively contained" in
+ // a nested funclet).
+ //
+ // Note: Despite not being a funclet in the truest sense, a catchswitch is
+ // considered to belong to its own funclet for the purposes of coloring.
+
+ DEBUG_WITH_TYPE("winehprepare-coloring", dbgs() << "\nColoring funclets for "
+ << F.getName() << "\n");
+
+ Worklist.push_back({EntryBlock, EntryBlock});
+
+ while (!Worklist.empty()) {
+ BasicBlock *Visiting;
+ BasicBlock *Color;
+ std::tie(Visiting, Color) = Worklist.pop_back_val();
+ DEBUG_WITH_TYPE("winehprepare-coloring",
+ dbgs() << "Visiting " << Visiting->getName() << ", "
+ << Color->getName() << "\n");
+ Instruction *VisitingHead = Visiting->getFirstNonPHI();
+ if (VisitingHead->isEHPad()) {
+ // Mark this funclet head as a member of itself.
+ Color = Visiting;
+ }
+ // Note that this is a member of the given color.
+ ColorVector &Colors = BlockColors[Visiting];
+ if (!is_contained(Colors, Color))
+ Colors.push_back(Color);
+ else
+ continue;
+
+ DEBUG_WITH_TYPE("winehprepare-coloring",
+ dbgs() << " Assigned color \'" << Color->getName()
+ << "\' to block \'" << Visiting->getName()
+ << "\'.\n");
+
+ BasicBlock *SuccColor = Color;
+ Instruction *Terminator = Visiting->getTerminator();
+ if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) {
+ Value *ParentPad = CatchRet->getCatchSwitchParentPad();
+ if (isa<ConstantTokenNone>(ParentPad))
+ SuccColor = EntryBlock;
+ else
+ SuccColor = cast<Instruction>(ParentPad)->getParent();
+ }
+
+ for (BasicBlock *Succ : successors(Visiting))
+ Worklist.push_back({Succ, SuccColor});
+ }
+ return BlockColors;
+}
diff --git a/llvm/lib/Analysis/GlobalsModRef.cpp b/llvm/lib/Analysis/GlobalsModRef.cpp
new file mode 100644
index 000000000000..efdf9706ba3c
--- /dev/null
+++ b/llvm/lib/Analysis/GlobalsModRef.cpp
@@ -0,0 +1,1026 @@
+//===- GlobalsModRef.cpp - Simple Mod/Ref Analysis for Globals ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This simple pass provides alias and mod/ref information for global values
+// that do not have their address taken, and keeps track of whether functions
+// read or write memory (are "pure"). For this simple (but very common) case,
+// we can provide pretty accurate and useful information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+using namespace llvm;
+
+#define DEBUG_TYPE "globalsmodref-aa"
+
+STATISTIC(NumNonAddrTakenGlobalVars,
+ "Number of global vars without address taken");
+STATISTIC(NumNonAddrTakenFunctions,"Number of functions without address taken");
+STATISTIC(NumNoMemFunctions, "Number of functions that do not access memory");
+STATISTIC(NumReadMemFunctions, "Number of functions that only read memory");
+STATISTIC(NumIndirectGlobalVars, "Number of indirect global objects");
+
+// An option to enable unsafe alias results from the GlobalsModRef analysis.
+// When enabled, GlobalsModRef will provide no-alias results which in extremely
+// rare cases may not be conservatively correct. In particular, in the face of
+// transforms which cause assymetry between how effective GetUnderlyingObject
+// is for two pointers, it may produce incorrect results.
+//
+// These unsafe results have been returned by GMR for many years without
+// causing significant issues in the wild and so we provide a mechanism to
+// re-enable them for users of LLVM that have a particular performance
+// sensitivity and no known issues. The option also makes it easy to evaluate
+// the performance impact of these results.
+static cl::opt<bool> EnableUnsafeGlobalsModRefAliasResults(
+ "enable-unsafe-globalsmodref-alias-results", cl::init(false), cl::Hidden);
+
+/// The mod/ref information collected for a particular function.
+///
+/// We collect information about mod/ref behavior of a function here, both in
+/// general and as pertains to specific globals. We only have this detailed
+/// information when we know *something* useful about the behavior. If we
+/// saturate to fully general mod/ref, we remove the info for the function.
+class GlobalsAAResult::FunctionInfo {
+ typedef SmallDenseMap<const GlobalValue *, ModRefInfo, 16> GlobalInfoMapType;
+
+ /// Build a wrapper struct that has 8-byte alignment. All heap allocations
+ /// should provide this much alignment at least, but this makes it clear we
+ /// specifically rely on this amount of alignment.
+ struct alignas(8) AlignedMap {
+ AlignedMap() {}
+ AlignedMap(const AlignedMap &Arg) : Map(Arg.Map) {}
+ GlobalInfoMapType Map;
+ };
+
+ /// Pointer traits for our aligned map.
+ struct AlignedMapPointerTraits {
+ static inline void *getAsVoidPointer(AlignedMap *P) { return P; }
+ static inline AlignedMap *getFromVoidPointer(void *P) {
+ return (AlignedMap *)P;
+ }
+ enum { NumLowBitsAvailable = 3 };
+ static_assert(alignof(AlignedMap) >= (1 << NumLowBitsAvailable),
+ "AlignedMap insufficiently aligned to have enough low bits.");
+ };
+
+ /// The bit that flags that this function may read any global. This is
+ /// chosen to mix together with ModRefInfo bits.
+ /// FIXME: This assumes ModRefInfo lattice will remain 4 bits!
+ /// It overlaps with ModRefInfo::Must bit!
+ /// FunctionInfo.getModRefInfo() masks out everything except ModRef so
+ /// this remains correct, but the Must info is lost.
+ enum { MayReadAnyGlobal = 4 };
+
+ /// Checks to document the invariants of the bit packing here.
+ static_assert((MayReadAnyGlobal & static_cast<int>(ModRefInfo::MustModRef)) ==
+ 0,
+ "ModRef and the MayReadAnyGlobal flag bits overlap.");
+ static_assert(((MayReadAnyGlobal |
+ static_cast<int>(ModRefInfo::MustModRef)) >>
+ AlignedMapPointerTraits::NumLowBitsAvailable) == 0,
+ "Insufficient low bits to store our flag and ModRef info.");
+
+public:
+ FunctionInfo() : Info() {}
+ ~FunctionInfo() {
+ delete Info.getPointer();
+ }
+ // Spell out the copy ond move constructors and assignment operators to get
+ // deep copy semantics and correct move semantics in the face of the
+ // pointer-int pair.
+ FunctionInfo(const FunctionInfo &Arg)
+ : Info(nullptr, Arg.Info.getInt()) {
+ if (const auto *ArgPtr = Arg.Info.getPointer())
+ Info.setPointer(new AlignedMap(*ArgPtr));
+ }
+ FunctionInfo(FunctionInfo &&Arg)
+ : Info(Arg.Info.getPointer(), Arg.Info.getInt()) {
+ Arg.Info.setPointerAndInt(nullptr, 0);
+ }
+ FunctionInfo &operator=(const FunctionInfo &RHS) {
+ delete Info.getPointer();
+ Info.setPointerAndInt(nullptr, RHS.Info.getInt());
+ if (const auto *RHSPtr = RHS.Info.getPointer())
+ Info.setPointer(new AlignedMap(*RHSPtr));
+ return *this;
+ }
+ FunctionInfo &operator=(FunctionInfo &&RHS) {
+ delete Info.getPointer();
+ Info.setPointerAndInt(RHS.Info.getPointer(), RHS.Info.getInt());
+ RHS.Info.setPointerAndInt(nullptr, 0);
+ return *this;
+ }
+
+ /// This method clears MayReadAnyGlobal bit added by GlobalsAAResult to return
+ /// the corresponding ModRefInfo. It must align in functionality with
+ /// clearMust().
+ ModRefInfo globalClearMayReadAnyGlobal(int I) const {
+ return ModRefInfo((I & static_cast<int>(ModRefInfo::ModRef)) |
+ static_cast<int>(ModRefInfo::NoModRef));
+ }
+
+ /// Returns the \c ModRefInfo info for this function.
+ ModRefInfo getModRefInfo() const {
+ return globalClearMayReadAnyGlobal(Info.getInt());
+ }
+
+ /// Adds new \c ModRefInfo for this function to its state.
+ void addModRefInfo(ModRefInfo NewMRI) {
+ Info.setInt(Info.getInt() | static_cast<int>(setMust(NewMRI)));
+ }
+
+ /// Returns whether this function may read any global variable, and we don't
+ /// know which global.
+ bool mayReadAnyGlobal() const { return Info.getInt() & MayReadAnyGlobal; }
+
+ /// Sets this function as potentially reading from any global.
+ void setMayReadAnyGlobal() { Info.setInt(Info.getInt() | MayReadAnyGlobal); }
+
+ /// Returns the \c ModRefInfo info for this function w.r.t. a particular
+ /// global, which may be more precise than the general information above.
+ ModRefInfo getModRefInfoForGlobal(const GlobalValue &GV) const {
+ ModRefInfo GlobalMRI =
+ mayReadAnyGlobal() ? ModRefInfo::Ref : ModRefInfo::NoModRef;
+ if (AlignedMap *P = Info.getPointer()) {
+ auto I = P->Map.find(&GV);
+ if (I != P->Map.end())
+ GlobalMRI = unionModRef(GlobalMRI, I->second);
+ }
+ return GlobalMRI;
+ }
+
+ /// Add mod/ref info from another function into ours, saturating towards
+ /// ModRef.
+ void addFunctionInfo(const FunctionInfo &FI) {
+ addModRefInfo(FI.getModRefInfo());
+
+ if (FI.mayReadAnyGlobal())
+ setMayReadAnyGlobal();
+
+ if (AlignedMap *P = FI.Info.getPointer())
+ for (const auto &G : P->Map)
+ addModRefInfoForGlobal(*G.first, G.second);
+ }
+
+ void addModRefInfoForGlobal(const GlobalValue &GV, ModRefInfo NewMRI) {
+ AlignedMap *P = Info.getPointer();
+ if (!P) {
+ P = new AlignedMap();
+ Info.setPointer(P);
+ }
+ auto &GlobalMRI = P->Map[&GV];
+ GlobalMRI = unionModRef(GlobalMRI, NewMRI);
+ }
+
+ /// Clear a global's ModRef info. Should be used when a global is being
+ /// deleted.
+ void eraseModRefInfoForGlobal(const GlobalValue &GV) {
+ if (AlignedMap *P = Info.getPointer())
+ P->Map.erase(&GV);
+ }
+
+private:
+ /// All of the information is encoded into a single pointer, with a three bit
+ /// integer in the low three bits. The high bit provides a flag for when this
+ /// function may read any global. The low two bits are the ModRefInfo. And
+ /// the pointer, when non-null, points to a map from GlobalValue to
+ /// ModRefInfo specific to that GlobalValue.
+ PointerIntPair<AlignedMap *, 3, unsigned, AlignedMapPointerTraits> Info;
+};
+
+void GlobalsAAResult::DeletionCallbackHandle::deleted() {
+ Value *V = getValPtr();
+ if (auto *F = dyn_cast<Function>(V))
+ GAR->FunctionInfos.erase(F);
+
+ if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
+ if (GAR->NonAddressTakenGlobals.erase(GV)) {
+ // This global might be an indirect global. If so, remove it and
+ // remove any AllocRelatedValues for it.
+ if (GAR->IndirectGlobals.erase(GV)) {
+ // Remove any entries in AllocsForIndirectGlobals for this global.
+ for (auto I = GAR->AllocsForIndirectGlobals.begin(),
+ E = GAR->AllocsForIndirectGlobals.end();
+ I != E; ++I)
+ if (I->second == GV)
+ GAR->AllocsForIndirectGlobals.erase(I);
+ }
+
+ // Scan the function info we have collected and remove this global
+ // from all of them.
+ for (auto &FIPair : GAR->FunctionInfos)
+ FIPair.second.eraseModRefInfoForGlobal(*GV);
+ }
+ }
+
+ // If this is an allocation related to an indirect global, remove it.
+ GAR->AllocsForIndirectGlobals.erase(V);
+
+ // And clear out the handle.
+ setValPtr(nullptr);
+ GAR->Handles.erase(I);
+ // This object is now destroyed!
+}
+
+FunctionModRefBehavior GlobalsAAResult::getModRefBehavior(const Function *F) {
+ FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior;
+
+ if (FunctionInfo *FI = getFunctionInfo(F)) {
+ if (!isModOrRefSet(FI->getModRefInfo()))
+ Min = FMRB_DoesNotAccessMemory;
+ else if (!isModSet(FI->getModRefInfo()))
+ Min = FMRB_OnlyReadsMemory;
+ }
+
+ return FunctionModRefBehavior(AAResultBase::getModRefBehavior(F) & Min);
+}
+
+FunctionModRefBehavior
+GlobalsAAResult::getModRefBehavior(const CallBase *Call) {
+ FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior;
+
+ if (!Call->hasOperandBundles())
+ if (const Function *F = Call->getCalledFunction())
+ if (FunctionInfo *FI = getFunctionInfo(F)) {
+ if (!isModOrRefSet(FI->getModRefInfo()))
+ Min = FMRB_DoesNotAccessMemory;
+ else if (!isModSet(FI->getModRefInfo()))
+ Min = FMRB_OnlyReadsMemory;
+ }
+
+ return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min);
+}
+
+/// Returns the function info for the function, or null if we don't have
+/// anything useful to say about it.
+GlobalsAAResult::FunctionInfo *
+GlobalsAAResult::getFunctionInfo(const Function *F) {
+ auto I = FunctionInfos.find(F);
+ if (I != FunctionInfos.end())
+ return &I->second;
+ return nullptr;
+}
+
+/// AnalyzeGlobals - Scan through the users of all of the internal
+/// GlobalValue's in the program. If none of them have their "address taken"
+/// (really, their address passed to something nontrivial), record this fact,
+/// and record the functions that they are used directly in.
+void GlobalsAAResult::AnalyzeGlobals(Module &M) {
+ SmallPtrSet<Function *, 32> TrackedFunctions;
+ for (Function &F : M)
+ if (F.hasLocalLinkage())
+ if (!AnalyzeUsesOfPointer(&F)) {
+ // Remember that we are tracking this global.
+ NonAddressTakenGlobals.insert(&F);
+ TrackedFunctions.insert(&F);
+ Handles.emplace_front(*this, &F);
+ Handles.front().I = Handles.begin();
+ ++NumNonAddrTakenFunctions;
+ }
+
+ SmallPtrSet<Function *, 16> Readers, Writers;
+ for (GlobalVariable &GV : M.globals())
+ if (GV.hasLocalLinkage()) {
+ if (!AnalyzeUsesOfPointer(&GV, &Readers,
+ GV.isConstant() ? nullptr : &Writers)) {
+ // Remember that we are tracking this global, and the mod/ref fns
+ NonAddressTakenGlobals.insert(&GV);
+ Handles.emplace_front(*this, &GV);
+ Handles.front().I = Handles.begin();
+
+ for (Function *Reader : Readers) {
+ if (TrackedFunctions.insert(Reader).second) {
+ Handles.emplace_front(*this, Reader);
+ Handles.front().I = Handles.begin();
+ }
+ FunctionInfos[Reader].addModRefInfoForGlobal(GV, ModRefInfo::Ref);
+ }
+
+ if (!GV.isConstant()) // No need to keep track of writers to constants
+ for (Function *Writer : Writers) {
+ if (TrackedFunctions.insert(Writer).second) {
+ Handles.emplace_front(*this, Writer);
+ Handles.front().I = Handles.begin();
+ }
+ FunctionInfos[Writer].addModRefInfoForGlobal(GV, ModRefInfo::Mod);
+ }
+ ++NumNonAddrTakenGlobalVars;
+
+ // If this global holds a pointer type, see if it is an indirect global.
+ if (GV.getValueType()->isPointerTy() &&
+ AnalyzeIndirectGlobalMemory(&GV))
+ ++NumIndirectGlobalVars;
+ }
+ Readers.clear();
+ Writers.clear();
+ }
+}
+
+/// AnalyzeUsesOfPointer - Look at all of the users of the specified pointer.
+/// If this is used by anything complex (i.e., the address escapes), return
+/// true. Also, while we are at it, keep track of those functions that read and
+/// write to the value.
+///
+/// If OkayStoreDest is non-null, stores into this global are allowed.
+bool GlobalsAAResult::AnalyzeUsesOfPointer(Value *V,
+ SmallPtrSetImpl<Function *> *Readers,
+ SmallPtrSetImpl<Function *> *Writers,
+ GlobalValue *OkayStoreDest) {
+ if (!V->getType()->isPointerTy())
+ return true;
+
+ for (Use &U : V->uses()) {
+ User *I = U.getUser();
+ if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+ if (Readers)
+ Readers->insert(LI->getParent()->getParent());
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+ if (V == SI->getOperand(1)) {
+ if (Writers)
+ Writers->insert(SI->getParent()->getParent());
+ } else if (SI->getOperand(1) != OkayStoreDest) {
+ return true; // Storing the pointer
+ }
+ } else if (Operator::getOpcode(I) == Instruction::GetElementPtr) {
+ if (AnalyzeUsesOfPointer(I, Readers, Writers))
+ return true;
+ } else if (Operator::getOpcode(I) == Instruction::BitCast) {
+ if (AnalyzeUsesOfPointer(I, Readers, Writers, OkayStoreDest))
+ return true;
+ } else if (auto *Call = dyn_cast<CallBase>(I)) {
+ // Make sure that this is just the function being called, not that it is
+ // passing into the function.
+ if (Call->isDataOperand(&U)) {
+ // Detect calls to free.
+ if (Call->isArgOperand(&U) &&
+ isFreeCall(I, &GetTLI(*Call->getFunction()))) {
+ if (Writers)
+ Writers->insert(Call->getParent()->getParent());
+ } else {
+ return true; // Argument of an unknown call.
+ }
+ }
+ } else if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) {
+ if (!isa<ConstantPointerNull>(ICI->getOperand(1)))
+ return true; // Allow comparison against null.
+ } else if (Constant *C = dyn_cast<Constant>(I)) {
+ // Ignore constants which don't have any live uses.
+ if (isa<GlobalValue>(C) || C->isConstantUsed())
+ return true;
+ } else {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+/// AnalyzeIndirectGlobalMemory - We found an non-address-taken global variable
+/// which holds a pointer type. See if the global always points to non-aliased
+/// heap memory: that is, all initializers of the globals are allocations, and
+/// those allocations have no use other than initialization of the global.
+/// Further, all loads out of GV must directly use the memory, not store the
+/// pointer somewhere. If this is true, we consider the memory pointed to by
+/// GV to be owned by GV and can disambiguate other pointers from it.
+bool GlobalsAAResult::AnalyzeIndirectGlobalMemory(GlobalVariable *GV) {
+ // Keep track of values related to the allocation of the memory, f.e. the
+ // value produced by the malloc call and any casts.
+ std::vector<Value *> AllocRelatedValues;
+
+ // If the initializer is a valid pointer, bail.
+ if (Constant *C = GV->getInitializer())
+ if (!C->isNullValue())
+ return false;
+
+ // Walk the user list of the global. If we find anything other than a direct
+ // load or store, bail out.
+ for (User *U : GV->users()) {
+ if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+ // The pointer loaded from the global can only be used in simple ways:
+ // we allow addressing of it and loading storing to it. We do *not* allow
+ // storing the loaded pointer somewhere else or passing to a function.
+ if (AnalyzeUsesOfPointer(LI))
+ return false; // Loaded pointer escapes.
+ // TODO: Could try some IP mod/ref of the loaded pointer.
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
+ // Storing the global itself.
+ if (SI->getOperand(0) == GV)
+ return false;
+
+ // If storing the null pointer, ignore it.
+ if (isa<ConstantPointerNull>(SI->getOperand(0)))
+ continue;
+
+ // Check the value being stored.
+ Value *Ptr = GetUnderlyingObject(SI->getOperand(0),
+ GV->getParent()->getDataLayout());
+
+ if (!isAllocLikeFn(Ptr, &GetTLI(*SI->getFunction())))
+ return false; // Too hard to analyze.
+
+ // Analyze all uses of the allocation. If any of them are used in a
+ // non-simple way (e.g. stored to another global) bail out.
+ if (AnalyzeUsesOfPointer(Ptr, /*Readers*/ nullptr, /*Writers*/ nullptr,
+ GV))
+ return false; // Loaded pointer escapes.
+
+ // Remember that this allocation is related to the indirect global.
+ AllocRelatedValues.push_back(Ptr);
+ } else {
+ // Something complex, bail out.
+ return false;
+ }
+ }
+
+ // Okay, this is an indirect global. Remember all of the allocations for
+ // this global in AllocsForIndirectGlobals.
+ while (!AllocRelatedValues.empty()) {
+ AllocsForIndirectGlobals[AllocRelatedValues.back()] = GV;
+ Handles.emplace_front(*this, AllocRelatedValues.back());
+ Handles.front().I = Handles.begin();
+ AllocRelatedValues.pop_back();
+ }
+ IndirectGlobals.insert(GV);
+ Handles.emplace_front(*this, GV);
+ Handles.front().I = Handles.begin();
+ return true;
+}
+
+void GlobalsAAResult::CollectSCCMembership(CallGraph &CG) {
+ // We do a bottom-up SCC traversal of the call graph. In other words, we
+ // visit all callees before callers (leaf-first).
+ unsigned SCCID = 0;
+ for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) {
+ const std::vector<CallGraphNode *> &SCC = *I;
+ assert(!SCC.empty() && "SCC with no functions?");
+
+ for (auto *CGN : SCC)
+ if (Function *F = CGN->getFunction())
+ FunctionToSCCMap[F] = SCCID;
+ ++SCCID;
+ }
+}
+
+/// AnalyzeCallGraph - At this point, we know the functions where globals are
+/// immediately stored to and read from. Propagate this information up the call
+/// graph to all callers and compute the mod/ref info for all memory for each
+/// function.
+void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) {
+ // We do a bottom-up SCC traversal of the call graph. In other words, we
+ // visit all callees before callers (leaf-first).
+ for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) {
+ const std::vector<CallGraphNode *> &SCC = *I;
+ assert(!SCC.empty() && "SCC with no functions?");
+
+ Function *F = SCC[0]->getFunction();
+
+ if (!F || !F->isDefinitionExact()) {
+ // Calls externally or not exact - can't say anything useful. Remove any
+ // existing function records (may have been created when scanning
+ // globals).
+ for (auto *Node : SCC)
+ FunctionInfos.erase(Node->getFunction());
+ continue;
+ }
+
+ FunctionInfo &FI = FunctionInfos[F];
+ Handles.emplace_front(*this, F);
+ Handles.front().I = Handles.begin();
+ bool KnowNothing = false;
+
+ // Collect the mod/ref properties due to called functions. We only compute
+ // one mod-ref set.
+ for (unsigned i = 0, e = SCC.size(); i != e && !KnowNothing; ++i) {
+ if (!F) {
+ KnowNothing = true;
+ break;
+ }
+
+ if (F->isDeclaration() || F->hasOptNone()) {
+ // Try to get mod/ref behaviour from function attributes.
+ if (F->doesNotAccessMemory()) {
+ // Can't do better than that!
+ } else if (F->onlyReadsMemory()) {
+ FI.addModRefInfo(ModRefInfo::Ref);
+ if (!F->isIntrinsic() && !F->onlyAccessesArgMemory())
+ // This function might call back into the module and read a global -
+ // consider every global as possibly being read by this function.
+ FI.setMayReadAnyGlobal();
+ } else {
+ FI.addModRefInfo(ModRefInfo::ModRef);
+ // Can't say anything useful unless it's an intrinsic - they don't
+ // read or write global variables of the kind considered here.
+ KnowNothing = !F->isIntrinsic();
+ }
+ continue;
+ }
+
+ for (CallGraphNode::iterator CI = SCC[i]->begin(), E = SCC[i]->end();
+ CI != E && !KnowNothing; ++CI)
+ if (Function *Callee = CI->second->getFunction()) {
+ if (FunctionInfo *CalleeFI = getFunctionInfo(Callee)) {
+ // Propagate function effect up.
+ FI.addFunctionInfo(*CalleeFI);
+ } else {
+ // Can't say anything about it. However, if it is inside our SCC,
+ // then nothing needs to be done.
+ CallGraphNode *CalleeNode = CG[Callee];
+ if (!is_contained(SCC, CalleeNode))
+ KnowNothing = true;
+ }
+ } else {
+ KnowNothing = true;
+ }
+ }
+
+ // If we can't say anything useful about this SCC, remove all SCC functions
+ // from the FunctionInfos map.
+ if (KnowNothing) {
+ for (auto *Node : SCC)
+ FunctionInfos.erase(Node->getFunction());
+ continue;
+ }
+
+ // Scan the function bodies for explicit loads or stores.
+ for (auto *Node : SCC) {
+ if (isModAndRefSet(FI.getModRefInfo()))
+ break; // The mod/ref lattice saturates here.
+
+ // Don't prove any properties based on the implementation of an optnone
+ // function. Function attributes were already used as a best approximation
+ // above.
+ if (Node->getFunction()->hasOptNone())
+ continue;
+
+ for (Instruction &I : instructions(Node->getFunction())) {
+ if (isModAndRefSet(FI.getModRefInfo()))
+ break; // The mod/ref lattice saturates here.
+
+ // We handle calls specially because the graph-relevant aspects are
+ // handled above.
+ if (auto *Call = dyn_cast<CallBase>(&I)) {
+ auto &TLI = GetTLI(*Node->getFunction());
+ if (isAllocationFn(Call, &TLI) || isFreeCall(Call, &TLI)) {
+ // FIXME: It is completely unclear why this is necessary and not
+ // handled by the above graph code.
+ FI.addModRefInfo(ModRefInfo::ModRef);
+ } else if (Function *Callee = Call->getCalledFunction()) {
+ // The callgraph doesn't include intrinsic calls.
+ if (Callee->isIntrinsic()) {
+ if (isa<DbgInfoIntrinsic>(Call))
+ // Don't let dbg intrinsics affect alias info.
+ continue;
+
+ FunctionModRefBehavior Behaviour =
+ AAResultBase::getModRefBehavior(Callee);
+ FI.addModRefInfo(createModRefInfo(Behaviour));
+ }
+ }
+ continue;
+ }
+
+ // All non-call instructions we use the primary predicates for whether
+ // they read or write memory.
+ if (I.mayReadFromMemory())
+ FI.addModRefInfo(ModRefInfo::Ref);
+ if (I.mayWriteToMemory())
+ FI.addModRefInfo(ModRefInfo::Mod);
+ }
+ }
+
+ if (!isModSet(FI.getModRefInfo()))
+ ++NumReadMemFunctions;
+ if (!isModOrRefSet(FI.getModRefInfo()))
+ ++NumNoMemFunctions;
+
+ // Finally, now that we know the full effect on this SCC, clone the
+ // information to each function in the SCC.
+ // FI is a reference into FunctionInfos, so copy it now so that it doesn't
+ // get invalidated if DenseMap decides to re-hash.
+ FunctionInfo CachedFI = FI;
+ for (unsigned i = 1, e = SCC.size(); i != e; ++i)
+ FunctionInfos[SCC[i]->getFunction()] = CachedFI;
+ }
+}
+
+// GV is a non-escaping global. V is a pointer address that has been loaded from.
+// If we can prove that V must escape, we can conclude that a load from V cannot
+// alias GV.
+static bool isNonEscapingGlobalNoAliasWithLoad(const GlobalValue *GV,
+ const Value *V,
+ int &Depth,
+ const DataLayout &DL) {
+ SmallPtrSet<const Value *, 8> Visited;
+ SmallVector<const Value *, 8> Inputs;
+ Visited.insert(V);
+ Inputs.push_back(V);
+ do {
+ const Value *Input = Inputs.pop_back_val();
+
+ if (isa<GlobalValue>(Input) || isa<Argument>(Input) || isa<CallInst>(Input) ||
+ isa<InvokeInst>(Input))
+ // Arguments to functions or returns from functions are inherently
+ // escaping, so we can immediately classify those as not aliasing any
+ // non-addr-taken globals.
+ //
+ // (Transitive) loads from a global are also safe - if this aliased
+ // another global, its address would escape, so no alias.
+ continue;
+
+ // Recurse through a limited number of selects, loads and PHIs. This is an
+ // arbitrary depth of 4, lower numbers could be used to fix compile time
+ // issues if needed, but this is generally expected to be only be important
+ // for small depths.
+ if (++Depth > 4)
+ return false;
+
+ if (auto *LI = dyn_cast<LoadInst>(Input)) {
+ Inputs.push_back(GetUnderlyingObject(LI->getPointerOperand(), DL));
+ continue;
+ }
+ if (auto *SI = dyn_cast<SelectInst>(Input)) {
+ const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL);
+ const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL);
+ if (Visited.insert(LHS).second)
+ Inputs.push_back(LHS);
+ if (Visited.insert(RHS).second)
+ Inputs.push_back(RHS);
+ continue;
+ }
+ if (auto *PN = dyn_cast<PHINode>(Input)) {
+ for (const Value *Op : PN->incoming_values()) {
+ Op = GetUnderlyingObject(Op, DL);
+ if (Visited.insert(Op).second)
+ Inputs.push_back(Op);
+ }
+ continue;
+ }
+
+ return false;
+ } while (!Inputs.empty());
+
+ // All inputs were known to be no-alias.
+ return true;
+}
+
+// There are particular cases where we can conclude no-alias between
+// a non-addr-taken global and some other underlying object. Specifically,
+// a non-addr-taken global is known to not be escaped from any function. It is
+// also incorrect for a transformation to introduce an escape of a global in
+// a way that is observable when it was not there previously. One function
+// being transformed to introduce an escape which could possibly be observed
+// (via loading from a global or the return value for example) within another
+// function is never safe. If the observation is made through non-atomic
+// operations on different threads, it is a data-race and UB. If the
+// observation is well defined, by being observed the transformation would have
+// changed program behavior by introducing the observed escape, making it an
+// invalid transform.
+//
+// This property does require that transformations which *temporarily* escape
+// a global that was not previously escaped, prior to restoring it, cannot rely
+// on the results of GMR::alias. This seems a reasonable restriction, although
+// currently there is no way to enforce it. There is also no realistic
+// optimization pass that would make this mistake. The closest example is
+// a transformation pass which does reg2mem of SSA values but stores them into
+// global variables temporarily before restoring the global variable's value.
+// This could be useful to expose "benign" races for example. However, it seems
+// reasonable to require that a pass which introduces escapes of global
+// variables in this way to either not trust AA results while the escape is
+// active, or to be forced to operate as a module pass that cannot co-exist
+// with an alias analysis such as GMR.
+bool GlobalsAAResult::isNonEscapingGlobalNoAlias(const GlobalValue *GV,
+ const Value *V) {
+ // In order to know that the underlying object cannot alias the
+ // non-addr-taken global, we must know that it would have to be an escape.
+ // Thus if the underlying object is a function argument, a load from
+ // a global, or the return of a function, it cannot alias. We can also
+ // recurse through PHI nodes and select nodes provided all of their inputs
+ // resolve to one of these known-escaping roots.
+ SmallPtrSet<const Value *, 8> Visited;
+ SmallVector<const Value *, 8> Inputs;
+ Visited.insert(V);
+ Inputs.push_back(V);
+ int Depth = 0;
+ do {
+ const Value *Input = Inputs.pop_back_val();
+
+ if (auto *InputGV = dyn_cast<GlobalValue>(Input)) {
+ // If one input is the very global we're querying against, then we can't
+ // conclude no-alias.
+ if (InputGV == GV)
+ return false;
+
+ // Distinct GlobalVariables never alias, unless overriden or zero-sized.
+ // FIXME: The condition can be refined, but be conservative for now.
+ auto *GVar = dyn_cast<GlobalVariable>(GV);
+ auto *InputGVar = dyn_cast<GlobalVariable>(InputGV);
+ if (GVar && InputGVar &&
+ !GVar->isDeclaration() && !InputGVar->isDeclaration() &&
+ !GVar->isInterposable() && !InputGVar->isInterposable()) {
+ Type *GVType = GVar->getInitializer()->getType();
+ Type *InputGVType = InputGVar->getInitializer()->getType();
+ if (GVType->isSized() && InputGVType->isSized() &&
+ (DL.getTypeAllocSize(GVType) > 0) &&
+ (DL.getTypeAllocSize(InputGVType) > 0))
+ continue;
+ }
+
+ // Conservatively return false, even though we could be smarter
+ // (e.g. look through GlobalAliases).
+ return false;
+ }
+
+ if (isa<Argument>(Input) || isa<CallInst>(Input) ||
+ isa<InvokeInst>(Input)) {
+ // Arguments to functions or returns from functions are inherently
+ // escaping, so we can immediately classify those as not aliasing any
+ // non-addr-taken globals.
+ continue;
+ }
+
+ // Recurse through a limited number of selects, loads and PHIs. This is an
+ // arbitrary depth of 4, lower numbers could be used to fix compile time
+ // issues if needed, but this is generally expected to be only be important
+ // for small depths.
+ if (++Depth > 4)
+ return false;
+
+ if (auto *LI = dyn_cast<LoadInst>(Input)) {
+ // A pointer loaded from a global would have been captured, and we know
+ // that the global is non-escaping, so no alias.
+ const Value *Ptr = GetUnderlyingObject(LI->getPointerOperand(), DL);
+ if (isNonEscapingGlobalNoAliasWithLoad(GV, Ptr, Depth, DL))
+ // The load does not alias with GV.
+ continue;
+ // Otherwise, a load could come from anywhere, so bail.
+ return false;
+ }
+ if (auto *SI = dyn_cast<SelectInst>(Input)) {
+ const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL);
+ const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL);
+ if (Visited.insert(LHS).second)
+ Inputs.push_back(LHS);
+ if (Visited.insert(RHS).second)
+ Inputs.push_back(RHS);
+ continue;
+ }
+ if (auto *PN = dyn_cast<PHINode>(Input)) {
+ for (const Value *Op : PN->incoming_values()) {
+ Op = GetUnderlyingObject(Op, DL);
+ if (Visited.insert(Op).second)
+ Inputs.push_back(Op);
+ }
+ continue;
+ }
+
+ // FIXME: It would be good to handle other obvious no-alias cases here, but
+ // it isn't clear how to do so reasonably without building a small version
+ // of BasicAA into this code. We could recurse into AAResultBase::alias
+ // here but that seems likely to go poorly as we're inside the
+ // implementation of such a query. Until then, just conservatively return
+ // false.
+ return false;
+ } while (!Inputs.empty());
+
+ // If all the inputs to V were definitively no-alias, then V is no-alias.
+ return true;
+}
+
+/// alias - If one of the pointers is to a global that we are tracking, and the
+/// other is some random pointer, we know there cannot be an alias, because the
+/// address of the global isn't taken.
+AliasResult GlobalsAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ // Get the base object these pointers point to.
+ const Value *UV1 = GetUnderlyingObject(LocA.Ptr, DL);
+ const Value *UV2 = GetUnderlyingObject(LocB.Ptr, DL);
+
+ // If either of the underlying values is a global, they may be non-addr-taken
+ // globals, which we can answer queries about.
+ const GlobalValue *GV1 = dyn_cast<GlobalValue>(UV1);
+ const GlobalValue *GV2 = dyn_cast<GlobalValue>(UV2);
+ if (GV1 || GV2) {
+ // If the global's address is taken, pretend we don't know it's a pointer to
+ // the global.
+ if (GV1 && !NonAddressTakenGlobals.count(GV1))
+ GV1 = nullptr;
+ if (GV2 && !NonAddressTakenGlobals.count(GV2))
+ GV2 = nullptr;
+
+ // If the two pointers are derived from two different non-addr-taken
+ // globals we know these can't alias.
+ if (GV1 && GV2 && GV1 != GV2)
+ return NoAlias;
+
+ // If one is and the other isn't, it isn't strictly safe but we can fake
+ // this result if necessary for performance. This does not appear to be
+ // a common problem in practice.
+ if (EnableUnsafeGlobalsModRefAliasResults)
+ if ((GV1 || GV2) && GV1 != GV2)
+ return NoAlias;
+
+ // Check for a special case where a non-escaping global can be used to
+ // conclude no-alias.
+ if ((GV1 || GV2) && GV1 != GV2) {
+ const GlobalValue *GV = GV1 ? GV1 : GV2;
+ const Value *UV = GV1 ? UV2 : UV1;
+ if (isNonEscapingGlobalNoAlias(GV, UV))
+ return NoAlias;
+ }
+
+ // Otherwise if they are both derived from the same addr-taken global, we
+ // can't know the two accesses don't overlap.
+ }
+
+ // These pointers may be based on the memory owned by an indirect global. If
+ // so, we may be able to handle this. First check to see if the base pointer
+ // is a direct load from an indirect global.
+ GV1 = GV2 = nullptr;
+ if (const LoadInst *LI = dyn_cast<LoadInst>(UV1))
+ if (GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0)))
+ if (IndirectGlobals.count(GV))
+ GV1 = GV;
+ if (const LoadInst *LI = dyn_cast<LoadInst>(UV2))
+ if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0)))
+ if (IndirectGlobals.count(GV))
+ GV2 = GV;
+
+ // These pointers may also be from an allocation for the indirect global. If
+ // so, also handle them.
+ if (!GV1)
+ GV1 = AllocsForIndirectGlobals.lookup(UV1);
+ if (!GV2)
+ GV2 = AllocsForIndirectGlobals.lookup(UV2);
+
+ // Now that we know whether the two pointers are related to indirect globals,
+ // use this to disambiguate the pointers. If the pointers are based on
+ // different indirect globals they cannot alias.
+ if (GV1 && GV2 && GV1 != GV2)
+ return NoAlias;
+
+ // If one is based on an indirect global and the other isn't, it isn't
+ // strictly safe but we can fake this result if necessary for performance.
+ // This does not appear to be a common problem in practice.
+ if (EnableUnsafeGlobalsModRefAliasResults)
+ if ((GV1 || GV2) && GV1 != GV2)
+ return NoAlias;
+
+ return AAResultBase::alias(LocA, LocB, AAQI);
+}
+
+ModRefInfo GlobalsAAResult::getModRefInfoForArgument(const CallBase *Call,
+ const GlobalValue *GV,
+ AAQueryInfo &AAQI) {
+ if (Call->doesNotAccessMemory())
+ return ModRefInfo::NoModRef;
+ ModRefInfo ConservativeResult =
+ Call->onlyReadsMemory() ? ModRefInfo::Ref : ModRefInfo::ModRef;
+
+ // Iterate through all the arguments to the called function. If any argument
+ // is based on GV, return the conservative result.
+ for (auto &A : Call->args()) {
+ SmallVector<const Value*, 4> Objects;
+ GetUnderlyingObjects(A, Objects, DL);
+
+ // All objects must be identified.
+ if (!all_of(Objects, isIdentifiedObject) &&
+ // Try ::alias to see if all objects are known not to alias GV.
+ !all_of(Objects, [&](const Value *V) {
+ return this->alias(MemoryLocation(V), MemoryLocation(GV), AAQI) ==
+ NoAlias;
+ }))
+ return ConservativeResult;
+
+ if (is_contained(Objects, GV))
+ return ConservativeResult;
+ }
+
+ // We identified all objects in the argument list, and none of them were GV.
+ return ModRefInfo::NoModRef;
+}
+
+ModRefInfo GlobalsAAResult::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ ModRefInfo Known = ModRefInfo::ModRef;
+
+ // If we are asking for mod/ref info of a direct call with a pointer to a
+ // global we are tracking, return information if we have it.
+ if (const GlobalValue *GV =
+ dyn_cast<GlobalValue>(GetUnderlyingObject(Loc.Ptr, DL)))
+ if (GV->hasLocalLinkage())
+ if (const Function *F = Call->getCalledFunction())
+ if (NonAddressTakenGlobals.count(GV))
+ if (const FunctionInfo *FI = getFunctionInfo(F))
+ Known = unionModRef(FI->getModRefInfoForGlobal(*GV),
+ getModRefInfoForArgument(Call, GV, AAQI));
+
+ if (!isModOrRefSet(Known))
+ return ModRefInfo::NoModRef; // No need to query other mod/ref analyses
+ return intersectModRef(Known, AAResultBase::getModRefInfo(Call, Loc, AAQI));
+}
+
+GlobalsAAResult::GlobalsAAResult(
+ const DataLayout &DL,
+ std::function<const TargetLibraryInfo &(Function &F)> GetTLI)
+ : AAResultBase(), DL(DL), GetTLI(std::move(GetTLI)) {}
+
+GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg)
+ : AAResultBase(std::move(Arg)), DL(Arg.DL), GetTLI(std::move(Arg.GetTLI)),
+ NonAddressTakenGlobals(std::move(Arg.NonAddressTakenGlobals)),
+ IndirectGlobals(std::move(Arg.IndirectGlobals)),
+ AllocsForIndirectGlobals(std::move(Arg.AllocsForIndirectGlobals)),
+ FunctionInfos(std::move(Arg.FunctionInfos)),
+ Handles(std::move(Arg.Handles)) {
+ // Update the parent for each DeletionCallbackHandle.
+ for (auto &H : Handles) {
+ assert(H.GAR == &Arg);
+ H.GAR = this;
+ }
+}
+
+GlobalsAAResult::~GlobalsAAResult() {}
+
+/*static*/ GlobalsAAResult GlobalsAAResult::analyzeModule(
+ Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI,
+ CallGraph &CG) {
+ GlobalsAAResult Result(M.getDataLayout(), GetTLI);
+
+ // Discover which functions aren't recursive, to feed into AnalyzeGlobals.
+ Result.CollectSCCMembership(CG);
+
+ // Find non-addr taken globals.
+ Result.AnalyzeGlobals(M);
+
+ // Propagate on CG.
+ Result.AnalyzeCallGraph(CG, M);
+
+ return Result;
+}
+
+AnalysisKey GlobalsAA::Key;
+
+GlobalsAAResult GlobalsAA::run(Module &M, ModuleAnalysisManager &AM) {
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
+ return GlobalsAAResult::analyzeModule(M, GetTLI,
+ AM.getResult<CallGraphAnalysis>(M));
+}
+
+char GlobalsAAWrapperPass::ID = 0;
+INITIALIZE_PASS_BEGIN(GlobalsAAWrapperPass, "globals-aa",
+ "Globals Alias Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(GlobalsAAWrapperPass, "globals-aa",
+ "Globals Alias Analysis", false, true)
+
+ModulePass *llvm::createGlobalsAAWrapperPass() {
+ return new GlobalsAAWrapperPass();
+}
+
+GlobalsAAWrapperPass::GlobalsAAWrapperPass() : ModulePass(ID) {
+ initializeGlobalsAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool GlobalsAAWrapperPass::runOnModule(Module &M) {
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+ Result.reset(new GlobalsAAResult(GlobalsAAResult::analyzeModule(
+ M, GetTLI, getAnalysis<CallGraphWrapperPass>().getCallGraph())));
+ return false;
+}
+
+bool GlobalsAAWrapperPass::doFinalization(Module &M) {
+ Result.reset();
+ return false;
+}
+
+void GlobalsAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<CallGraphWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/GuardUtils.cpp b/llvm/lib/Analysis/GuardUtils.cpp
new file mode 100644
index 000000000000..cad92f6e56bb
--- /dev/null
+++ b/llvm/lib/Analysis/GuardUtils.cpp
@@ -0,0 +1,49 @@
+//===-- GuardUtils.cpp - Utils for work with guards -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Utils that are used to perform analyzes related to guards and their
+// conditions.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/GuardUtils.h"
+#include "llvm/IR/PatternMatch.h"
+
+using namespace llvm;
+
+bool llvm::isGuard(const User *U) {
+ using namespace llvm::PatternMatch;
+ return match(U, m_Intrinsic<Intrinsic::experimental_guard>());
+}
+
+bool llvm::isGuardAsWidenableBranch(const User *U) {
+ Value *Condition, *WidenableCondition;
+ BasicBlock *GuardedBB, *DeoptBB;
+ if (!parseWidenableBranch(U, Condition, WidenableCondition, GuardedBB,
+ DeoptBB))
+ return false;
+ using namespace llvm::PatternMatch;
+ for (auto &Insn : *DeoptBB) {
+ if (match(&Insn, m_Intrinsic<Intrinsic::experimental_deoptimize>()))
+ return true;
+ if (Insn.mayHaveSideEffects())
+ return false;
+ }
+ return false;
+}
+
+bool llvm::parseWidenableBranch(const User *U, Value *&Condition,
+ Value *&WidenableCondition,
+ BasicBlock *&IfTrueBB, BasicBlock *&IfFalseBB) {
+ using namespace llvm::PatternMatch;
+ if (!match(U, m_Br(m_And(m_Value(Condition), m_Value(WidenableCondition)),
+ IfTrueBB, IfFalseBB)))
+ return false;
+ // TODO: At the moment, we only recognize the branch if the WC call in this
+ // specific position. We should generalize!
+ return match(WidenableCondition,
+ m_Intrinsic<Intrinsic::experimental_widenable_condition>());
+}
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
new file mode 100644
index 000000000000..6fb600114bc6
--- /dev/null
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -0,0 +1,1101 @@
+//===- llvm/Analysis/IVDescriptors.cpp - IndVar Descriptors -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file "describes" induction and recurrence variables.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IVDescriptors.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MustExecute.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/KnownBits.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "iv-descriptors"
+
+bool RecurrenceDescriptor::areAllUsesIn(Instruction *I,
+ SmallPtrSetImpl<Instruction *> &Set) {
+ for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use)
+ if (!Set.count(dyn_cast<Instruction>(*Use)))
+ return false;
+ return true;
+}
+
+bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) {
+ switch (Kind) {
+ default:
+ break;
+ case RK_IntegerAdd:
+ case RK_IntegerMult:
+ case RK_IntegerOr:
+ case RK_IntegerAnd:
+ case RK_IntegerXor:
+ case RK_IntegerMinMax:
+ return true;
+ }
+ return false;
+}
+
+bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) {
+ return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind);
+}
+
+bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) {
+ switch (Kind) {
+ default:
+ break;
+ case RK_IntegerAdd:
+ case RK_IntegerMult:
+ case RK_FloatAdd:
+ case RK_FloatMult:
+ return true;
+ }
+ return false;
+}
+
+/// Determines if Phi may have been type-promoted. If Phi has a single user
+/// that ANDs the Phi with a type mask, return the user. RT is updated to
+/// account for the narrower bit width represented by the mask, and the AND
+/// instruction is added to CI.
+static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT,
+ SmallPtrSetImpl<Instruction *> &Visited,
+ SmallPtrSetImpl<Instruction *> &CI) {
+ if (!Phi->hasOneUse())
+ return Phi;
+
+ const APInt *M = nullptr;
+ Instruction *I, *J = cast<Instruction>(Phi->use_begin()->getUser());
+
+ // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT
+ // with a new integer type of the corresponding bit width.
+ if (match(J, m_c_And(m_Instruction(I), m_APInt(M)))) {
+ int32_t Bits = (*M + 1).exactLogBase2();
+ if (Bits > 0) {
+ RT = IntegerType::get(Phi->getContext(), Bits);
+ Visited.insert(Phi);
+ CI.insert(J);
+ return J;
+ }
+ }
+ return Phi;
+}
+
+/// Compute the minimal bit width needed to represent a reduction whose exit
+/// instruction is given by Exit.
+static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit,
+ DemandedBits *DB,
+ AssumptionCache *AC,
+ DominatorTree *DT) {
+ bool IsSigned = false;
+ const DataLayout &DL = Exit->getModule()->getDataLayout();
+ uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType());
+
+ if (DB) {
+ // Use the demanded bits analysis to determine the bits that are live out
+ // of the exit instruction, rounding up to the nearest power of two. If the
+ // use of demanded bits results in a smaller bit width, we know the value
+ // must be positive (i.e., IsSigned = false), because if this were not the
+ // case, the sign bit would have been demanded.
+ auto Mask = DB->getDemandedBits(Exit);
+ MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros();
+ }
+
+ if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) {
+ // If demanded bits wasn't able to limit the bit width, we can try to use
+ // value tracking instead. This can be the case, for example, if the value
+ // may be negative.
+ auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT);
+ auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType());
+ MaxBitWidth = NumTypeBits - NumSignBits;
+ KnownBits Bits = computeKnownBits(Exit, DL);
+ if (!Bits.isNonNegative()) {
+ // If the value is not known to be non-negative, we set IsSigned to true,
+ // meaning that we will use sext instructions instead of zext
+ // instructions to restore the original type.
+ IsSigned = true;
+ if (!Bits.isNegative())
+ // If the value is not known to be negative, we don't known what the
+ // upper bit is, and therefore, we don't know what kind of extend we
+ // will need. In this case, just increase the bit width by one bit and
+ // use sext.
+ ++MaxBitWidth;
+ }
+ }
+ if (!isPowerOf2_64(MaxBitWidth))
+ MaxBitWidth = NextPowerOf2(MaxBitWidth);
+
+ return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth),
+ IsSigned);
+}
+
+/// Collect cast instructions that can be ignored in the vectorizer's cost
+/// model, given a reduction exit value and the minimal type in which the
+/// reduction can be represented.
+static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit,
+ Type *RecurrenceType,
+ SmallPtrSetImpl<Instruction *> &Casts) {
+
+ SmallVector<Instruction *, 8> Worklist;
+ SmallPtrSet<Instruction *, 8> Visited;
+ Worklist.push_back(Exit);
+
+ while (!Worklist.empty()) {
+ Instruction *Val = Worklist.pop_back_val();
+ Visited.insert(Val);
+ if (auto *Cast = dyn_cast<CastInst>(Val))
+ if (Cast->getSrcTy() == RecurrenceType) {
+ // If the source type of a cast instruction is equal to the recurrence
+ // type, it will be eliminated, and should be ignored in the vectorizer
+ // cost model.
+ Casts.insert(Cast);
+ continue;
+ }
+
+ // Add all operands to the work list if they are loop-varying values that
+ // we haven't yet visited.
+ for (Value *O : cast<User>(Val)->operands())
+ if (auto *I = dyn_cast<Instruction>(O))
+ if (TheLoop->contains(I) && !Visited.count(I))
+ Worklist.push_back(I);
+ }
+}
+
+bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
+ Loop *TheLoop, bool HasFunNoNaNAttr,
+ RecurrenceDescriptor &RedDes,
+ DemandedBits *DB,
+ AssumptionCache *AC,
+ DominatorTree *DT) {
+ if (Phi->getNumIncomingValues() != 2)
+ return false;
+
+ // Reduction variables are only found in the loop header block.
+ if (Phi->getParent() != TheLoop->getHeader())
+ return false;
+
+ // Obtain the reduction start value from the value that comes from the loop
+ // preheader.
+ Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader());
+
+ // ExitInstruction is the single value which is used outside the loop.
+ // We only allow for a single reduction value to be used outside the loop.
+ // This includes users of the reduction, variables (which form a cycle
+ // which ends in the phi node).
+ Instruction *ExitInstruction = nullptr;
+ // Indicates that we found a reduction operation in our scan.
+ bool FoundReduxOp = false;
+
+ // We start with the PHI node and scan for all of the users of this
+ // instruction. All users must be instructions that can be used as reduction
+ // variables (such as ADD). We must have a single out-of-block user. The cycle
+ // must include the original PHI.
+ bool FoundStartPHI = false;
+
+ // To recognize min/max patterns formed by a icmp select sequence, we store
+ // the number of instruction we saw from the recognized min/max pattern,
+ // to make sure we only see exactly the two instructions.
+ unsigned NumCmpSelectPatternInst = 0;
+ InstDesc ReduxDesc(false, nullptr);
+
+ // Data used for determining if the recurrence has been type-promoted.
+ Type *RecurrenceType = Phi->getType();
+ SmallPtrSet<Instruction *, 4> CastInsts;
+ Instruction *Start = Phi;
+ bool IsSigned = false;
+
+ SmallPtrSet<Instruction *, 8> VisitedInsts;
+ SmallVector<Instruction *, 8> Worklist;
+
+ // Return early if the recurrence kind does not match the type of Phi. If the
+ // recurrence kind is arithmetic, we attempt to look through AND operations
+ // resulting from the type promotion performed by InstCombine. Vector
+ // operations are not limited to the legal integer widths, so we may be able
+ // to evaluate the reduction in the narrower width.
+ if (RecurrenceType->isFloatingPointTy()) {
+ if (!isFloatingPointRecurrenceKind(Kind))
+ return false;
+ } else {
+ if (!isIntegerRecurrenceKind(Kind))
+ return false;
+ if (isArithmeticRecurrenceKind(Kind))
+ Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts);
+ }
+
+ Worklist.push_back(Start);
+ VisitedInsts.insert(Start);
+
+ // Start with all flags set because we will intersect this with the reduction
+ // flags from all the reduction operations.
+ FastMathFlags FMF = FastMathFlags::getFast();
+
+ // A value in the reduction can be used:
+ // - By the reduction:
+ // - Reduction operation:
+ // - One use of reduction value (safe).
+ // - Multiple use of reduction value (not safe).
+ // - PHI:
+ // - All uses of the PHI must be the reduction (safe).
+ // - Otherwise, not safe.
+ // - By instructions outside of the loop (safe).
+ // * One value may have several outside users, but all outside
+ // uses must be of the same value.
+ // - By an instruction that is not part of the reduction (not safe).
+ // This is either:
+ // * An instruction type other than PHI or the reduction operation.
+ // * A PHI in the header other than the initial PHI.
+ while (!Worklist.empty()) {
+ Instruction *Cur = Worklist.back();
+ Worklist.pop_back();
+
+ // No Users.
+ // If the instruction has no users then this is a broken chain and can't be
+ // a reduction variable.
+ if (Cur->use_empty())
+ return false;
+
+ bool IsAPhi = isa<PHINode>(Cur);
+
+ // A header PHI use other than the original PHI.
+ if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent())
+ return false;
+
+ // Reductions of instructions such as Div, and Sub is only possible if the
+ // LHS is the reduction variable.
+ if (!Cur->isCommutative() && !IsAPhi && !isa<SelectInst>(Cur) &&
+ !isa<ICmpInst>(Cur) && !isa<FCmpInst>(Cur) &&
+ !VisitedInsts.count(dyn_cast<Instruction>(Cur->getOperand(0))))
+ return false;
+
+ // Any reduction instruction must be of one of the allowed kinds. We ignore
+ // the starting value (the Phi or an AND instruction if the Phi has been
+ // type-promoted).
+ if (Cur != Start) {
+ ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr);
+ if (!ReduxDesc.isRecurrence())
+ return false;
+ // FIXME: FMF is allowed on phi, but propagation is not handled correctly.
+ if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi)
+ FMF &= ReduxDesc.getPatternInst()->getFastMathFlags();
+ }
+
+ bool IsASelect = isa<SelectInst>(Cur);
+
+ // A conditional reduction operation must only have 2 or less uses in
+ // VisitedInsts.
+ if (IsASelect && (Kind == RK_FloatAdd || Kind == RK_FloatMult) &&
+ hasMultipleUsesOf(Cur, VisitedInsts, 2))
+ return false;
+
+ // A reduction operation must only have one use of the reduction value.
+ if (!IsAPhi && !IsASelect && Kind != RK_IntegerMinMax &&
+ Kind != RK_FloatMinMax && hasMultipleUsesOf(Cur, VisitedInsts, 1))
+ return false;
+
+ // All inputs to a PHI node must be a reduction value.
+ if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts))
+ return false;
+
+ if (Kind == RK_IntegerMinMax &&
+ (isa<ICmpInst>(Cur) || isa<SelectInst>(Cur)))
+ ++NumCmpSelectPatternInst;
+ if (Kind == RK_FloatMinMax && (isa<FCmpInst>(Cur) || isa<SelectInst>(Cur)))
+ ++NumCmpSelectPatternInst;
+
+ // Check whether we found a reduction operator.
+ FoundReduxOp |= !IsAPhi && Cur != Start;
+
+ // Process users of current instruction. Push non-PHI nodes after PHI nodes
+ // onto the stack. This way we are going to have seen all inputs to PHI
+ // nodes once we get to them.
+ SmallVector<Instruction *, 8> NonPHIs;
+ SmallVector<Instruction *, 8> PHIs;
+ for (User *U : Cur->users()) {
+ Instruction *UI = cast<Instruction>(U);
+
+ // Check if we found the exit user.
+ BasicBlock *Parent = UI->getParent();
+ if (!TheLoop->contains(Parent)) {
+ // If we already know this instruction is used externally, move on to
+ // the next user.
+ if (ExitInstruction == Cur)
+ continue;
+
+ // Exit if you find multiple values used outside or if the header phi
+ // node is being used. In this case the user uses the value of the
+ // previous iteration, in which case we would loose "VF-1" iterations of
+ // the reduction operation if we vectorize.
+ if (ExitInstruction != nullptr || Cur == Phi)
+ return false;
+
+ // The instruction used by an outside user must be the last instruction
+ // before we feed back to the reduction phi. Otherwise, we loose VF-1
+ // operations on the value.
+ if (!is_contained(Phi->operands(), Cur))
+ return false;
+
+ ExitInstruction = Cur;
+ continue;
+ }
+
+ // Process instructions only once (termination). Each reduction cycle
+ // value must only be used once, except by phi nodes and min/max
+ // reductions which are represented as a cmp followed by a select.
+ InstDesc IgnoredVal(false, nullptr);
+ if (VisitedInsts.insert(UI).second) {
+ if (isa<PHINode>(UI))
+ PHIs.push_back(UI);
+ else
+ NonPHIs.push_back(UI);
+ } else if (!isa<PHINode>(UI) &&
+ ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
+ !isa<SelectInst>(UI)) ||
+ (!isConditionalRdxPattern(Kind, UI).isRecurrence() &&
+ !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence())))
+ return false;
+
+ // Remember that we completed the cycle.
+ if (UI == Phi)
+ FoundStartPHI = true;
+ }
+ Worklist.append(PHIs.begin(), PHIs.end());
+ Worklist.append(NonPHIs.begin(), NonPHIs.end());
+ }
+
+ // This means we have seen one but not the other instruction of the
+ // pattern or more than just a select and cmp.
+ if ((Kind == RK_IntegerMinMax || Kind == RK_FloatMinMax) &&
+ NumCmpSelectPatternInst != 2)
+ return false;
+
+ if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
+ return false;
+
+ if (Start != Phi) {
+ // If the starting value is not the same as the phi node, we speculatively
+ // looked through an 'and' instruction when evaluating a potential
+ // arithmetic reduction to determine if it may have been type-promoted.
+ //
+ // We now compute the minimal bit width that is required to represent the
+ // reduction. If this is the same width that was indicated by the 'and', we
+ // can represent the reduction in the smaller type. The 'and' instruction
+ // will be eliminated since it will essentially be a cast instruction that
+ // can be ignore in the cost model. If we compute a different type than we
+ // did when evaluating the 'and', the 'and' will not be eliminated, and we
+ // will end up with different kinds of operations in the recurrence
+ // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is
+ // the case.
+ //
+ // The vectorizer relies on InstCombine to perform the actual
+ // type-shrinking. It does this by inserting instructions to truncate the
+ // exit value of the reduction to the width indicated by RecurrenceType and
+ // then extend this value back to the original width. If IsSigned is false,
+ // a 'zext' instruction will be generated; otherwise, a 'sext' will be
+ // used.
+ //
+ // TODO: We should not rely on InstCombine to rewrite the reduction in the
+ // smaller type. We should just generate a correctly typed expression
+ // to begin with.
+ Type *ComputedType;
+ std::tie(ComputedType, IsSigned) =
+ computeRecurrenceType(ExitInstruction, DB, AC, DT);
+ if (ComputedType != RecurrenceType)
+ return false;
+
+ // The recurrence expression will be represented in a narrower type. If
+ // there are any cast instructions that will be unnecessary, collect them
+ // in CastInsts. Note that the 'and' instruction was already included in
+ // this list.
+ //
+ // TODO: A better way to represent this may be to tag in some way all the
+ // instructions that are a part of the reduction. The vectorizer cost
+ // model could then apply the recurrence type to these instructions,
+ // without needing a white list of instructions to ignore.
+ collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts);
+ }
+
+ // We found a reduction var if we have reached the original phi node and we
+ // only have a single instruction with out-of-loop users.
+
+ // The ExitInstruction(Instruction which is allowed to have out-of-loop users)
+ // is saved as part of the RecurrenceDescriptor.
+
+ // Save the description of this reduction variable.
+ RecurrenceDescriptor RD(
+ RdxStart, ExitInstruction, Kind, FMF, ReduxDesc.getMinMaxKind(),
+ ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts);
+ RedDes = RD;
+
+ return true;
+}
+
+/// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction
+/// pattern corresponding to a min(X, Y) or max(X, Y).
+RecurrenceDescriptor::InstDesc
+RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) {
+
+ assert((isa<ICmpInst>(I) || isa<FCmpInst>(I) || isa<SelectInst>(I)) &&
+ "Expect a select instruction");
+ Instruction *Cmp = nullptr;
+ SelectInst *Select = nullptr;
+
+ // We must handle the select(cmp()) as a single instruction. Advance to the
+ // select.
+ if ((Cmp = dyn_cast<ICmpInst>(I)) || (Cmp = dyn_cast<FCmpInst>(I))) {
+ if (!Cmp->hasOneUse() || !(Select = dyn_cast<SelectInst>(*I->user_begin())))
+ return InstDesc(false, I);
+ return InstDesc(Select, Prev.getMinMaxKind());
+ }
+
+ // Only handle single use cases for now.
+ if (!(Select = dyn_cast<SelectInst>(I)))
+ return InstDesc(false, I);
+ if (!(Cmp = dyn_cast<ICmpInst>(I->getOperand(0))) &&
+ !(Cmp = dyn_cast<FCmpInst>(I->getOperand(0))))
+ return InstDesc(false, I);
+ if (!Cmp->hasOneUse())
+ return InstDesc(false, I);
+
+ Value *CmpLeft;
+ Value *CmpRight;
+
+ // Look for a min/max pattern.
+ if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_UIntMin);
+ else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_UIntMax);
+ else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_SIntMax);
+ else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_SIntMin);
+ else if (m_OrdFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_FloatMin);
+ else if (m_OrdFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_FloatMax);
+ else if (m_UnordFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_FloatMin);
+ else if (m_UnordFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
+ return InstDesc(Select, MRK_FloatMax);
+
+ return InstDesc(false, I);
+}
+
+/// Returns true if the select instruction has users in the compare-and-add
+/// reduction pattern below. The select instruction argument is the last one
+/// in the sequence.
+///
+/// %sum.1 = phi ...
+/// ...
+/// %cmp = fcmp pred %0, %CFP
+/// %add = fadd %0, %sum.1
+/// %sum.2 = select %cmp, %add, %sum.1
+RecurrenceDescriptor::InstDesc
+RecurrenceDescriptor::isConditionalRdxPattern(
+ RecurrenceKind Kind, Instruction *I) {
+ SelectInst *SI = dyn_cast<SelectInst>(I);
+ if (!SI)
+ return InstDesc(false, I);
+
+ CmpInst *CI = dyn_cast<CmpInst>(SI->getCondition());
+ // Only handle single use cases for now.
+ if (!CI || !CI->hasOneUse())
+ return InstDesc(false, I);
+
+ Value *TrueVal = SI->getTrueValue();
+ Value *FalseVal = SI->getFalseValue();
+ // Handle only when either of operands of select instruction is a PHI
+ // node for now.
+ if ((isa<PHINode>(*TrueVal) && isa<PHINode>(*FalseVal)) ||
+ (!isa<PHINode>(*TrueVal) && !isa<PHINode>(*FalseVal)))
+ return InstDesc(false, I);
+
+ Instruction *I1 =
+ isa<PHINode>(*TrueVal) ? dyn_cast<Instruction>(FalseVal)
+ : dyn_cast<Instruction>(TrueVal);
+ if (!I1 || !I1->isBinaryOp())
+ return InstDesc(false, I);
+
+ Value *Op1, *Op2;
+ if ((m_FAdd(m_Value(Op1), m_Value(Op2)).match(I1) ||
+ m_FSub(m_Value(Op1), m_Value(Op2)).match(I1)) &&
+ I1->isFast())
+ return InstDesc(Kind == RK_FloatAdd, SI);
+
+ if (m_FMul(m_Value(Op1), m_Value(Op2)).match(I1) && (I1->isFast()))
+ return InstDesc(Kind == RK_FloatMult, SI);
+
+ return InstDesc(false, I);
+}
+
+RecurrenceDescriptor::InstDesc
+RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind,
+ InstDesc &Prev, bool HasFunNoNaNAttr) {
+ Instruction *UAI = Prev.getUnsafeAlgebraInst();
+ if (!UAI && isa<FPMathOperator>(I) && !I->hasAllowReassoc())
+ UAI = I; // Found an unsafe (unvectorizable) algebra instruction.
+
+ switch (I->getOpcode()) {
+ default:
+ return InstDesc(false, I);
+ case Instruction::PHI:
+ return InstDesc(I, Prev.getMinMaxKind(), Prev.getUnsafeAlgebraInst());
+ case Instruction::Sub:
+ case Instruction::Add:
+ return InstDesc(Kind == RK_IntegerAdd, I);
+ case Instruction::Mul:
+ return InstDesc(Kind == RK_IntegerMult, I);
+ case Instruction::And:
+ return InstDesc(Kind == RK_IntegerAnd, I);
+ case Instruction::Or:
+ return InstDesc(Kind == RK_IntegerOr, I);
+ case Instruction::Xor:
+ return InstDesc(Kind == RK_IntegerXor, I);
+ case Instruction::FMul:
+ return InstDesc(Kind == RK_FloatMult, I, UAI);
+ case Instruction::FSub:
+ case Instruction::FAdd:
+ return InstDesc(Kind == RK_FloatAdd, I, UAI);
+ case Instruction::Select:
+ if (Kind == RK_FloatAdd || Kind == RK_FloatMult)
+ return isConditionalRdxPattern(Kind, I);
+ LLVM_FALLTHROUGH;
+ case Instruction::FCmp:
+ case Instruction::ICmp:
+ if (Kind != RK_IntegerMinMax &&
+ (!HasFunNoNaNAttr || Kind != RK_FloatMinMax))
+ return InstDesc(false, I);
+ return isMinMaxSelectCmpPattern(I, Prev);
+ }
+}
+
+bool RecurrenceDescriptor::hasMultipleUsesOf(
+ Instruction *I, SmallPtrSetImpl<Instruction *> &Insts,
+ unsigned MaxNumUses) {
+ unsigned NumUses = 0;
+ for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E;
+ ++Use) {
+ if (Insts.count(dyn_cast<Instruction>(*Use)))
+ ++NumUses;
+ if (NumUses > MaxNumUses)
+ return true;
+ }
+
+ return false;
+}
+bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
+ RecurrenceDescriptor &RedDes,
+ DemandedBits *DB, AssumptionCache *AC,
+ DominatorTree *DT) {
+
+ BasicBlock *Header = TheLoop->getHeader();
+ Function &F = *Header->getParent();
+ bool HasFunNoNaNAttr =
+ F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
+
+ if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes,
+ DB, AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n");
+ return true;
+ }
+ if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB,
+ AC, DT)) {
+ LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi
+ << "\n");
+ return true;
+ }
+ // Not a reduction of known type.
+ return false;
+}
+
+bool RecurrenceDescriptor::isFirstOrderRecurrence(
+ PHINode *Phi, Loop *TheLoop,
+ DenseMap<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) {
+
+ // Ensure the phi node is in the loop header and has two incoming values.
+ if (Phi->getParent() != TheLoop->getHeader() ||
+ Phi->getNumIncomingValues() != 2)
+ return false;
+
+ // Ensure the loop has a preheader and a single latch block. The loop
+ // vectorizer will need the latch to set up the next iteration of the loop.
+ auto *Preheader = TheLoop->getLoopPreheader();
+ auto *Latch = TheLoop->getLoopLatch();
+ if (!Preheader || !Latch)
+ return false;
+
+ // Ensure the phi node's incoming blocks are the loop preheader and latch.
+ if (Phi->getBasicBlockIndex(Preheader) < 0 ||
+ Phi->getBasicBlockIndex(Latch) < 0)
+ return false;
+
+ // Get the previous value. The previous value comes from the latch edge while
+ // the initial value comes form the preheader edge.
+ auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch));
+ if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) ||
+ SinkAfter.count(Previous)) // Cannot rely on dominance due to motion.
+ return false;
+
+ // Ensure every user of the phi node is dominated by the previous value.
+ // The dominance requirement ensures the loop vectorizer will not need to
+ // vectorize the initial value prior to the first iteration of the loop.
+ // TODO: Consider extending this sinking to handle other kinds of instructions
+ // and expressions, beyond sinking a single cast past Previous.
+ if (Phi->hasOneUse()) {
+ auto *I = Phi->user_back();
+ if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() &&
+ DT->dominates(Previous, I->user_back())) {
+ if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking.
+ SinkAfter[I] = Previous;
+ return true;
+ }
+ }
+
+ for (User *U : Phi->users())
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ if (!DT->dominates(Previous, I))
+ return false;
+ }
+
+ return true;
+}
+
+/// This function returns the identity element (or neutral element) for
+/// the operation K.
+Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K,
+ Type *Tp) {
+ switch (K) {
+ case RK_IntegerXor:
+ case RK_IntegerAdd:
+ case RK_IntegerOr:
+ // Adding, Xoring, Oring zero to a number does not change it.
+ return ConstantInt::get(Tp, 0);
+ case RK_IntegerMult:
+ // Multiplying a number by 1 does not change it.
+ return ConstantInt::get(Tp, 1);
+ case RK_IntegerAnd:
+ // AND-ing a number with an all-1 value does not change it.
+ return ConstantInt::get(Tp, -1, true);
+ case RK_FloatMult:
+ // Multiplying a number by 1 does not change it.
+ return ConstantFP::get(Tp, 1.0L);
+ case RK_FloatAdd:
+ // Adding zero to a number does not change it.
+ return ConstantFP::get(Tp, 0.0L);
+ default:
+ llvm_unreachable("Unknown recurrence kind");
+ }
+}
+
+/// This function translates the recurrence kind to an LLVM binary operator.
+unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurrenceKind Kind) {
+ switch (Kind) {
+ case RK_IntegerAdd:
+ return Instruction::Add;
+ case RK_IntegerMult:
+ return Instruction::Mul;
+ case RK_IntegerOr:
+ return Instruction::Or;
+ case RK_IntegerAnd:
+ return Instruction::And;
+ case RK_IntegerXor:
+ return Instruction::Xor;
+ case RK_FloatMult:
+ return Instruction::FMul;
+ case RK_FloatAdd:
+ return Instruction::FAdd;
+ case RK_IntegerMinMax:
+ return Instruction::ICmp;
+ case RK_FloatMinMax:
+ return Instruction::FCmp;
+ default:
+ llvm_unreachable("Unknown recurrence operation");
+ }
+}
+
+InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
+ const SCEV *Step, BinaryOperator *BOp,
+ SmallVectorImpl<Instruction *> *Casts)
+ : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) {
+ assert(IK != IK_NoInduction && "Not an induction");
+
+ // Start value type should match the induction kind and the value
+ // itself should not be null.
+ assert(StartValue && "StartValue is null");
+ assert((IK != IK_PtrInduction || StartValue->getType()->isPointerTy()) &&
+ "StartValue is not a pointer for pointer induction");
+ assert((IK != IK_IntInduction || StartValue->getType()->isIntegerTy()) &&
+ "StartValue is not an integer for integer induction");
+
+ // Check the Step Value. It should be non-zero integer value.
+ assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) &&
+ "Step value is zero");
+
+ assert((IK != IK_PtrInduction || getConstIntStepValue()) &&
+ "Step value should be constant for pointer induction");
+ assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) &&
+ "StepValue is not an integer");
+
+ assert((IK != IK_FpInduction || Step->getType()->isFloatingPointTy()) &&
+ "StepValue is not FP for FpInduction");
+ assert((IK != IK_FpInduction ||
+ (InductionBinOp &&
+ (InductionBinOp->getOpcode() == Instruction::FAdd ||
+ InductionBinOp->getOpcode() == Instruction::FSub))) &&
+ "Binary opcode should be specified for FP induction");
+
+ if (Casts) {
+ for (auto &Inst : *Casts) {
+ RedundantCasts.push_back(Inst);
+ }
+ }
+}
+
+int InductionDescriptor::getConsecutiveDirection() const {
+ ConstantInt *ConstStep = getConstIntStepValue();
+ if (ConstStep && (ConstStep->isOne() || ConstStep->isMinusOne()))
+ return ConstStep->getSExtValue();
+ return 0;
+}
+
+ConstantInt *InductionDescriptor::getConstIntStepValue() const {
+ if (isa<SCEVConstant>(Step))
+ return dyn_cast<ConstantInt>(cast<SCEVConstant>(Step)->getValue());
+ return nullptr;
+}
+
+bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop,
+ ScalarEvolution *SE,
+ InductionDescriptor &D) {
+
+ // Here we only handle FP induction variables.
+ assert(Phi->getType()->isFloatingPointTy() && "Unexpected Phi type");
+
+ if (TheLoop->getHeader() != Phi->getParent())
+ return false;
+
+ // The loop may have multiple entrances or multiple exits; we can analyze
+ // this phi if it has a unique entry value and a unique backedge value.
+ if (Phi->getNumIncomingValues() != 2)
+ return false;
+ Value *BEValue = nullptr, *StartValue = nullptr;
+ if (TheLoop->contains(Phi->getIncomingBlock(0))) {
+ BEValue = Phi->getIncomingValue(0);
+ StartValue = Phi->getIncomingValue(1);
+ } else {
+ assert(TheLoop->contains(Phi->getIncomingBlock(1)) &&
+ "Unexpected Phi node in the loop");
+ BEValue = Phi->getIncomingValue(1);
+ StartValue = Phi->getIncomingValue(0);
+ }
+
+ BinaryOperator *BOp = dyn_cast<BinaryOperator>(BEValue);
+ if (!BOp)
+ return false;
+
+ Value *Addend = nullptr;
+ if (BOp->getOpcode() == Instruction::FAdd) {
+ if (BOp->getOperand(0) == Phi)
+ Addend = BOp->getOperand(1);
+ else if (BOp->getOperand(1) == Phi)
+ Addend = BOp->getOperand(0);
+ } else if (BOp->getOpcode() == Instruction::FSub)
+ if (BOp->getOperand(0) == Phi)
+ Addend = BOp->getOperand(1);
+
+ if (!Addend)
+ return false;
+
+ // The addend should be loop invariant
+ if (auto *I = dyn_cast<Instruction>(Addend))
+ if (TheLoop->contains(I))
+ return false;
+
+ // FP Step has unknown SCEV
+ const SCEV *Step = SE->getUnknown(Addend);
+ D = InductionDescriptor(StartValue, IK_FpInduction, Step, BOp);
+ return true;
+}
+
+/// This function is called when we suspect that the update-chain of a phi node
+/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts,
+/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime
+/// predicate P under which the SCEV expression for the phi can be the
+/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the
+/// cast instructions that are involved in the update-chain of this induction.
+/// A caller that adds the required runtime predicate can be free to drop these
+/// cast instructions, and compute the phi using \p AR (instead of some scev
+/// expression with casts).
+///
+/// For example, without a predicate the scev expression can take the following
+/// form:
+/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy)
+///
+/// It corresponds to the following IR sequence:
+/// %for.body:
+/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ]
+/// %casted_phi = "ExtTrunc i64 %x"
+/// %add = add i64 %casted_phi, %step
+///
+/// where %x is given in \p PN,
+/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate,
+/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of
+/// several forms, for example, such as:
+/// ExtTrunc1: %casted_phi = and %x, 2^n-1
+/// or:
+/// ExtTrunc2: %t = shl %x, m
+/// %casted_phi = ashr %t, m
+///
+/// If we are able to find such sequence, we return the instructions
+/// we found, namely %casted_phi and the instructions on its use-def chain up
+/// to the phi (not including the phi).
+static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE,
+ const SCEVUnknown *PhiScev,
+ const SCEVAddRecExpr *AR,
+ SmallVectorImpl<Instruction *> &CastInsts) {
+
+ assert(CastInsts.empty() && "CastInsts is expected to be empty.");
+ auto *PN = cast<PHINode>(PhiScev->getValue());
+ assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression");
+ const Loop *L = AR->getLoop();
+
+ // Find any cast instructions that participate in the def-use chain of
+ // PhiScev in the loop.
+ // FORNOW/TODO: We currently expect the def-use chain to include only
+ // two-operand instructions, where one of the operands is an invariant.
+ // createAddRecFromPHIWithCasts() currently does not support anything more
+ // involved than that, so we keep the search simple. This can be
+ // extended/generalized as needed.
+
+ auto getDef = [&](const Value *Val) -> Value * {
+ const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val);
+ if (!BinOp)
+ return nullptr;
+ Value *Op0 = BinOp->getOperand(0);
+ Value *Op1 = BinOp->getOperand(1);
+ Value *Def = nullptr;
+ if (L->isLoopInvariant(Op0))
+ Def = Op1;
+ else if (L->isLoopInvariant(Op1))
+ Def = Op0;
+ return Def;
+ };
+
+ // Look for the instruction that defines the induction via the
+ // loop backedge.
+ BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return false;
+ Value *Val = PN->getIncomingValueForBlock(Latch);
+ if (!Val)
+ return false;
+
+ // Follow the def-use chain until the induction phi is reached.
+ // If on the way we encounter a Value that has the same SCEV Expr as the
+ // phi node, we can consider the instructions we visit from that point
+ // as part of the cast-sequence that can be ignored.
+ bool InCastSequence = false;
+ auto *Inst = dyn_cast<Instruction>(Val);
+ while (Val != PN) {
+ // If we encountered a phi node other than PN, or if we left the loop,
+ // we bail out.
+ if (!Inst || !L->contains(Inst)) {
+ return false;
+ }
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val));
+ if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR))
+ InCastSequence = true;
+ if (InCastSequence) {
+ // Only the last instruction in the cast sequence is expected to have
+ // uses outside the induction def-use chain.
+ if (!CastInsts.empty())
+ if (!Inst->hasOneUse())
+ return false;
+ CastInsts.push_back(Inst);
+ }
+ Val = getDef(Val);
+ if (!Val)
+ return false;
+ Inst = dyn_cast<Instruction>(Val);
+ }
+
+ return InCastSequence;
+}
+
+bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
+ PredicatedScalarEvolution &PSE,
+ InductionDescriptor &D, bool Assume) {
+ Type *PhiTy = Phi->getType();
+
+ // Handle integer and pointer inductions variables.
+ // Now we handle also FP induction but not trying to make a
+ // recurrent expression from the PHI node in-place.
+
+ if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy() && !PhiTy->isFloatTy() &&
+ !PhiTy->isDoubleTy() && !PhiTy->isHalfTy())
+ return false;
+
+ if (PhiTy->isFloatingPointTy())
+ return isFPInductionPHI(Phi, TheLoop, PSE.getSE(), D);
+
+ const SCEV *PhiScev = PSE.getSCEV(Phi);
+ const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev);
+
+ // We need this expression to be an AddRecExpr.
+ if (Assume && !AR)
+ AR = PSE.getAsAddRec(Phi);
+
+ if (!AR) {
+ LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n");
+ return false;
+ }
+
+ // Record any Cast instructions that participate in the induction update
+ const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev);
+ // If we started from an UnknownSCEV, and managed to build an addRecurrence
+ // only after enabling Assume with PSCEV, this means we may have encountered
+ // cast instructions that required adding a runtime check in order to
+ // guarantee the correctness of the AddRecurrence respresentation of the
+ // induction.
+ if (PhiScev != AR && SymbolicPhi) {
+ SmallVector<Instruction *, 2> Casts;
+ if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts))
+ return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts);
+ }
+
+ return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR);
+}
+
+bool InductionDescriptor::isInductionPHI(
+ PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE,
+ InductionDescriptor &D, const SCEV *Expr,
+ SmallVectorImpl<Instruction *> *CastsToIgnore) {
+ Type *PhiTy = Phi->getType();
+ // We only handle integer and pointer inductions variables.
+ if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy())
+ return false;
+
+ // Check that the PHI is consecutive.
+ const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi);
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev);
+
+ if (!AR) {
+ LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n");
+ return false;
+ }
+
+ if (AR->getLoop() != TheLoop) {
+ // FIXME: We should treat this as a uniform. Unfortunately, we
+ // don't currently know how to handled uniform PHIs.
+ LLVM_DEBUG(
+ dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n");
+ return false;
+ }
+
+ Value *StartValue =
+ Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader());
+
+ BasicBlock *Latch = AR->getLoop()->getLoopLatch();
+ if (!Latch)
+ return false;
+ BinaryOperator *BOp =
+ dyn_cast<BinaryOperator>(Phi->getIncomingValueForBlock(Latch));
+
+ const SCEV *Step = AR->getStepRecurrence(*SE);
+ // Calculate the pointer stride and check if it is consecutive.
+ // The stride may be a constant or a loop invariant integer value.
+ const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step);
+ if (!ConstStep && !SE->isLoopInvariant(Step, TheLoop))
+ return false;
+
+ if (PhiTy->isIntegerTy()) {
+ D = InductionDescriptor(StartValue, IK_IntInduction, Step, BOp,
+ CastsToIgnore);
+ return true;
+ }
+
+ assert(PhiTy->isPointerTy() && "The PHI must be a pointer");
+ // Pointer induction should be a constant.
+ if (!ConstStep)
+ return false;
+
+ ConstantInt *CV = ConstStep->getValue();
+ Type *PointerElementType = PhiTy->getPointerElementType();
+ // The pointer stride cannot be determined if the pointer element type is not
+ // sized.
+ if (!PointerElementType->isSized())
+ return false;
+
+ const DataLayout &DL = Phi->getModule()->getDataLayout();
+ int64_t Size = static_cast<int64_t>(DL.getTypeAllocSize(PointerElementType));
+ if (!Size)
+ return false;
+
+ int64_t CVSize = CV->getSExtValue();
+ if (CVSize % Size)
+ return false;
+ auto *StepValue =
+ SE->getConstant(CV->getType(), CVSize / Size, true /* signed */);
+ D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue, BOp);
+ return true;
+}
diff --git a/llvm/lib/Analysis/IVUsers.cpp b/llvm/lib/Analysis/IVUsers.cpp
new file mode 100644
index 000000000000..681a0cf7e981
--- /dev/null
+++ b/llvm/lib/Analysis/IVUsers.cpp
@@ -0,0 +1,426 @@
+//===- IVUsers.cpp - Induction Variable Users -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bookkeeping for "interesting" users of expressions
+// computed from induction variables.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IVUsers.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+using namespace llvm;
+
+#define DEBUG_TYPE "iv-users"
+
+AnalysisKey IVUsersAnalysis::Key;
+
+IVUsers IVUsersAnalysis::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR) {
+ return IVUsers(&L, &AR.AC, &AR.LI, &AR.DT, &AR.SE);
+}
+
+char IVUsersWrapperPass::ID = 0;
+INITIALIZE_PASS_BEGIN(IVUsersWrapperPass, "iv-users",
+ "Induction Variable Users", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_END(IVUsersWrapperPass, "iv-users", "Induction Variable Users",
+ false, true)
+
+Pass *llvm::createIVUsersPass() { return new IVUsersWrapperPass(); }
+
+/// isInteresting - Test whether the given expression is "interesting" when
+/// used by the given expression, within the context of analyzing the
+/// given loop.
+static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L,
+ ScalarEvolution *SE, LoopInfo *LI) {
+ // An addrec is interesting if it's affine or if it has an interesting start.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+ // Keep things simple. Don't touch loop-variant strides unless they're
+ // only used outside the loop and we can simplify them.
+ if (AR->getLoop() == L)
+ return AR->isAffine() ||
+ (!L->contains(I) &&
+ SE->getSCEVAtScope(AR, LI->getLoopFor(I->getParent())) != AR);
+ // Otherwise recurse to see if the start value is interesting, and that
+ // the step value is not interesting, since we don't yet know how to
+ // do effective SCEV expansions for addrecs with interesting steps.
+ return isInteresting(AR->getStart(), I, L, SE, LI) &&
+ !isInteresting(AR->getStepRecurrence(*SE), I, L, SE, LI);
+ }
+
+ // An add is interesting if exactly one of its operands is interesting.
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
+ bool AnyInterestingYet = false;
+ for (const auto *Op : Add->operands())
+ if (isInteresting(Op, I, L, SE, LI)) {
+ if (AnyInterestingYet)
+ return false;
+ AnyInterestingYet = true;
+ }
+ return AnyInterestingYet;
+ }
+
+ // Nothing else is interesting here.
+ return false;
+}
+
+/// Return true if all loop headers that dominate this block are in simplified
+/// form.
+static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT,
+ const LoopInfo *LI,
+ SmallPtrSetImpl<Loop*> &SimpleLoopNests) {
+ Loop *NearestLoop = nullptr;
+ for (DomTreeNode *Rung = DT->getNode(BB);
+ Rung; Rung = Rung->getIDom()) {
+ BasicBlock *DomBB = Rung->getBlock();
+ Loop *DomLoop = LI->getLoopFor(DomBB);
+ if (DomLoop && DomLoop->getHeader() == DomBB) {
+ // If the domtree walk reaches a loop with no preheader, return false.
+ if (!DomLoop->isLoopSimplifyForm())
+ return false;
+ // If we have already checked this loop nest, stop checking.
+ if (SimpleLoopNests.count(DomLoop))
+ break;
+ // If we have not already checked this loop nest, remember the loop
+ // header nearest to BB. The nearest loop may not contain BB.
+ if (!NearestLoop)
+ NearestLoop = DomLoop;
+ }
+ }
+ if (NearestLoop)
+ SimpleLoopNests.insert(NearestLoop);
+ return true;
+}
+
+/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression
+/// and now we need to decide whether the user should use the preinc or post-inc
+/// value. If this user should use the post-inc version of the IV, return true.
+///
+/// Choosing wrong here can break dominance properties (if we choose to use the
+/// post-inc value when we cannot) or it can end up adding extra live-ranges to
+/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we
+/// should use the post-inc value).
+static bool IVUseShouldUsePostIncValue(Instruction *User, Value *Operand,
+ const Loop *L, DominatorTree *DT) {
+ // If the user is in the loop, use the preinc value.
+ if (L->contains(User))
+ return false;
+
+ BasicBlock *LatchBlock = L->getLoopLatch();
+ if (!LatchBlock)
+ return false;
+
+ // Ok, the user is outside of the loop. If it is dominated by the latch
+ // block, use the post-inc value.
+ if (DT->dominates(LatchBlock, User->getParent()))
+ return true;
+
+ // There is one case we have to be careful of: PHI nodes. These little guys
+ // can live in blocks that are not dominated by the latch block, but (since
+ // their uses occur in the predecessor block, not the block the PHI lives in)
+ // should still use the post-inc value. Check for this case now.
+ PHINode *PN = dyn_cast<PHINode>(User);
+ if (!PN || !Operand)
+ return false; // not a phi, not dominated by latch block.
+
+ // Look at all of the uses of Operand by the PHI node. If any use corresponds
+ // to a block that is not dominated by the latch block, give up and use the
+ // preincremented value.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (PN->getIncomingValue(i) == Operand &&
+ !DT->dominates(LatchBlock, PN->getIncomingBlock(i)))
+ return false;
+
+ // Okay, all uses of Operand by PN are in predecessor blocks that really are
+ // dominated by the latch block. Use the post-incremented value.
+ return true;
+}
+
+/// AddUsersImpl - Inspect the specified instruction. If it is a
+/// reducible SCEV, recursively add its users to the IVUsesByStride set and
+/// return true. Otherwise, return false.
+bool IVUsers::AddUsersImpl(Instruction *I,
+ SmallPtrSetImpl<Loop*> &SimpleLoopNests) {
+ const DataLayout &DL = I->getModule()->getDataLayout();
+
+ // Add this IV user to the Processed set before returning false to ensure that
+ // all IV users are members of the set. See IVUsers::isIVUserOrOperand.
+ if (!Processed.insert(I).second)
+ return true; // Instruction already handled.
+
+ if (!SE->isSCEVable(I->getType()))
+ return false; // Void and FP expressions cannot be reduced.
+
+ // IVUsers is used by LSR which assumes that all SCEV expressions are safe to
+ // pass to SCEVExpander. Expressions are not safe to expand if they represent
+ // operations that are not safe to speculate, namely integer division.
+ if (!isa<PHINode>(I) && !isSafeToSpeculativelyExecute(I))
+ return false;
+
+ // LSR is not APInt clean, do not touch integers bigger than 64-bits.
+ // Also avoid creating IVs of non-native types. For example, we don't want a
+ // 64-bit IV in 32-bit code just because the loop has one 64-bit cast.
+ uint64_t Width = SE->getTypeSizeInBits(I->getType());
+ if (Width > 64 || !DL.isLegalInteger(Width))
+ return false;
+
+ // Don't attempt to promote ephemeral values to indvars. They will be removed
+ // later anyway.
+ if (EphValues.count(I))
+ return false;
+
+ // Get the symbolic expression for this instruction.
+ const SCEV *ISE = SE->getSCEV(I);
+
+ // If we've come to an uninteresting expression, stop the traversal and
+ // call this a user.
+ if (!isInteresting(ISE, I, L, SE, LI))
+ return false;
+
+ SmallPtrSet<Instruction *, 4> UniqueUsers;
+ for (Use &U : I->uses()) {
+ Instruction *User = cast<Instruction>(U.getUser());
+ if (!UniqueUsers.insert(User).second)
+ continue;
+
+ // Do not infinitely recurse on PHI nodes.
+ if (isa<PHINode>(User) && Processed.count(User))
+ continue;
+
+ // Only consider IVUsers that are dominated by simplified loop
+ // headers. Otherwise, SCEVExpander will crash.
+ BasicBlock *UseBB = User->getParent();
+ // A phi's use is live out of its predecessor block.
+ if (PHINode *PHI = dyn_cast<PHINode>(User)) {
+ unsigned OperandNo = U.getOperandNo();
+ unsigned ValNo = PHINode::getIncomingValueNumForOperand(OperandNo);
+ UseBB = PHI->getIncomingBlock(ValNo);
+ }
+ if (!isSimplifiedLoopNest(UseBB, DT, LI, SimpleLoopNests))
+ return false;
+
+ // Descend recursively, but not into PHI nodes outside the current loop.
+ // It's important to see the entire expression outside the loop to get
+ // choices that depend on addressing mode use right, although we won't
+ // consider references outside the loop in all cases.
+ // If User is already in Processed, we don't want to recurse into it again,
+ // but do want to record a second reference in the same instruction.
+ bool AddUserToIVUsers = false;
+ if (LI->getLoopFor(User->getParent()) != L) {
+ if (isa<PHINode>(User) || Processed.count(User) ||
+ !AddUsersImpl(User, SimpleLoopNests)) {
+ LLVM_DEBUG(dbgs() << "FOUND USER in other loop: " << *User << '\n'
+ << " OF SCEV: " << *ISE << '\n');
+ AddUserToIVUsers = true;
+ }
+ } else if (Processed.count(User) || !AddUsersImpl(User, SimpleLoopNests)) {
+ LLVM_DEBUG(dbgs() << "FOUND USER: " << *User << '\n'
+ << " OF SCEV: " << *ISE << '\n');
+ AddUserToIVUsers = true;
+ }
+
+ if (AddUserToIVUsers) {
+ // Okay, we found a user that we cannot reduce.
+ IVStrideUse &NewUse = AddUser(User, I);
+ // Autodetect the post-inc loop set, populating NewUse.PostIncLoops.
+ // The regular return value here is discarded; instead of recording
+ // it, we just recompute it when we need it.
+ const SCEV *OriginalISE = ISE;
+
+ auto NormalizePred = [&](const SCEVAddRecExpr *AR) {
+ auto *L = AR->getLoop();
+ bool Result = IVUseShouldUsePostIncValue(User, I, L, DT);
+ if (Result)
+ NewUse.PostIncLoops.insert(L);
+ return Result;
+ };
+
+ ISE = normalizeForPostIncUseIf(ISE, NormalizePred, *SE);
+
+ // PostIncNormalization effectively simplifies the expression under
+ // pre-increment assumptions. Those assumptions (no wrapping) might not
+ // hold for the post-inc value. Catch such cases by making sure the
+ // transformation is invertible.
+ if (OriginalISE != ISE) {
+ const SCEV *DenormalizedISE =
+ denormalizeForPostIncUse(ISE, NewUse.PostIncLoops, *SE);
+
+ // If we normalized the expression, but denormalization doesn't give the
+ // original one, discard this user.
+ if (OriginalISE != DenormalizedISE) {
+ LLVM_DEBUG(dbgs()
+ << " DISCARDING (NORMALIZATION ISN'T INVERTIBLE): "
+ << *ISE << '\n');
+ IVUses.pop_back();
+ return false;
+ }
+ }
+ LLVM_DEBUG(if (SE->getSCEV(I) != ISE) dbgs()
+ << " NORMALIZED TO: " << *ISE << '\n');
+ }
+ }
+ return true;
+}
+
+bool IVUsers::AddUsersIfInteresting(Instruction *I) {
+ // SCEVExpander can only handle users that are dominated by simplified loop
+ // entries. Keep track of all loops that are only dominated by other simple
+ // loops so we don't traverse the domtree for each user.
+ SmallPtrSet<Loop*,16> SimpleLoopNests;
+
+ return AddUsersImpl(I, SimpleLoopNests);
+}
+
+IVStrideUse &IVUsers::AddUser(Instruction *User, Value *Operand) {
+ IVUses.push_back(new IVStrideUse(this, User, Operand));
+ return IVUses.back();
+}
+
+IVUsers::IVUsers(Loop *L, AssumptionCache *AC, LoopInfo *LI, DominatorTree *DT,
+ ScalarEvolution *SE)
+ : L(L), AC(AC), LI(LI), DT(DT), SE(SE), IVUses() {
+ // Collect ephemeral values so that AddUsersIfInteresting skips them.
+ EphValues.clear();
+ CodeMetrics::collectEphemeralValues(L, AC, EphValues);
+
+ // Find all uses of induction variables in this loop, and categorize
+ // them by stride. Start by finding all of the PHI nodes in the header for
+ // this loop. If they are induction variables, inspect their uses.
+ for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I)
+ (void)AddUsersIfInteresting(&*I);
+}
+
+void IVUsers::print(raw_ostream &OS, const Module *M) const {
+ OS << "IV Users for loop ";
+ L->getHeader()->printAsOperand(OS, false);
+ if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
+ OS << " with backedge-taken count " << *SE->getBackedgeTakenCount(L);
+ }
+ OS << ":\n";
+
+ for (const IVStrideUse &IVUse : IVUses) {
+ OS << " ";
+ IVUse.getOperandValToReplace()->printAsOperand(OS, false);
+ OS << " = " << *getReplacementExpr(IVUse);
+ for (auto PostIncLoop : IVUse.PostIncLoops) {
+ OS << " (post-inc with loop ";
+ PostIncLoop->getHeader()->printAsOperand(OS, false);
+ OS << ")";
+ }
+ OS << " in ";
+ if (IVUse.getUser())
+ IVUse.getUser()->print(OS);
+ else
+ OS << "Printing <null> User";
+ OS << '\n';
+ }
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void IVUsers::dump() const { print(dbgs()); }
+#endif
+
+void IVUsers::releaseMemory() {
+ Processed.clear();
+ IVUses.clear();
+}
+
+IVUsersWrapperPass::IVUsersWrapperPass() : LoopPass(ID) {
+ initializeIVUsersWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void IVUsersWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ AU.setPreservesAll();
+}
+
+bool IVUsersWrapperPass::runOnLoop(Loop *L, LPPassManager &LPM) {
+ auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
+ *L->getHeader()->getParent());
+ auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+
+ IU.reset(new IVUsers(L, AC, LI, DT, SE));
+ return false;
+}
+
+void IVUsersWrapperPass::print(raw_ostream &OS, const Module *M) const {
+ IU->print(OS, M);
+}
+
+void IVUsersWrapperPass::releaseMemory() { IU->releaseMemory(); }
+
+/// getReplacementExpr - Return a SCEV expression which computes the
+/// value of the OperandValToReplace.
+const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const {
+ return SE->getSCEV(IU.getOperandValToReplace());
+}
+
+/// getExpr - Return the expression for the use.
+const SCEV *IVUsers::getExpr(const IVStrideUse &IU) const {
+ return normalizeForPostIncUse(getReplacementExpr(IU), IU.getPostIncLoops(),
+ *SE);
+}
+
+static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) {
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+ if (AR->getLoop() == L)
+ return AR;
+ return findAddRecForLoop(AR->getStart(), L);
+ }
+
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
+ for (const auto *Op : Add->operands())
+ if (const SCEVAddRecExpr *AR = findAddRecForLoop(Op, L))
+ return AR;
+ return nullptr;
+ }
+
+ return nullptr;
+}
+
+const SCEV *IVUsers::getStride(const IVStrideUse &IU, const Loop *L) const {
+ if (const SCEVAddRecExpr *AR = findAddRecForLoop(getExpr(IU), L))
+ return AR->getStepRecurrence(*SE);
+ return nullptr;
+}
+
+void IVStrideUse::transformToPostInc(const Loop *L) {
+ PostIncLoops.insert(L);
+}
+
+void IVStrideUse::deleted() {
+ // Remove this user from the list.
+ Parent->Processed.erase(this->getUser());
+ Parent->IVUses.erase(this);
+ // this now dangles!
+}
diff --git a/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp
new file mode 100644
index 000000000000..68153de8219f
--- /dev/null
+++ b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp
@@ -0,0 +1,106 @@
+//===-- IndirectCallPromotionAnalysis.cpp - Find promotion candidates ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helper methods for identifying profitable indirect call promotion
+// candidates for an instruction when the indirect-call value profile metadata
+// is available.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IndirectCallPromotionAnalysis.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/IndirectCallVisitor.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Debug.h"
+#include <string>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "pgo-icall-prom-analysis"
+
+// The percent threshold for the direct-call target (this call site vs the
+// remaining call count) for it to be considered as the promotion target.
+static cl::opt<unsigned> ICPRemainingPercentThreshold(
+ "icp-remaining-percent-threshold", cl::init(30), cl::Hidden, cl::ZeroOrMore,
+ cl::desc("The percentage threshold against remaining unpromoted indirect "
+ "call count for the promotion"));
+
+// The percent threshold for the direct-call target (this call site vs the
+// total call count) for it to be considered as the promotion target.
+static cl::opt<unsigned>
+ ICPTotalPercentThreshold("icp-total-percent-threshold", cl::init(5),
+ cl::Hidden, cl::ZeroOrMore,
+ cl::desc("The percentage threshold against total "
+ "count for the promotion"));
+
+// Set the maximum number of targets to promote for a single indirect-call
+// callsite.
+static cl::opt<unsigned>
+ MaxNumPromotions("icp-max-prom", cl::init(3), cl::Hidden, cl::ZeroOrMore,
+ cl::desc("Max number of promotions for a single indirect "
+ "call callsite"));
+
+ICallPromotionAnalysis::ICallPromotionAnalysis() {
+ ValueDataArray = std::make_unique<InstrProfValueData[]>(MaxNumPromotions);
+}
+
+bool ICallPromotionAnalysis::isPromotionProfitable(uint64_t Count,
+ uint64_t TotalCount,
+ uint64_t RemainingCount) {
+ return Count * 100 >= ICPRemainingPercentThreshold * RemainingCount &&
+ Count * 100 >= ICPTotalPercentThreshold * TotalCount;
+}
+
+// Indirect-call promotion heuristic. The direct targets are sorted based on
+// the count. Stop at the first target that is not promoted. Returns the
+// number of candidates deemed profitable.
+uint32_t ICallPromotionAnalysis::getProfitablePromotionCandidates(
+ const Instruction *Inst, uint32_t NumVals, uint64_t TotalCount) {
+ ArrayRef<InstrProfValueData> ValueDataRef(ValueDataArray.get(), NumVals);
+
+ LLVM_DEBUG(dbgs() << " \nWork on callsite " << *Inst
+ << " Num_targets: " << NumVals << "\n");
+
+ uint32_t I = 0;
+ uint64_t RemainingCount = TotalCount;
+ for (; I < MaxNumPromotions && I < NumVals; I++) {
+ uint64_t Count = ValueDataRef[I].Count;
+ assert(Count <= RemainingCount);
+ LLVM_DEBUG(dbgs() << " Candidate " << I << " Count=" << Count
+ << " Target_func: " << ValueDataRef[I].Value << "\n");
+
+ if (!isPromotionProfitable(Count, TotalCount, RemainingCount)) {
+ LLVM_DEBUG(dbgs() << " Not promote: Cold target.\n");
+ return I;
+ }
+ RemainingCount -= Count;
+ }
+ return I;
+}
+
+ArrayRef<InstrProfValueData>
+ICallPromotionAnalysis::getPromotionCandidatesForInstruction(
+ const Instruction *I, uint32_t &NumVals, uint64_t &TotalCount,
+ uint32_t &NumCandidates) {
+ bool Res =
+ getValueProfDataFromInst(*I, IPVK_IndirectCallTarget, MaxNumPromotions,
+ ValueDataArray.get(), NumVals, TotalCount);
+ if (!Res) {
+ NumCandidates = 0;
+ return ArrayRef<InstrProfValueData>();
+ }
+ NumCandidates = getProfitablePromotionCandidates(I, NumVals, TotalCount);
+ return ArrayRef<InstrProfValueData>(ValueDataArray.get(), NumVals);
+}
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
new file mode 100644
index 000000000000..89811ec0e377
--- /dev/null
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -0,0 +1,2223 @@
+//===- InlineCost.cpp - Cost analysis for inliner -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements inline cost analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/InlineCost.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/CallingConv.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "inline-cost"
+
+STATISTIC(NumCallsAnalyzed, "Number of call sites analyzed");
+
+static cl::opt<int> InlineThreshold(
+ "inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore,
+ cl::desc("Control the amount of inlining to perform (default = 225)"));
+
+static cl::opt<int> HintThreshold(
+ "inlinehint-threshold", cl::Hidden, cl::init(325), cl::ZeroOrMore,
+ cl::desc("Threshold for inlining functions with inline hint"));
+
+static cl::opt<int>
+ ColdCallSiteThreshold("inline-cold-callsite-threshold", cl::Hidden,
+ cl::init(45), cl::ZeroOrMore,
+ cl::desc("Threshold for inlining cold callsites"));
+
+// We introduce this threshold to help performance of instrumentation based
+// PGO before we actually hook up inliner with analysis passes such as BPI and
+// BFI.
+static cl::opt<int> ColdThreshold(
+ "inlinecold-threshold", cl::Hidden, cl::init(45), cl::ZeroOrMore,
+ cl::desc("Threshold for inlining functions with cold attribute"));
+
+static cl::opt<int>
+ HotCallSiteThreshold("hot-callsite-threshold", cl::Hidden, cl::init(3000),
+ cl::ZeroOrMore,
+ cl::desc("Threshold for hot callsites "));
+
+static cl::opt<int> LocallyHotCallSiteThreshold(
+ "locally-hot-callsite-threshold", cl::Hidden, cl::init(525), cl::ZeroOrMore,
+ cl::desc("Threshold for locally hot callsites "));
+
+static cl::opt<int> ColdCallSiteRelFreq(
+ "cold-callsite-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore,
+ cl::desc("Maximum block frequency, expressed as a percentage of caller's "
+ "entry frequency, for a callsite to be cold in the absence of "
+ "profile information."));
+
+static cl::opt<int> HotCallSiteRelFreq(
+ "hot-callsite-rel-freq", cl::Hidden, cl::init(60), cl::ZeroOrMore,
+ cl::desc("Minimum block frequency, expressed as a multiple of caller's "
+ "entry frequency, for a callsite to be hot in the absence of "
+ "profile information."));
+
+static cl::opt<bool> OptComputeFullInlineCost(
+ "inline-cost-full", cl::Hidden, cl::init(false), cl::ZeroOrMore,
+ cl::desc("Compute the full inline cost of a call site even when the cost "
+ "exceeds the threshold."));
+
+namespace {
+
+class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
+ typedef InstVisitor<CallAnalyzer, bool> Base;
+ friend class InstVisitor<CallAnalyzer, bool>;
+
+ /// The TargetTransformInfo available for this compilation.
+ const TargetTransformInfo &TTI;
+
+ /// Getter for the cache of @llvm.assume intrinsics.
+ std::function<AssumptionCache &(Function &)> &GetAssumptionCache;
+
+ /// Getter for BlockFrequencyInfo
+ Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI;
+
+ /// Profile summary information.
+ ProfileSummaryInfo *PSI;
+
+ /// The called function.
+ Function &F;
+
+ // Cache the DataLayout since we use it a lot.
+ const DataLayout &DL;
+
+ /// The OptimizationRemarkEmitter available for this compilation.
+ OptimizationRemarkEmitter *ORE;
+
+ /// The candidate callsite being analyzed. Please do not use this to do
+ /// analysis in the caller function; we want the inline cost query to be
+ /// easily cacheable. Instead, use the cover function paramHasAttr.
+ CallBase &CandidateCall;
+
+ /// Tunable parameters that control the analysis.
+ const InlineParams &Params;
+
+ /// Upper bound for the inlining cost. Bonuses are being applied to account
+ /// for speculative "expected profit" of the inlining decision.
+ int Threshold;
+
+ /// Inlining cost measured in abstract units, accounts for all the
+ /// instructions expected to be executed for a given function invocation.
+ /// Instructions that are statically proven to be dead based on call-site
+ /// arguments are not counted here.
+ int Cost = 0;
+
+ bool ComputeFullInlineCost;
+
+ bool IsCallerRecursive = false;
+ bool IsRecursiveCall = false;
+ bool ExposesReturnsTwice = false;
+ bool HasDynamicAlloca = false;
+ bool ContainsNoDuplicateCall = false;
+ bool HasReturn = false;
+ bool HasIndirectBr = false;
+ bool HasUninlineableIntrinsic = false;
+ bool InitsVargArgs = false;
+
+ /// Number of bytes allocated statically by the callee.
+ uint64_t AllocatedSize = 0;
+ unsigned NumInstructions = 0;
+ unsigned NumVectorInstructions = 0;
+
+ /// Bonus to be applied when percentage of vector instructions in callee is
+ /// high (see more details in updateThreshold).
+ int VectorBonus = 0;
+ /// Bonus to be applied when the callee has only one reachable basic block.
+ int SingleBBBonus = 0;
+
+ /// While we walk the potentially-inlined instructions, we build up and
+ /// maintain a mapping of simplified values specific to this callsite. The
+ /// idea is to propagate any special information we have about arguments to
+ /// this call through the inlinable section of the function, and account for
+ /// likely simplifications post-inlining. The most important aspect we track
+ /// is CFG altering simplifications -- when we prove a basic block dead, that
+ /// can cause dramatic shifts in the cost of inlining a function.
+ DenseMap<Value *, Constant *> SimplifiedValues;
+
+ /// Keep track of the values which map back (through function arguments) to
+ /// allocas on the caller stack which could be simplified through SROA.
+ DenseMap<Value *, Value *> SROAArgValues;
+
+ /// The mapping of caller Alloca values to their accumulated cost savings. If
+ /// we have to disable SROA for one of the allocas, this tells us how much
+ /// cost must be added.
+ DenseMap<Value *, int> SROAArgCosts;
+
+ /// Keep track of values which map to a pointer base and constant offset.
+ DenseMap<Value *, std::pair<Value *, APInt>> ConstantOffsetPtrs;
+
+ /// Keep track of dead blocks due to the constant arguments.
+ SetVector<BasicBlock *> DeadBlocks;
+
+ /// The mapping of the blocks to their known unique successors due to the
+ /// constant arguments.
+ DenseMap<BasicBlock *, BasicBlock *> KnownSuccessors;
+
+ /// Model the elimination of repeated loads that is expected to happen
+ /// whenever we simplify away the stores that would otherwise cause them to be
+ /// loads.
+ bool EnableLoadElimination;
+ SmallPtrSet<Value *, 16> LoadAddrSet;
+ int LoadEliminationCost = 0;
+
+ // Custom simplification helper routines.
+ bool isAllocaDerivedArg(Value *V);
+ bool lookupSROAArgAndCost(Value *V, Value *&Arg,
+ DenseMap<Value *, int>::iterator &CostIt);
+ void disableSROA(DenseMap<Value *, int>::iterator CostIt);
+ void disableSROA(Value *V);
+ void findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB);
+ void accumulateSROACost(DenseMap<Value *, int>::iterator CostIt,
+ int InstructionCost);
+ void disableLoadElimination();
+ bool isGEPFree(GetElementPtrInst &GEP);
+ bool canFoldInboundsGEP(GetElementPtrInst &I);
+ bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
+ bool simplifyCallSite(Function *F, CallBase &Call);
+ template <typename Callable>
+ bool simplifyInstruction(Instruction &I, Callable Evaluate);
+ ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V);
+
+ /// Return true if the given argument to the function being considered for
+ /// inlining has the given attribute set either at the call site or the
+ /// function declaration. Primarily used to inspect call site specific
+ /// attributes since these can be more precise than the ones on the callee
+ /// itself.
+ bool paramHasAttr(Argument *A, Attribute::AttrKind Attr);
+
+ /// Return true if the given value is known non null within the callee if
+ /// inlined through this particular callsite.
+ bool isKnownNonNullInCallee(Value *V);
+
+ /// Update Threshold based on callsite properties such as callee
+ /// attributes and callee hotness for PGO builds. The Callee is explicitly
+ /// passed to support analyzing indirect calls whose target is inferred by
+ /// analysis.
+ void updateThreshold(CallBase &Call, Function &Callee);
+
+ /// Return true if size growth is allowed when inlining the callee at \p Call.
+ bool allowSizeGrowth(CallBase &Call);
+
+ /// Return true if \p Call is a cold callsite.
+ bool isColdCallSite(CallBase &Call, BlockFrequencyInfo *CallerBFI);
+
+ /// Return a higher threshold if \p Call is a hot callsite.
+ Optional<int> getHotCallSiteThreshold(CallBase &Call,
+ BlockFrequencyInfo *CallerBFI);
+
+ // Custom analysis routines.
+ InlineResult analyzeBlock(BasicBlock *BB,
+ SmallPtrSetImpl<const Value *> &EphValues);
+
+ /// Handle a capped 'int' increment for Cost.
+ void addCost(int64_t Inc, int64_t UpperBound = INT_MAX) {
+ assert(UpperBound > 0 && UpperBound <= INT_MAX && "invalid upper bound");
+ Cost = (int)std::min(UpperBound, Cost + Inc);
+ }
+
+ // Disable several entry points to the visitor so we don't accidentally use
+ // them by declaring but not defining them here.
+ void visit(Module *);
+ void visit(Module &);
+ void visit(Function *);
+ void visit(Function &);
+ void visit(BasicBlock *);
+ void visit(BasicBlock &);
+
+ // Provide base case for our instruction visit.
+ bool visitInstruction(Instruction &I);
+
+ // Our visit overrides.
+ bool visitAlloca(AllocaInst &I);
+ bool visitPHI(PHINode &I);
+ bool visitGetElementPtr(GetElementPtrInst &I);
+ bool visitBitCast(BitCastInst &I);
+ bool visitPtrToInt(PtrToIntInst &I);
+ bool visitIntToPtr(IntToPtrInst &I);
+ bool visitCastInst(CastInst &I);
+ bool visitUnaryInstruction(UnaryInstruction &I);
+ bool visitCmpInst(CmpInst &I);
+ bool visitSub(BinaryOperator &I);
+ bool visitBinaryOperator(BinaryOperator &I);
+ bool visitFNeg(UnaryOperator &I);
+ bool visitLoad(LoadInst &I);
+ bool visitStore(StoreInst &I);
+ bool visitExtractValue(ExtractValueInst &I);
+ bool visitInsertValue(InsertValueInst &I);
+ bool visitCallBase(CallBase &Call);
+ bool visitReturnInst(ReturnInst &RI);
+ bool visitBranchInst(BranchInst &BI);
+ bool visitSelectInst(SelectInst &SI);
+ bool visitSwitchInst(SwitchInst &SI);
+ bool visitIndirectBrInst(IndirectBrInst &IBI);
+ bool visitResumeInst(ResumeInst &RI);
+ bool visitCleanupReturnInst(CleanupReturnInst &RI);
+ bool visitCatchReturnInst(CatchReturnInst &RI);
+ bool visitUnreachableInst(UnreachableInst &I);
+
+public:
+ CallAnalyzer(const TargetTransformInfo &TTI,
+ std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
+ Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI,
+ ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE,
+ Function &Callee, CallBase &Call, const InlineParams &Params)
+ : TTI(TTI), GetAssumptionCache(GetAssumptionCache), GetBFI(GetBFI),
+ PSI(PSI), F(Callee), DL(F.getParent()->getDataLayout()), ORE(ORE),
+ CandidateCall(Call), Params(Params), Threshold(Params.DefaultThreshold),
+ ComputeFullInlineCost(OptComputeFullInlineCost ||
+ Params.ComputeFullInlineCost || ORE),
+ EnableLoadElimination(true) {}
+
+ InlineResult analyzeCall(CallBase &Call);
+
+ int getThreshold() { return Threshold; }
+ int getCost() { return Cost; }
+
+ // Keep a bunch of stats about the cost savings found so we can print them
+ // out when debugging.
+ unsigned NumConstantArgs = 0;
+ unsigned NumConstantOffsetPtrArgs = 0;
+ unsigned NumAllocaArgs = 0;
+ unsigned NumConstantPtrCmps = 0;
+ unsigned NumConstantPtrDiffs = 0;
+ unsigned NumInstructionsSimplified = 0;
+ unsigned SROACostSavings = 0;
+ unsigned SROACostSavingsLost = 0;
+
+ void dump();
+};
+
+} // namespace
+
+/// Test whether the given value is an Alloca-derived function argument.
+bool CallAnalyzer::isAllocaDerivedArg(Value *V) {
+ return SROAArgValues.count(V);
+}
+
+/// Lookup the SROA-candidate argument and cost iterator which V maps to.
+/// Returns false if V does not map to a SROA-candidate.
+bool CallAnalyzer::lookupSROAArgAndCost(
+ Value *V, Value *&Arg, DenseMap<Value *, int>::iterator &CostIt) {
+ if (SROAArgValues.empty() || SROAArgCosts.empty())
+ return false;
+
+ DenseMap<Value *, Value *>::iterator ArgIt = SROAArgValues.find(V);
+ if (ArgIt == SROAArgValues.end())
+ return false;
+
+ Arg = ArgIt->second;
+ CostIt = SROAArgCosts.find(Arg);
+ return CostIt != SROAArgCosts.end();
+}
+
+/// Disable SROA for the candidate marked by this cost iterator.
+///
+/// This marks the candidate as no longer viable for SROA, and adds the cost
+/// savings associated with it back into the inline cost measurement.
+void CallAnalyzer::disableSROA(DenseMap<Value *, int>::iterator CostIt) {
+ // If we're no longer able to perform SROA we need to undo its cost savings
+ // and prevent subsequent analysis.
+ addCost(CostIt->second);
+ SROACostSavings -= CostIt->second;
+ SROACostSavingsLost += CostIt->second;
+ SROAArgCosts.erase(CostIt);
+ disableLoadElimination();
+}
+
+/// If 'V' maps to a SROA candidate, disable SROA for it.
+void CallAnalyzer::disableSROA(Value *V) {
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(V, SROAArg, CostIt))
+ disableSROA(CostIt);
+}
+
+/// Accumulate the given cost for a particular SROA candidate.
+void CallAnalyzer::accumulateSROACost(DenseMap<Value *, int>::iterator CostIt,
+ int InstructionCost) {
+ CostIt->second += InstructionCost;
+ SROACostSavings += InstructionCost;
+}
+
+void CallAnalyzer::disableLoadElimination() {
+ if (EnableLoadElimination) {
+ addCost(LoadEliminationCost);
+ LoadEliminationCost = 0;
+ EnableLoadElimination = false;
+ }
+}
+
+/// Accumulate a constant GEP offset into an APInt if possible.
+///
+/// Returns false if unable to compute the offset for any reason. Respects any
+/// simplified values known during the analysis of this callsite.
+bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) {
+ unsigned IntPtrWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+ assert(IntPtrWidth == Offset.getBitWidth());
+
+ for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
+ GTI != GTE; ++GTI) {
+ ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand());
+ if (!OpC)
+ if (Constant *SimpleOp = SimplifiedValues.lookup(GTI.getOperand()))
+ OpC = dyn_cast<ConstantInt>(SimpleOp);
+ if (!OpC)
+ return false;
+ if (OpC->isZero())
+ continue;
+
+ // Handle a struct index, which adds its field offset to the pointer.
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ unsigned ElementIdx = OpC->getZExtValue();
+ const StructLayout *SL = DL.getStructLayout(STy);
+ Offset += APInt(IntPtrWidth, SL->getElementOffset(ElementIdx));
+ continue;
+ }
+
+ APInt TypeSize(IntPtrWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
+ Offset += OpC->getValue().sextOrTrunc(IntPtrWidth) * TypeSize;
+ }
+ return true;
+}
+
+/// Use TTI to check whether a GEP is free.
+///
+/// Respects any simplified values known during the analysis of this callsite.
+bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) {
+ SmallVector<Value *, 4> Operands;
+ Operands.push_back(GEP.getOperand(0));
+ for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I)
+ if (Constant *SimpleOp = SimplifiedValues.lookup(*I))
+ Operands.push_back(SimpleOp);
+ else
+ Operands.push_back(*I);
+ return TargetTransformInfo::TCC_Free == TTI.getUserCost(&GEP, Operands);
+}
+
+bool CallAnalyzer::visitAlloca(AllocaInst &I) {
+ // Check whether inlining will turn a dynamic alloca into a static
+ // alloca and handle that case.
+ if (I.isArrayAllocation()) {
+ Constant *Size = SimplifiedValues.lookup(I.getArraySize());
+ if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) {
+ Type *Ty = I.getAllocatedType();
+ AllocatedSize = SaturatingMultiplyAdd(
+ AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty).getFixedSize(),
+ AllocatedSize);
+ return Base::visitAlloca(I);
+ }
+ }
+
+ // Accumulate the allocated size.
+ if (I.isStaticAlloca()) {
+ Type *Ty = I.getAllocatedType();
+ AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty).getFixedSize(),
+ AllocatedSize);
+ }
+
+ // We will happily inline static alloca instructions.
+ if (I.isStaticAlloca())
+ return Base::visitAlloca(I);
+
+ // FIXME: This is overly conservative. Dynamic allocas are inefficient for
+ // a variety of reasons, and so we would like to not inline them into
+ // functions which don't currently have a dynamic alloca. This simply
+ // disables inlining altogether in the presence of a dynamic alloca.
+ HasDynamicAlloca = true;
+ return false;
+}
+
+bool CallAnalyzer::visitPHI(PHINode &I) {
+ // FIXME: We need to propagate SROA *disabling* through phi nodes, even
+ // though we don't want to propagate it's bonuses. The idea is to disable
+ // SROA if it *might* be used in an inappropriate manner.
+
+ // Phi nodes are always zero-cost.
+ // FIXME: Pointer sizes may differ between different address spaces, so do we
+ // need to use correct address space in the call to getPointerSizeInBits here?
+ // Or could we skip the getPointerSizeInBits call completely? As far as I can
+ // see the ZeroOffset is used as a dummy value, so we can probably use any
+ // bit width for the ZeroOffset?
+ APInt ZeroOffset = APInt::getNullValue(DL.getPointerSizeInBits(0));
+ bool CheckSROA = I.getType()->isPointerTy();
+
+ // Track the constant or pointer with constant offset we've seen so far.
+ Constant *FirstC = nullptr;
+ std::pair<Value *, APInt> FirstBaseAndOffset = {nullptr, ZeroOffset};
+ Value *FirstV = nullptr;
+
+ for (unsigned i = 0, e = I.getNumIncomingValues(); i != e; ++i) {
+ BasicBlock *Pred = I.getIncomingBlock(i);
+ // If the incoming block is dead, skip the incoming block.
+ if (DeadBlocks.count(Pred))
+ continue;
+ // If the parent block of phi is not the known successor of the incoming
+ // block, skip the incoming block.
+ BasicBlock *KnownSuccessor = KnownSuccessors[Pred];
+ if (KnownSuccessor && KnownSuccessor != I.getParent())
+ continue;
+
+ Value *V = I.getIncomingValue(i);
+ // If the incoming value is this phi itself, skip the incoming value.
+ if (&I == V)
+ continue;
+
+ Constant *C = dyn_cast<Constant>(V);
+ if (!C)
+ C = SimplifiedValues.lookup(V);
+
+ std::pair<Value *, APInt> BaseAndOffset = {nullptr, ZeroOffset};
+ if (!C && CheckSROA)
+ BaseAndOffset = ConstantOffsetPtrs.lookup(V);
+
+ if (!C && !BaseAndOffset.first)
+ // The incoming value is neither a constant nor a pointer with constant
+ // offset, exit early.
+ return true;
+
+ if (FirstC) {
+ if (FirstC == C)
+ // If we've seen a constant incoming value before and it is the same
+ // constant we see this time, continue checking the next incoming value.
+ continue;
+ // Otherwise early exit because we either see a different constant or saw
+ // a constant before but we have a pointer with constant offset this time.
+ return true;
+ }
+
+ if (FirstV) {
+ // The same logic as above, but check pointer with constant offset here.
+ if (FirstBaseAndOffset == BaseAndOffset)
+ continue;
+ return true;
+ }
+
+ if (C) {
+ // This is the 1st time we've seen a constant, record it.
+ FirstC = C;
+ continue;
+ }
+
+ // The remaining case is that this is the 1st time we've seen a pointer with
+ // constant offset, record it.
+ FirstV = V;
+ FirstBaseAndOffset = BaseAndOffset;
+ }
+
+ // Check if we can map phi to a constant.
+ if (FirstC) {
+ SimplifiedValues[&I] = FirstC;
+ return true;
+ }
+
+ // Check if we can map phi to a pointer with constant offset.
+ if (FirstBaseAndOffset.first) {
+ ConstantOffsetPtrs[&I] = FirstBaseAndOffset;
+
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(FirstV, SROAArg, CostIt))
+ SROAArgValues[&I] = SROAArg;
+ }
+
+ return true;
+}
+
+/// Check we can fold GEPs of constant-offset call site argument pointers.
+/// This requires target data and inbounds GEPs.
+///
+/// \return true if the specified GEP can be folded.
+bool CallAnalyzer::canFoldInboundsGEP(GetElementPtrInst &I) {
+ // Check if we have a base + offset for the pointer.
+ std::pair<Value *, APInt> BaseAndOffset =
+ ConstantOffsetPtrs.lookup(I.getPointerOperand());
+ if (!BaseAndOffset.first)
+ return false;
+
+ // Check if the offset of this GEP is constant, and if so accumulate it
+ // into Offset.
+ if (!accumulateGEPOffset(cast<GEPOperator>(I), BaseAndOffset.second))
+ return false;
+
+ // Add the result as a new mapping to Base + Offset.
+ ConstantOffsetPtrs[&I] = BaseAndOffset;
+
+ return true;
+}
+
+bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ bool SROACandidate =
+ lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt);
+
+ // Lambda to check whether a GEP's indices are all constant.
+ auto IsGEPOffsetConstant = [&](GetElementPtrInst &GEP) {
+ for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I)
+ if (!isa<Constant>(*I) && !SimplifiedValues.lookup(*I))
+ return false;
+ return true;
+ };
+
+ if ((I.isInBounds() && canFoldInboundsGEP(I)) || IsGEPOffsetConstant(I)) {
+ if (SROACandidate)
+ SROAArgValues[&I] = SROAArg;
+
+ // Constant GEPs are modeled as free.
+ return true;
+ }
+
+ // Variable GEPs will require math and will disable SROA.
+ if (SROACandidate)
+ disableSROA(CostIt);
+ return isGEPFree(I);
+}
+
+/// Simplify \p I if its operands are constants and update SimplifiedValues.
+/// \p Evaluate is a callable specific to instruction type that evaluates the
+/// instruction when all the operands are constants.
+template <typename Callable>
+bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) {
+ SmallVector<Constant *, 2> COps;
+ for (Value *Op : I.operands()) {
+ Constant *COp = dyn_cast<Constant>(Op);
+ if (!COp)
+ COp = SimplifiedValues.lookup(Op);
+ if (!COp)
+ return false;
+ COps.push_back(COp);
+ }
+ auto *C = Evaluate(COps);
+ if (!C)
+ return false;
+ SimplifiedValues[&I] = C;
+ return true;
+}
+
+bool CallAnalyzer::visitBitCast(BitCastInst &I) {
+ // Propagate constants through bitcasts.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getBitCast(COps[0], I.getType());
+ }))
+ return true;
+
+ // Track base/offsets through casts
+ std::pair<Value *, APInt> BaseAndOffset =
+ ConstantOffsetPtrs.lookup(I.getOperand(0));
+ // Casts don't change the offset, just wrap it up.
+ if (BaseAndOffset.first)
+ ConstantOffsetPtrs[&I] = BaseAndOffset;
+
+ // Also look for SROA candidates here.
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt))
+ SROAArgValues[&I] = SROAArg;
+
+ // Bitcasts are always zero cost.
+ return true;
+}
+
+bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) {
+ // Propagate constants through ptrtoint.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getPtrToInt(COps[0], I.getType());
+ }))
+ return true;
+
+ // Track base/offset pairs when converted to a plain integer provided the
+ // integer is large enough to represent the pointer.
+ unsigned IntegerSize = I.getType()->getScalarSizeInBits();
+ unsigned AS = I.getOperand(0)->getType()->getPointerAddressSpace();
+ if (IntegerSize >= DL.getPointerSizeInBits(AS)) {
+ std::pair<Value *, APInt> BaseAndOffset =
+ ConstantOffsetPtrs.lookup(I.getOperand(0));
+ if (BaseAndOffset.first)
+ ConstantOffsetPtrs[&I] = BaseAndOffset;
+ }
+
+ // This is really weird. Technically, ptrtoint will disable SROA. However,
+ // unless that ptrtoint is *used* somewhere in the live basic blocks after
+ // inlining, it will be nuked, and SROA should proceed. All of the uses which
+ // would block SROA would also block SROA if applied directly to a pointer,
+ // and so we can just add the integer in here. The only places where SROA is
+ // preserved either cannot fire on an integer, or won't in-and-of themselves
+ // disable SROA (ext) w/o some later use that we would see and disable.
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt))
+ SROAArgValues[&I] = SROAArg;
+
+ return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
+}
+
+bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) {
+ // Propagate constants through ptrtoint.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getIntToPtr(COps[0], I.getType());
+ }))
+ return true;
+
+ // Track base/offset pairs when round-tripped through a pointer without
+ // modifications provided the integer is not too large.
+ Value *Op = I.getOperand(0);
+ unsigned IntegerSize = Op->getType()->getScalarSizeInBits();
+ if (IntegerSize <= DL.getPointerTypeSizeInBits(I.getType())) {
+ std::pair<Value *, APInt> BaseAndOffset = ConstantOffsetPtrs.lookup(Op);
+ if (BaseAndOffset.first)
+ ConstantOffsetPtrs[&I] = BaseAndOffset;
+ }
+
+ // "Propagate" SROA here in the same manner as we do for ptrtoint above.
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(Op, SROAArg, CostIt))
+ SROAArgValues[&I] = SROAArg;
+
+ return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
+}
+
+bool CallAnalyzer::visitCastInst(CastInst &I) {
+ // Propagate constants through casts.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getCast(I.getOpcode(), COps[0], I.getType());
+ }))
+ return true;
+
+ // Disable SROA in the face of arbitrary casts we don't whitelist elsewhere.
+ disableSROA(I.getOperand(0));
+
+ // If this is a floating-point cast, and the target says this operation
+ // is expensive, this may eventually become a library call. Treat the cost
+ // as such.
+ switch (I.getOpcode()) {
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::UIToFP:
+ case Instruction::SIToFP:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ if (TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive)
+ addCost(InlineConstants::CallPenalty);
+ break;
+ default:
+ break;
+ }
+
+ return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
+}
+
+bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) {
+ Value *Operand = I.getOperand(0);
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantFoldInstOperands(&I, COps[0], DL);
+ }))
+ return true;
+
+ // Disable any SROA on the argument to arbitrary unary instructions.
+ disableSROA(Operand);
+
+ return false;
+}
+
+bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) {
+ return CandidateCall.paramHasAttr(A->getArgNo(), Attr);
+}
+
+bool CallAnalyzer::isKnownNonNullInCallee(Value *V) {
+ // Does the *call site* have the NonNull attribute set on an argument? We
+ // use the attribute on the call site to memoize any analysis done in the
+ // caller. This will also trip if the callee function has a non-null
+ // parameter attribute, but that's a less interesting case because hopefully
+ // the callee would already have been simplified based on that.
+ if (Argument *A = dyn_cast<Argument>(V))
+ if (paramHasAttr(A, Attribute::NonNull))
+ return true;
+
+ // Is this an alloca in the caller? This is distinct from the attribute case
+ // above because attributes aren't updated within the inliner itself and we
+ // always want to catch the alloca derived case.
+ if (isAllocaDerivedArg(V))
+ // We can actually predict the result of comparisons between an
+ // alloca-derived value and null. Note that this fires regardless of
+ // SROA firing.
+ return true;
+
+ return false;
+}
+
+bool CallAnalyzer::allowSizeGrowth(CallBase &Call) {
+ // If the normal destination of the invoke or the parent block of the call
+ // site is unreachable-terminated, there is little point in inlining this
+ // unless there is literally zero cost.
+ // FIXME: Note that it is possible that an unreachable-terminated block has a
+ // hot entry. For example, in below scenario inlining hot_call_X() may be
+ // beneficial :
+ // main() {
+ // hot_call_1();
+ // ...
+ // hot_call_N()
+ // exit(0);
+ // }
+ // For now, we are not handling this corner case here as it is rare in real
+ // code. In future, we should elaborate this based on BPI and BFI in more
+ // general threshold adjusting heuristics in updateThreshold().
+ if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) {
+ if (isa<UnreachableInst>(II->getNormalDest()->getTerminator()))
+ return false;
+ } else if (isa<UnreachableInst>(Call.getParent()->getTerminator()))
+ return false;
+
+ return true;
+}
+
+bool CallAnalyzer::isColdCallSite(CallBase &Call,
+ BlockFrequencyInfo *CallerBFI) {
+ // If global profile summary is available, then callsite's coldness is
+ // determined based on that.
+ if (PSI && PSI->hasProfileSummary())
+ return PSI->isColdCallSite(CallSite(&Call), CallerBFI);
+
+ // Otherwise we need BFI to be available.
+ if (!CallerBFI)
+ return false;
+
+ // Determine if the callsite is cold relative to caller's entry. We could
+ // potentially cache the computation of scaled entry frequency, but the added
+ // complexity is not worth it unless this scaling shows up high in the
+ // profiles.
+ const BranchProbability ColdProb(ColdCallSiteRelFreq, 100);
+ auto CallSiteBB = Call.getParent();
+ auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB);
+ auto CallerEntryFreq =
+ CallerBFI->getBlockFreq(&(Call.getCaller()->getEntryBlock()));
+ return CallSiteFreq < CallerEntryFreq * ColdProb;
+}
+
+Optional<int>
+CallAnalyzer::getHotCallSiteThreshold(CallBase &Call,
+ BlockFrequencyInfo *CallerBFI) {
+
+ // If global profile summary is available, then callsite's hotness is
+ // determined based on that.
+ if (PSI && PSI->hasProfileSummary() &&
+ PSI->isHotCallSite(CallSite(&Call), CallerBFI))
+ return Params.HotCallSiteThreshold;
+
+ // Otherwise we need BFI to be available and to have a locally hot callsite
+ // threshold.
+ if (!CallerBFI || !Params.LocallyHotCallSiteThreshold)
+ return None;
+
+ // Determine if the callsite is hot relative to caller's entry. We could
+ // potentially cache the computation of scaled entry frequency, but the added
+ // complexity is not worth it unless this scaling shows up high in the
+ // profiles.
+ auto CallSiteBB = Call.getParent();
+ auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency();
+ auto CallerEntryFreq = CallerBFI->getEntryFreq();
+ if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq)
+ return Params.LocallyHotCallSiteThreshold;
+
+ // Otherwise treat it normally.
+ return None;
+}
+
+void CallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) {
+ // If no size growth is allowed for this inlining, set Threshold to 0.
+ if (!allowSizeGrowth(Call)) {
+ Threshold = 0;
+ return;
+ }
+
+ Function *Caller = Call.getCaller();
+
+ // return min(A, B) if B is valid.
+ auto MinIfValid = [](int A, Optional<int> B) {
+ return B ? std::min(A, B.getValue()) : A;
+ };
+
+ // return max(A, B) if B is valid.
+ auto MaxIfValid = [](int A, Optional<int> B) {
+ return B ? std::max(A, B.getValue()) : A;
+ };
+
+ // Various bonus percentages. These are multiplied by Threshold to get the
+ // bonus values.
+ // SingleBBBonus: This bonus is applied if the callee has a single reachable
+ // basic block at the given callsite context. This is speculatively applied
+ // and withdrawn if more than one basic block is seen.
+ //
+ // LstCallToStaticBonus: This large bonus is applied to ensure the inlining
+ // of the last call to a static function as inlining such functions is
+ // guaranteed to reduce code size.
+ //
+ // These bonus percentages may be set to 0 based on properties of the caller
+ // and the callsite.
+ int SingleBBBonusPercent = 50;
+ int VectorBonusPercent = TTI.getInlinerVectorBonusPercent();
+ int LastCallToStaticBonus = InlineConstants::LastCallToStaticBonus;
+
+ // Lambda to set all the above bonus and bonus percentages to 0.
+ auto DisallowAllBonuses = [&]() {
+ SingleBBBonusPercent = 0;
+ VectorBonusPercent = 0;
+ LastCallToStaticBonus = 0;
+ };
+
+ // Use the OptMinSizeThreshold or OptSizeThreshold knob if they are available
+ // and reduce the threshold if the caller has the necessary attribute.
+ if (Caller->hasMinSize()) {
+ Threshold = MinIfValid(Threshold, Params.OptMinSizeThreshold);
+ // For minsize, we want to disable the single BB bonus and the vector
+ // bonuses, but not the last-call-to-static bonus. Inlining the last call to
+ // a static function will, at the minimum, eliminate the parameter setup and
+ // call/return instructions.
+ SingleBBBonusPercent = 0;
+ VectorBonusPercent = 0;
+ } else if (Caller->hasOptSize())
+ Threshold = MinIfValid(Threshold, Params.OptSizeThreshold);
+
+ // Adjust the threshold based on inlinehint attribute and profile based
+ // hotness information if the caller does not have MinSize attribute.
+ if (!Caller->hasMinSize()) {
+ if (Callee.hasFnAttribute(Attribute::InlineHint))
+ Threshold = MaxIfValid(Threshold, Params.HintThreshold);
+
+ // FIXME: After switching to the new passmanager, simplify the logic below
+ // by checking only the callsite hotness/coldness as we will reliably
+ // have local profile information.
+ //
+ // Callsite hotness and coldness can be determined if sample profile is
+ // used (which adds hotness metadata to calls) or if caller's
+ // BlockFrequencyInfo is available.
+ BlockFrequencyInfo *CallerBFI = GetBFI ? &((*GetBFI)(*Caller)) : nullptr;
+ auto HotCallSiteThreshold = getHotCallSiteThreshold(Call, CallerBFI);
+ if (!Caller->hasOptSize() && HotCallSiteThreshold) {
+ LLVM_DEBUG(dbgs() << "Hot callsite.\n");
+ // FIXME: This should update the threshold only if it exceeds the
+ // current threshold, but AutoFDO + ThinLTO currently relies on this
+ // behavior to prevent inlining of hot callsites during ThinLTO
+ // compile phase.
+ Threshold = HotCallSiteThreshold.getValue();
+ } else if (isColdCallSite(Call, CallerBFI)) {
+ LLVM_DEBUG(dbgs() << "Cold callsite.\n");
+ // Do not apply bonuses for a cold callsite including the
+ // LastCallToStatic bonus. While this bonus might result in code size
+ // reduction, it can cause the size of a non-cold caller to increase
+ // preventing it from being inlined.
+ DisallowAllBonuses();
+ Threshold = MinIfValid(Threshold, Params.ColdCallSiteThreshold);
+ } else if (PSI) {
+ // Use callee's global profile information only if we have no way of
+ // determining this via callsite information.
+ if (PSI->isFunctionEntryHot(&Callee)) {
+ LLVM_DEBUG(dbgs() << "Hot callee.\n");
+ // If callsite hotness can not be determined, we may still know
+ // that the callee is hot and treat it as a weaker hint for threshold
+ // increase.
+ Threshold = MaxIfValid(Threshold, Params.HintThreshold);
+ } else if (PSI->isFunctionEntryCold(&Callee)) {
+ LLVM_DEBUG(dbgs() << "Cold callee.\n");
+ // Do not apply bonuses for a cold callee including the
+ // LastCallToStatic bonus. While this bonus might result in code size
+ // reduction, it can cause the size of a non-cold caller to increase
+ // preventing it from being inlined.
+ DisallowAllBonuses();
+ Threshold = MinIfValid(Threshold, Params.ColdThreshold);
+ }
+ }
+ }
+
+ // Finally, take the target-specific inlining threshold multiplier into
+ // account.
+ Threshold *= TTI.getInliningThresholdMultiplier();
+
+ SingleBBBonus = Threshold * SingleBBBonusPercent / 100;
+ VectorBonus = Threshold * VectorBonusPercent / 100;
+
+ bool OnlyOneCallAndLocalLinkage =
+ F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction();
+ // If there is only one call of the function, and it has internal linkage,
+ // the cost of inlining it drops dramatically. It may seem odd to update
+ // Cost in updateThreshold, but the bonus depends on the logic in this method.
+ if (OnlyOneCallAndLocalLinkage)
+ Cost -= LastCallToStaticBonus;
+}
+
+bool CallAnalyzer::visitCmpInst(CmpInst &I) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ // First try to handle simplified comparisons.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getCompare(I.getPredicate(), COps[0], COps[1]);
+ }))
+ return true;
+
+ if (I.getOpcode() == Instruction::FCmp)
+ return false;
+
+ // Otherwise look for a comparison between constant offset pointers with
+ // a common base.
+ Value *LHSBase, *RHSBase;
+ APInt LHSOffset, RHSOffset;
+ std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS);
+ if (LHSBase) {
+ std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS);
+ if (RHSBase && LHSBase == RHSBase) {
+ // We have common bases, fold the icmp to a constant based on the
+ // offsets.
+ Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
+ Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
+ if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) {
+ SimplifiedValues[&I] = C;
+ ++NumConstantPtrCmps;
+ return true;
+ }
+ }
+ }
+
+ // If the comparison is an equality comparison with null, we can simplify it
+ // if we know the value (argument) can't be null
+ if (I.isEquality() && isa<ConstantPointerNull>(I.getOperand(1)) &&
+ isKnownNonNullInCallee(I.getOperand(0))) {
+ bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE;
+ SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType())
+ : ConstantInt::getFalse(I.getType());
+ return true;
+ }
+ // Finally check for SROA candidates in comparisons.
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) {
+ if (isa<ConstantPointerNull>(I.getOperand(1))) {
+ accumulateSROACost(CostIt, InlineConstants::InstrCost);
+ return true;
+ }
+
+ disableSROA(CostIt);
+ }
+
+ return false;
+}
+
+bool CallAnalyzer::visitSub(BinaryOperator &I) {
+ // Try to handle a special case: we can fold computing the difference of two
+ // constant-related pointers.
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ Value *LHSBase, *RHSBase;
+ APInt LHSOffset, RHSOffset;
+ std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS);
+ if (LHSBase) {
+ std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS);
+ if (RHSBase && LHSBase == RHSBase) {
+ // We have common bases, fold the subtract to a constant based on the
+ // offsets.
+ Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
+ Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
+ if (Constant *C = ConstantExpr::getSub(CLHS, CRHS)) {
+ SimplifiedValues[&I] = C;
+ ++NumConstantPtrDiffs;
+ return true;
+ }
+ }
+ }
+
+ // Otherwise, fall back to the generic logic for simplifying and handling
+ // instructions.
+ return Base::visitSub(I);
+}
+
+bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ Constant *CLHS = dyn_cast<Constant>(LHS);
+ if (!CLHS)
+ CLHS = SimplifiedValues.lookup(LHS);
+ Constant *CRHS = dyn_cast<Constant>(RHS);
+ if (!CRHS)
+ CRHS = SimplifiedValues.lookup(RHS);
+
+ Value *SimpleV = nullptr;
+ if (auto FI = dyn_cast<FPMathOperator>(&I))
+ SimpleV = SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS,
+ CRHS ? CRHS : RHS, FI->getFastMathFlags(), DL);
+ else
+ SimpleV =
+ SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, DL);
+
+ if (Constant *C = dyn_cast_or_null<Constant>(SimpleV))
+ SimplifiedValues[&I] = C;
+
+ if (SimpleV)
+ return true;
+
+ // Disable any SROA on arguments to arbitrary, unsimplified binary operators.
+ disableSROA(LHS);
+ disableSROA(RHS);
+
+ // If the instruction is floating point, and the target says this operation
+ // is expensive, this may eventually become a library call. Treat the cost
+ // as such. Unless it's fneg which can be implemented with an xor.
+ using namespace llvm::PatternMatch;
+ if (I.getType()->isFloatingPointTy() &&
+ TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive &&
+ !match(&I, m_FNeg(m_Value())))
+ addCost(InlineConstants::CallPenalty);
+
+ return false;
+}
+
+bool CallAnalyzer::visitFNeg(UnaryOperator &I) {
+ Value *Op = I.getOperand(0);
+ Constant *COp = dyn_cast<Constant>(Op);
+ if (!COp)
+ COp = SimplifiedValues.lookup(Op);
+
+ Value *SimpleV = SimplifyFNegInst(COp ? COp : Op,
+ cast<FPMathOperator>(I).getFastMathFlags(),
+ DL);
+
+ if (Constant *C = dyn_cast_or_null<Constant>(SimpleV))
+ SimplifiedValues[&I] = C;
+
+ if (SimpleV)
+ return true;
+
+ // Disable any SROA on arguments to arbitrary, unsimplified fneg.
+ disableSROA(Op);
+
+ return false;
+}
+
+bool CallAnalyzer::visitLoad(LoadInst &I) {
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) {
+ if (I.isSimple()) {
+ accumulateSROACost(CostIt, InlineConstants::InstrCost);
+ return true;
+ }
+
+ disableSROA(CostIt);
+ }
+
+ // If the data is already loaded from this address and hasn't been clobbered
+ // by any stores or calls, this load is likely to be redundant and can be
+ // eliminated.
+ if (EnableLoadElimination &&
+ !LoadAddrSet.insert(I.getPointerOperand()).second && I.isUnordered()) {
+ LoadEliminationCost += InlineConstants::InstrCost;
+ return true;
+ }
+
+ return false;
+}
+
+bool CallAnalyzer::visitStore(StoreInst &I) {
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) {
+ if (I.isSimple()) {
+ accumulateSROACost(CostIt, InlineConstants::InstrCost);
+ return true;
+ }
+
+ disableSROA(CostIt);
+ }
+
+ // The store can potentially clobber loads and prevent repeated loads from
+ // being eliminated.
+ // FIXME:
+ // 1. We can probably keep an initial set of eliminatable loads substracted
+ // from the cost even when we finally see a store. We just need to disable
+ // *further* accumulation of elimination savings.
+ // 2. We should probably at some point thread MemorySSA for the callee into
+ // this and then use that to actually compute *really* precise savings.
+ disableLoadElimination();
+ return false;
+}
+
+bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) {
+ // Constant folding for extract value is trivial.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getExtractValue(COps[0], I.getIndices());
+ }))
+ return true;
+
+ // SROA can look through these but give them a cost.
+ return false;
+}
+
+bool CallAnalyzer::visitInsertValue(InsertValueInst &I) {
+ // Constant folding for insert value is trivial.
+ if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) {
+ return ConstantExpr::getInsertValue(/*AggregateOperand*/ COps[0],
+ /*InsertedValueOperand*/ COps[1],
+ I.getIndices());
+ }))
+ return true;
+
+ // SROA can look through these but give them a cost.
+ return false;
+}
+
+/// Try to simplify a call site.
+///
+/// Takes a concrete function and callsite and tries to actually simplify it by
+/// analyzing the arguments and call itself with instsimplify. Returns true if
+/// it has simplified the callsite to some other entity (a constant), making it
+/// free.
+bool CallAnalyzer::simplifyCallSite(Function *F, CallBase &Call) {
+ // FIXME: Using the instsimplify logic directly for this is inefficient
+ // because we have to continually rebuild the argument list even when no
+ // simplifications can be performed. Until that is fixed with remapping
+ // inside of instsimplify, directly constant fold calls here.
+ if (!canConstantFoldCallTo(&Call, F))
+ return false;
+
+ // Try to re-map the arguments to constants.
+ SmallVector<Constant *, 4> ConstantArgs;
+ ConstantArgs.reserve(Call.arg_size());
+ for (Value *I : Call.args()) {
+ Constant *C = dyn_cast<Constant>(I);
+ if (!C)
+ C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(I));
+ if (!C)
+ return false; // This argument doesn't map to a constant.
+
+ ConstantArgs.push_back(C);
+ }
+ if (Constant *C = ConstantFoldCall(&Call, F, ConstantArgs)) {
+ SimplifiedValues[&Call] = C;
+ return true;
+ }
+
+ return false;
+}
+
+bool CallAnalyzer::visitCallBase(CallBase &Call) {
+ if (Call.hasFnAttr(Attribute::ReturnsTwice) &&
+ !F.hasFnAttribute(Attribute::ReturnsTwice)) {
+ // This aborts the entire analysis.
+ ExposesReturnsTwice = true;
+ return false;
+ }
+ if (isa<CallInst>(Call) && cast<CallInst>(Call).cannotDuplicate())
+ ContainsNoDuplicateCall = true;
+
+ if (Function *F = Call.getCalledFunction()) {
+ // When we have a concrete function, first try to simplify it directly.
+ if (simplifyCallSite(F, Call))
+ return true;
+
+ // Next check if it is an intrinsic we know about.
+ // FIXME: Lift this into part of the InstVisitor.
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Call)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ if (!Call.onlyReadsMemory() && !isAssumeLikeIntrinsic(II))
+ disableLoadElimination();
+ return Base::visitCallBase(Call);
+
+ case Intrinsic::load_relative:
+ // This is normally lowered to 4 LLVM instructions.
+ addCost(3 * InlineConstants::InstrCost);
+ return false;
+
+ case Intrinsic::memset:
+ case Intrinsic::memcpy:
+ case Intrinsic::memmove:
+ disableLoadElimination();
+ // SROA can usually chew through these intrinsics, but they aren't free.
+ return false;
+ case Intrinsic::icall_branch_funnel:
+ case Intrinsic::localescape:
+ HasUninlineableIntrinsic = true;
+ return false;
+ case Intrinsic::vastart:
+ InitsVargArgs = true;
+ return false;
+ }
+ }
+
+ if (F == Call.getFunction()) {
+ // This flag will fully abort the analysis, so don't bother with anything
+ // else.
+ IsRecursiveCall = true;
+ return false;
+ }
+
+ if (TTI.isLoweredToCall(F)) {
+ // We account for the average 1 instruction per call argument setup
+ // here.
+ addCost(Call.arg_size() * InlineConstants::InstrCost);
+
+ // Everything other than inline ASM will also have a significant cost
+ // merely from making the call.
+ if (!isa<InlineAsm>(Call.getCalledValue()))
+ addCost(InlineConstants::CallPenalty);
+ }
+
+ if (!Call.onlyReadsMemory())
+ disableLoadElimination();
+ return Base::visitCallBase(Call);
+ }
+
+ // Otherwise we're in a very special case -- an indirect function call. See
+ // if we can be particularly clever about this.
+ Value *Callee = Call.getCalledValue();
+
+ // First, pay the price of the argument setup. We account for the average
+ // 1 instruction per call argument setup here.
+ addCost(Call.arg_size() * InlineConstants::InstrCost);
+
+ // Next, check if this happens to be an indirect function call to a known
+ // function in this inline context. If not, we've done all we can.
+ Function *F = dyn_cast_or_null<Function>(SimplifiedValues.lookup(Callee));
+ if (!F) {
+ if (!Call.onlyReadsMemory())
+ disableLoadElimination();
+ return Base::visitCallBase(Call);
+ }
+
+ // If we have a constant that we are calling as a function, we can peer
+ // through it and see the function target. This happens not infrequently
+ // during devirtualization and so we want to give it a hefty bonus for
+ // inlining, but cap that bonus in the event that inlining wouldn't pan
+ // out. Pretend to inline the function, with a custom threshold.
+ auto IndirectCallParams = Params;
+ IndirectCallParams.DefaultThreshold = InlineConstants::IndirectCallThreshold;
+ CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, ORE, *F, Call,
+ IndirectCallParams);
+ if (CA.analyzeCall(Call)) {
+ // We were able to inline the indirect call! Subtract the cost from the
+ // threshold to get the bonus we want to apply, but don't go below zero.
+ Cost -= std::max(0, CA.getThreshold() - CA.getCost());
+ }
+
+ if (!F->onlyReadsMemory())
+ disableLoadElimination();
+ return Base::visitCallBase(Call);
+}
+
+bool CallAnalyzer::visitReturnInst(ReturnInst &RI) {
+ // At least one return instruction will be free after inlining.
+ bool Free = !HasReturn;
+ HasReturn = true;
+ return Free;
+}
+
+bool CallAnalyzer::visitBranchInst(BranchInst &BI) {
+ // We model unconditional branches as essentially free -- they really
+ // shouldn't exist at all, but handling them makes the behavior of the
+ // inliner more regular and predictable. Interestingly, conditional branches
+ // which will fold away are also free.
+ return BI.isUnconditional() || isa<ConstantInt>(BI.getCondition()) ||
+ dyn_cast_or_null<ConstantInt>(
+ SimplifiedValues.lookup(BI.getCondition()));
+}
+
+bool CallAnalyzer::visitSelectInst(SelectInst &SI) {
+ bool CheckSROA = SI.getType()->isPointerTy();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
+ Constant *TrueC = dyn_cast<Constant>(TrueVal);
+ if (!TrueC)
+ TrueC = SimplifiedValues.lookup(TrueVal);
+ Constant *FalseC = dyn_cast<Constant>(FalseVal);
+ if (!FalseC)
+ FalseC = SimplifiedValues.lookup(FalseVal);
+ Constant *CondC =
+ dyn_cast_or_null<Constant>(SimplifiedValues.lookup(SI.getCondition()));
+
+ if (!CondC) {
+ // Select C, X, X => X
+ if (TrueC == FalseC && TrueC) {
+ SimplifiedValues[&SI] = TrueC;
+ return true;
+ }
+
+ if (!CheckSROA)
+ return Base::visitSelectInst(SI);
+
+ std::pair<Value *, APInt> TrueBaseAndOffset =
+ ConstantOffsetPtrs.lookup(TrueVal);
+ std::pair<Value *, APInt> FalseBaseAndOffset =
+ ConstantOffsetPtrs.lookup(FalseVal);
+ if (TrueBaseAndOffset == FalseBaseAndOffset && TrueBaseAndOffset.first) {
+ ConstantOffsetPtrs[&SI] = TrueBaseAndOffset;
+
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(TrueVal, SROAArg, CostIt))
+ SROAArgValues[&SI] = SROAArg;
+ return true;
+ }
+
+ return Base::visitSelectInst(SI);
+ }
+
+ // Select condition is a constant.
+ Value *SelectedV = CondC->isAllOnesValue()
+ ? TrueVal
+ : (CondC->isNullValue()) ? FalseVal : nullptr;
+ if (!SelectedV) {
+ // Condition is a vector constant that is not all 1s or all 0s. If all
+ // operands are constants, ConstantExpr::getSelect() can handle the cases
+ // such as select vectors.
+ if (TrueC && FalseC) {
+ if (auto *C = ConstantExpr::getSelect(CondC, TrueC, FalseC)) {
+ SimplifiedValues[&SI] = C;
+ return true;
+ }
+ }
+ return Base::visitSelectInst(SI);
+ }
+
+ // Condition is either all 1s or all 0s. SI can be simplified.
+ if (Constant *SelectedC = dyn_cast<Constant>(SelectedV)) {
+ SimplifiedValues[&SI] = SelectedC;
+ return true;
+ }
+
+ if (!CheckSROA)
+ return true;
+
+ std::pair<Value *, APInt> BaseAndOffset =
+ ConstantOffsetPtrs.lookup(SelectedV);
+ if (BaseAndOffset.first) {
+ ConstantOffsetPtrs[&SI] = BaseAndOffset;
+
+ Value *SROAArg;
+ DenseMap<Value *, int>::iterator CostIt;
+ if (lookupSROAArgAndCost(SelectedV, SROAArg, CostIt))
+ SROAArgValues[&SI] = SROAArg;
+ }
+
+ return true;
+}
+
+bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) {
+ // We model unconditional switches as free, see the comments on handling
+ // branches.
+ if (isa<ConstantInt>(SI.getCondition()))
+ return true;
+ if (Value *V = SimplifiedValues.lookup(SI.getCondition()))
+ if (isa<ConstantInt>(V))
+ return true;
+
+ // Assume the most general case where the switch is lowered into
+ // either a jump table, bit test, or a balanced binary tree consisting of
+ // case clusters without merging adjacent clusters with the same
+ // destination. We do not consider the switches that are lowered with a mix
+ // of jump table/bit test/binary search tree. The cost of the switch is
+ // proportional to the size of the tree or the size of jump table range.
+ //
+ // NB: We convert large switches which are just used to initialize large phi
+ // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent
+ // inlining those. It will prevent inlining in cases where the optimization
+ // does not (yet) fire.
+
+ // Maximum valid cost increased in this function.
+ int CostUpperBound = INT_MAX - InlineConstants::InstrCost - 1;
+
+ unsigned JumpTableSize = 0;
+ unsigned NumCaseCluster =
+ TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize);
+
+ // If suitable for a jump table, consider the cost for the table size and
+ // branch to destination.
+ if (JumpTableSize) {
+ int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost +
+ 4 * InlineConstants::InstrCost;
+
+ addCost(JTCost, (int64_t)CostUpperBound);
+ return false;
+ }
+
+ // Considering forming a binary search, we should find the number of nodes
+ // which is same as the number of comparisons when lowered. For a given
+ // number of clusters, n, we can define a recursive function, f(n), to find
+ // the number of nodes in the tree. The recursion is :
+ // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3,
+ // and f(n) = n, when n <= 3.
+ // This will lead a binary tree where the leaf should be either f(2) or f(3)
+ // when n > 3. So, the number of comparisons from leaves should be n, while
+ // the number of non-leaf should be :
+ // 2^(log2(n) - 1) - 1
+ // = 2^log2(n) * 2^-1 - 1
+ // = n / 2 - 1.
+ // Considering comparisons from leaf and non-leaf nodes, we can estimate the
+ // number of comparisons in a simple closed form :
+ // n + n / 2 - 1 = n * 3 / 2 - 1
+ if (NumCaseCluster <= 3) {
+ // Suppose a comparison includes one compare and one conditional branch.
+ addCost(NumCaseCluster * 2 * InlineConstants::InstrCost);
+ return false;
+ }
+
+ int64_t ExpectedNumberOfCompare = 3 * (int64_t)NumCaseCluster / 2 - 1;
+ int64_t SwitchCost =
+ ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost;
+
+ addCost(SwitchCost, (int64_t)CostUpperBound);
+ return false;
+}
+
+bool CallAnalyzer::visitIndirectBrInst(IndirectBrInst &IBI) {
+ // We never want to inline functions that contain an indirectbr. This is
+ // incorrect because all the blockaddress's (in static global initializers
+ // for example) would be referring to the original function, and this
+ // indirect jump would jump from the inlined copy of the function into the
+ // original function which is extremely undefined behavior.
+ // FIXME: This logic isn't really right; we can safely inline functions with
+ // indirectbr's as long as no other function or global references the
+ // blockaddress of a block within the current function.
+ HasIndirectBr = true;
+ return false;
+}
+
+bool CallAnalyzer::visitResumeInst(ResumeInst &RI) {
+ // FIXME: It's not clear that a single instruction is an accurate model for
+ // the inline cost of a resume instruction.
+ return false;
+}
+
+bool CallAnalyzer::visitCleanupReturnInst(CleanupReturnInst &CRI) {
+ // FIXME: It's not clear that a single instruction is an accurate model for
+ // the inline cost of a cleanupret instruction.
+ return false;
+}
+
+bool CallAnalyzer::visitCatchReturnInst(CatchReturnInst &CRI) {
+ // FIXME: It's not clear that a single instruction is an accurate model for
+ // the inline cost of a catchret instruction.
+ return false;
+}
+
+bool CallAnalyzer::visitUnreachableInst(UnreachableInst &I) {
+ // FIXME: It might be reasonably to discount the cost of instructions leading
+ // to unreachable as they have the lowest possible impact on both runtime and
+ // code size.
+ return true; // No actual code is needed for unreachable.
+}
+
+bool CallAnalyzer::visitInstruction(Instruction &I) {
+ // Some instructions are free. All of the free intrinsics can also be
+ // handled by SROA, etc.
+ if (TargetTransformInfo::TCC_Free == TTI.getUserCost(&I))
+ return true;
+
+ // We found something we don't understand or can't handle. Mark any SROA-able
+ // values in the operand list as no longer viable.
+ for (User::op_iterator OI = I.op_begin(), OE = I.op_end(); OI != OE; ++OI)
+ disableSROA(*OI);
+
+ return false;
+}
+
+/// Analyze a basic block for its contribution to the inline cost.
+///
+/// This method walks the analyzer over every instruction in the given basic
+/// block and accounts for their cost during inlining at this callsite. It
+/// aborts early if the threshold has been exceeded or an impossible to inline
+/// construct has been detected. It returns false if inlining is no longer
+/// viable, and true if inlining remains viable.
+InlineResult
+CallAnalyzer::analyzeBlock(BasicBlock *BB,
+ SmallPtrSetImpl<const Value *> &EphValues) {
+ for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
+ // FIXME: Currently, the number of instructions in a function regardless of
+ // our ability to simplify them during inline to constants or dead code,
+ // are actually used by the vector bonus heuristic. As long as that's true,
+ // we have to special case debug intrinsics here to prevent differences in
+ // inlining due to debug symbols. Eventually, the number of unsimplified
+ // instructions shouldn't factor into the cost computation, but until then,
+ // hack around it here.
+ if (isa<DbgInfoIntrinsic>(I))
+ continue;
+
+ // Skip ephemeral values.
+ if (EphValues.count(&*I))
+ continue;
+
+ ++NumInstructions;
+ if (isa<ExtractElementInst>(I) || I->getType()->isVectorTy())
+ ++NumVectorInstructions;
+
+ // If the instruction simplified to a constant, there is no cost to this
+ // instruction. Visit the instructions using our InstVisitor to account for
+ // all of the per-instruction logic. The visit tree returns true if we
+ // consumed the instruction in any way, and false if the instruction's base
+ // cost should count against inlining.
+ if (Base::visit(&*I))
+ ++NumInstructionsSimplified;
+ else
+ addCost(InlineConstants::InstrCost);
+
+ using namespace ore;
+ // If the visit this instruction detected an uninlinable pattern, abort.
+ InlineResult IR;
+ if (IsRecursiveCall)
+ IR = "recursive";
+ else if (ExposesReturnsTwice)
+ IR = "exposes returns twice";
+ else if (HasDynamicAlloca)
+ IR = "dynamic alloca";
+ else if (HasIndirectBr)
+ IR = "indirect branch";
+ else if (HasUninlineableIntrinsic)
+ IR = "uninlinable intrinsic";
+ else if (InitsVargArgs)
+ IR = "varargs";
+ if (!IR) {
+ if (ORE)
+ ORE->emit([&]() {
+ return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline",
+ &CandidateCall)
+ << NV("Callee", &F) << " has uninlinable pattern ("
+ << NV("InlineResult", IR.message)
+ << ") and cost is not fully computed";
+ });
+ return IR;
+ }
+
+ // If the caller is a recursive function then we don't want to inline
+ // functions which allocate a lot of stack space because it would increase
+ // the caller stack usage dramatically.
+ if (IsCallerRecursive &&
+ AllocatedSize > InlineConstants::TotalAllocaSizeRecursiveCaller) {
+ InlineResult IR = "recursive and allocates too much stack space";
+ if (ORE)
+ ORE->emit([&]() {
+ return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline",
+ &CandidateCall)
+ << NV("Callee", &F) << " is " << NV("InlineResult", IR.message)
+ << ". Cost is not fully computed";
+ });
+ return IR;
+ }
+
+ // Check if we've passed the maximum possible threshold so we don't spin in
+ // huge basic blocks that will never inline.
+ if (Cost >= Threshold && !ComputeFullInlineCost)
+ return false;
+ }
+
+ return true;
+}
+
+/// Compute the base pointer and cumulative constant offsets for V.
+///
+/// This strips all constant offsets off of V, leaving it the base pointer, and
+/// accumulates the total constant offset applied in the returned constant. It
+/// returns 0 if V is not a pointer, and returns the constant '0' if there are
+/// no constant offsets applied.
+ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) {
+ if (!V->getType()->isPointerTy())
+ return nullptr;
+
+ unsigned AS = V->getType()->getPointerAddressSpace();
+ unsigned IntPtrWidth = DL.getIndexSizeInBits(AS);
+ APInt Offset = APInt::getNullValue(IntPtrWidth);
+
+ // Even though we don't look through PHI nodes, we could be called on an
+ // instruction in an unreachable block, which may be on a cycle.
+ SmallPtrSet<Value *, 4> Visited;
+ Visited.insert(V);
+ do {
+ if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+ if (!GEP->isInBounds() || !accumulateGEPOffset(*GEP, Offset))
+ return nullptr;
+ V = GEP->getPointerOperand();
+ } else if (Operator::getOpcode(V) == Instruction::BitCast) {
+ V = cast<Operator>(V)->getOperand(0);
+ } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+ if (GA->isInterposable())
+ break;
+ V = GA->getAliasee();
+ } else {
+ break;
+ }
+ assert(V->getType()->isPointerTy() && "Unexpected operand type!");
+ } while (Visited.insert(V).second);
+
+ Type *IntPtrTy = DL.getIntPtrType(V->getContext(), AS);
+ return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset));
+}
+
+/// Find dead blocks due to deleted CFG edges during inlining.
+///
+/// If we know the successor of the current block, \p CurrBB, has to be \p
+/// NextBB, the other successors of \p CurrBB are dead if these successors have
+/// no live incoming CFG edges. If one block is found to be dead, we can
+/// continue growing the dead block list by checking the successors of the dead
+/// blocks to see if all their incoming edges are dead or not.
+void CallAnalyzer::findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB) {
+ auto IsEdgeDead = [&](BasicBlock *Pred, BasicBlock *Succ) {
+ // A CFG edge is dead if the predecessor is dead or the predecessor has a
+ // known successor which is not the one under exam.
+ return (DeadBlocks.count(Pred) ||
+ (KnownSuccessors[Pred] && KnownSuccessors[Pred] != Succ));
+ };
+
+ auto IsNewlyDead = [&](BasicBlock *BB) {
+ // If all the edges to a block are dead, the block is also dead.
+ return (!DeadBlocks.count(BB) &&
+ llvm::all_of(predecessors(BB),
+ [&](BasicBlock *P) { return IsEdgeDead(P, BB); }));
+ };
+
+ for (BasicBlock *Succ : successors(CurrBB)) {
+ if (Succ == NextBB || !IsNewlyDead(Succ))
+ continue;
+ SmallVector<BasicBlock *, 4> NewDead;
+ NewDead.push_back(Succ);
+ while (!NewDead.empty()) {
+ BasicBlock *Dead = NewDead.pop_back_val();
+ if (DeadBlocks.insert(Dead))
+ // Continue growing the dead block lists.
+ for (BasicBlock *S : successors(Dead))
+ if (IsNewlyDead(S))
+ NewDead.push_back(S);
+ }
+ }
+}
+
+/// Analyze a call site for potential inlining.
+///
+/// Returns true if inlining this call is viable, and false if it is not
+/// viable. It computes the cost and adjusts the threshold based on numerous
+/// factors and heuristics. If this method returns false but the computed cost
+/// is below the computed threshold, then inlining was forcibly disabled by
+/// some artifact of the routine.
+InlineResult CallAnalyzer::analyzeCall(CallBase &Call) {
+ ++NumCallsAnalyzed;
+
+ // Perform some tweaks to the cost and threshold based on the direct
+ // callsite information.
+
+ // We want to more aggressively inline vector-dense kernels, so up the
+ // threshold, and we'll lower it if the % of vector instructions gets too
+ // low. Note that these bonuses are some what arbitrary and evolved over time
+ // by accident as much as because they are principled bonuses.
+ //
+ // FIXME: It would be nice to remove all such bonuses. At least it would be
+ // nice to base the bonus values on something more scientific.
+ assert(NumInstructions == 0);
+ assert(NumVectorInstructions == 0);
+
+ // Update the threshold based on callsite properties
+ updateThreshold(Call, F);
+
+ // While Threshold depends on commandline options that can take negative
+ // values, we want to enforce the invariant that the computed threshold and
+ // bonuses are non-negative.
+ assert(Threshold >= 0);
+ assert(SingleBBBonus >= 0);
+ assert(VectorBonus >= 0);
+
+ // Speculatively apply all possible bonuses to Threshold. If cost exceeds
+ // this Threshold any time, and cost cannot decrease, we can stop processing
+ // the rest of the function body.
+ Threshold += (SingleBBBonus + VectorBonus);
+
+ // Give out bonuses for the callsite, as the instructions setting them up
+ // will be gone after inlining.
+ addCost(-getCallsiteCost(Call, DL));
+
+ // If this function uses the coldcc calling convention, prefer not to inline
+ // it.
+ if (F.getCallingConv() == CallingConv::Cold)
+ Cost += InlineConstants::ColdccPenalty;
+
+ // Check if we're done. This can happen due to bonuses and penalties.
+ if (Cost >= Threshold && !ComputeFullInlineCost)
+ return "high cost";
+
+ if (F.empty())
+ return true;
+
+ Function *Caller = Call.getFunction();
+ // Check if the caller function is recursive itself.
+ for (User *U : Caller->users()) {
+ CallBase *Call = dyn_cast<CallBase>(U);
+ if (Call && Call->getFunction() == Caller) {
+ IsCallerRecursive = true;
+ break;
+ }
+ }
+
+ // Populate our simplified values by mapping from function arguments to call
+ // arguments with known important simplifications.
+ auto CAI = Call.arg_begin();
+ for (Function::arg_iterator FAI = F.arg_begin(), FAE = F.arg_end();
+ FAI != FAE; ++FAI, ++CAI) {
+ assert(CAI != Call.arg_end());
+ if (Constant *C = dyn_cast<Constant>(CAI))
+ SimplifiedValues[&*FAI] = C;
+
+ Value *PtrArg = *CAI;
+ if (ConstantInt *C = stripAndComputeInBoundsConstantOffsets(PtrArg)) {
+ ConstantOffsetPtrs[&*FAI] = std::make_pair(PtrArg, C->getValue());
+
+ // We can SROA any pointer arguments derived from alloca instructions.
+ if (isa<AllocaInst>(PtrArg)) {
+ SROAArgValues[&*FAI] = PtrArg;
+ SROAArgCosts[PtrArg] = 0;
+ }
+ }
+ }
+ NumConstantArgs = SimplifiedValues.size();
+ NumConstantOffsetPtrArgs = ConstantOffsetPtrs.size();
+ NumAllocaArgs = SROAArgValues.size();
+
+ // FIXME: If a caller has multiple calls to a callee, we end up recomputing
+ // the ephemeral values multiple times (and they're completely determined by
+ // the callee, so this is purely duplicate work).
+ SmallPtrSet<const Value *, 32> EphValues;
+ CodeMetrics::collectEphemeralValues(&F, &GetAssumptionCache(F), EphValues);
+
+ // The worklist of live basic blocks in the callee *after* inlining. We avoid
+ // adding basic blocks of the callee which can be proven to be dead for this
+ // particular call site in order to get more accurate cost estimates. This
+ // requires a somewhat heavyweight iteration pattern: we need to walk the
+ // basic blocks in a breadth-first order as we insert live successors. To
+ // accomplish this, prioritizing for small iterations because we exit after
+ // crossing our threshold, we use a small-size optimized SetVector.
+ typedef SetVector<BasicBlock *, SmallVector<BasicBlock *, 16>,
+ SmallPtrSet<BasicBlock *, 16>>
+ BBSetVector;
+ BBSetVector BBWorklist;
+ BBWorklist.insert(&F.getEntryBlock());
+ bool SingleBB = true;
+ // Note that we *must not* cache the size, this loop grows the worklist.
+ for (unsigned Idx = 0; Idx != BBWorklist.size(); ++Idx) {
+ // Bail out the moment we cross the threshold. This means we'll under-count
+ // the cost, but only when undercounting doesn't matter.
+ if (Cost >= Threshold && !ComputeFullInlineCost)
+ break;
+
+ BasicBlock *BB = BBWorklist[Idx];
+ if (BB->empty())
+ continue;
+
+ // Disallow inlining a blockaddress with uses other than strictly callbr.
+ // A blockaddress only has defined behavior for an indirect branch in the
+ // same function, and we do not currently support inlining indirect
+ // branches. But, the inliner may not see an indirect branch that ends up
+ // being dead code at a particular call site. If the blockaddress escapes
+ // the function, e.g., via a global variable, inlining may lead to an
+ // invalid cross-function reference.
+ // FIXME: pr/39560: continue relaxing this overt restriction.
+ if (BB->hasAddressTaken())
+ for (User *U : BlockAddress::get(&*BB)->users())
+ if (!isa<CallBrInst>(*U))
+ return "blockaddress used outside of callbr";
+
+ // Analyze the cost of this block. If we blow through the threshold, this
+ // returns false, and we can bail on out.
+ InlineResult IR = analyzeBlock(BB, EphValues);
+ if (!IR)
+ return IR;
+
+ Instruction *TI = BB->getTerminator();
+
+ // Add in the live successors by first checking whether we have terminator
+ // that may be simplified based on the values simplified by this call.
+ if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
+ if (BI->isConditional()) {
+ Value *Cond = BI->getCondition();
+ if (ConstantInt *SimpleCond =
+ dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) {
+ BasicBlock *NextBB = BI->getSuccessor(SimpleCond->isZero() ? 1 : 0);
+ BBWorklist.insert(NextBB);
+ KnownSuccessors[BB] = NextBB;
+ findDeadBlocks(BB, NextBB);
+ continue;
+ }
+ }
+ } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
+ Value *Cond = SI->getCondition();
+ if (ConstantInt *SimpleCond =
+ dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) {
+ BasicBlock *NextBB = SI->findCaseValue(SimpleCond)->getCaseSuccessor();
+ BBWorklist.insert(NextBB);
+ KnownSuccessors[BB] = NextBB;
+ findDeadBlocks(BB, NextBB);
+ continue;
+ }
+ }
+
+ // If we're unable to select a particular successor, just count all of
+ // them.
+ for (unsigned TIdx = 0, TSize = TI->getNumSuccessors(); TIdx != TSize;
+ ++TIdx)
+ BBWorklist.insert(TI->getSuccessor(TIdx));
+
+ // If we had any successors at this point, than post-inlining is likely to
+ // have them as well. Note that we assume any basic blocks which existed
+ // due to branches or switches which folded above will also fold after
+ // inlining.
+ if (SingleBB && TI->getNumSuccessors() > 1) {
+ // Take off the bonus we applied to the threshold.
+ Threshold -= SingleBBBonus;
+ SingleBB = false;
+ }
+ }
+
+ bool OnlyOneCallAndLocalLinkage =
+ F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction();
+ // If this is a noduplicate call, we can still inline as long as
+ // inlining this would cause the removal of the caller (so the instruction
+ // is not actually duplicated, just moved).
+ if (!OnlyOneCallAndLocalLinkage && ContainsNoDuplicateCall)
+ return "noduplicate";
+
+ // Loops generally act a lot like calls in that they act like barriers to
+ // movement, require a certain amount of setup, etc. So when optimising for
+ // size, we penalise any call sites that perform loops. We do this after all
+ // other costs here, so will likely only be dealing with relatively small
+ // functions (and hence DT and LI will hopefully be cheap).
+ if (Caller->hasMinSize()) {
+ DominatorTree DT(F);
+ LoopInfo LI(DT);
+ int NumLoops = 0;
+ for (Loop *L : LI) {
+ // Ignore loops that will not be executed
+ if (DeadBlocks.count(L->getHeader()))
+ continue;
+ NumLoops++;
+ }
+ addCost(NumLoops * InlineConstants::CallPenalty);
+ }
+
+ // We applied the maximum possible vector bonus at the beginning. Now,
+ // subtract the excess bonus, if any, from the Threshold before
+ // comparing against Cost.
+ if (NumVectorInstructions <= NumInstructions / 10)
+ Threshold -= VectorBonus;
+ else if (NumVectorInstructions <= NumInstructions / 2)
+ Threshold -= VectorBonus/2;
+
+ return Cost < std::max(1, Threshold);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+/// Dump stats about this call's analysis.
+LLVM_DUMP_METHOD void CallAnalyzer::dump() {
+#define DEBUG_PRINT_STAT(x) dbgs() << " " #x ": " << x << "\n"
+ DEBUG_PRINT_STAT(NumConstantArgs);
+ DEBUG_PRINT_STAT(NumConstantOffsetPtrArgs);
+ DEBUG_PRINT_STAT(NumAllocaArgs);
+ DEBUG_PRINT_STAT(NumConstantPtrCmps);
+ DEBUG_PRINT_STAT(NumConstantPtrDiffs);
+ DEBUG_PRINT_STAT(NumInstructionsSimplified);
+ DEBUG_PRINT_STAT(NumInstructions);
+ DEBUG_PRINT_STAT(SROACostSavings);
+ DEBUG_PRINT_STAT(SROACostSavingsLost);
+ DEBUG_PRINT_STAT(LoadEliminationCost);
+ DEBUG_PRINT_STAT(ContainsNoDuplicateCall);
+ DEBUG_PRINT_STAT(Cost);
+ DEBUG_PRINT_STAT(Threshold);
+#undef DEBUG_PRINT_STAT
+}
+#endif
+
+/// Test that there are no attribute conflicts between Caller and Callee
+/// that prevent inlining.
+static bool functionsHaveCompatibleAttributes(Function *Caller,
+ Function *Callee,
+ TargetTransformInfo &TTI) {
+ return TTI.areInlineCompatible(Caller, Callee) &&
+ AttributeFuncs::areInlineCompatible(*Caller, *Callee);
+}
+
+int llvm::getCallsiteCost(CallBase &Call, const DataLayout &DL) {
+ int Cost = 0;
+ for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
+ if (Call.isByValArgument(I)) {
+ // We approximate the number of loads and stores needed by dividing the
+ // size of the byval type by the target's pointer size.
+ PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType());
+ unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType());
+ unsigned AS = PTy->getAddressSpace();
+ unsigned PointerSize = DL.getPointerSizeInBits(AS);
+ // Ceiling division.
+ unsigned NumStores = (TypeSize + PointerSize - 1) / PointerSize;
+
+ // If it generates more than 8 stores it is likely to be expanded as an
+ // inline memcpy so we take that as an upper bound. Otherwise we assume
+ // one load and one store per word copied.
+ // FIXME: The maxStoresPerMemcpy setting from the target should be used
+ // here instead of a magic number of 8, but it's not available via
+ // DataLayout.
+ NumStores = std::min(NumStores, 8U);
+
+ Cost += 2 * NumStores * InlineConstants::InstrCost;
+ } else {
+ // For non-byval arguments subtract off one instruction per call
+ // argument.
+ Cost += InlineConstants::InstrCost;
+ }
+ }
+ // The call instruction also disappears after inlining.
+ Cost += InlineConstants::InstrCost + InlineConstants::CallPenalty;
+ return Cost;
+}
+
+InlineCost llvm::getInlineCost(
+ CallBase &Call, const InlineParams &Params, TargetTransformInfo &CalleeTTI,
+ std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
+ Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI,
+ ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) {
+ return getInlineCost(Call, Call.getCalledFunction(), Params, CalleeTTI,
+ GetAssumptionCache, GetBFI, PSI, ORE);
+}
+
+InlineCost llvm::getInlineCost(
+ CallBase &Call, Function *Callee, const InlineParams &Params,
+ TargetTransformInfo &CalleeTTI,
+ std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
+ Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI,
+ ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) {
+
+ // Cannot inline indirect calls.
+ if (!Callee)
+ return llvm::InlineCost::getNever("indirect call");
+
+ // Never inline calls with byval arguments that does not have the alloca
+ // address space. Since byval arguments can be replaced with a copy to an
+ // alloca, the inlined code would need to be adjusted to handle that the
+ // argument is in the alloca address space (so it is a little bit complicated
+ // to solve).
+ unsigned AllocaAS = Callee->getParent()->getDataLayout().getAllocaAddrSpace();
+ for (unsigned I = 0, E = Call.arg_size(); I != E; ++I)
+ if (Call.isByValArgument(I)) {
+ PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType());
+ if (PTy->getAddressSpace() != AllocaAS)
+ return llvm::InlineCost::getNever("byval arguments without alloca"
+ " address space");
+ }
+
+ // Calls to functions with always-inline attributes should be inlined
+ // whenever possible.
+ if (Call.hasFnAttr(Attribute::AlwaysInline)) {
+ auto IsViable = isInlineViable(*Callee);
+ if (IsViable)
+ return llvm::InlineCost::getAlways("always inline attribute");
+ return llvm::InlineCost::getNever(IsViable.message);
+ }
+
+ // Never inline functions with conflicting attributes (unless callee has
+ // always-inline attribute).
+ Function *Caller = Call.getCaller();
+ if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI))
+ return llvm::InlineCost::getNever("conflicting attributes");
+
+ // Don't inline this call if the caller has the optnone attribute.
+ if (Caller->hasOptNone())
+ return llvm::InlineCost::getNever("optnone attribute");
+
+ // Don't inline a function that treats null pointer as valid into a caller
+ // that does not have this attribute.
+ if (!Caller->nullPointerIsDefined() && Callee->nullPointerIsDefined())
+ return llvm::InlineCost::getNever("nullptr definitions incompatible");
+
+ // Don't inline functions which can be interposed at link-time.
+ if (Callee->isInterposable())
+ return llvm::InlineCost::getNever("interposable");
+
+ // Don't inline functions marked noinline.
+ if (Callee->hasFnAttribute(Attribute::NoInline))
+ return llvm::InlineCost::getNever("noinline function attribute");
+
+ // Don't inline call sites marked noinline.
+ if (Call.isNoInline())
+ return llvm::InlineCost::getNever("noinline call site attribute");
+
+ LLVM_DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName()
+ << "... (caller:" << Caller->getName() << ")\n");
+
+ CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee,
+ Call, Params);
+ InlineResult ShouldInline = CA.analyzeCall(Call);
+
+ LLVM_DEBUG(CA.dump());
+
+ // Check if there was a reason to force inlining or no inlining.
+ if (!ShouldInline && CA.getCost() < CA.getThreshold())
+ return InlineCost::getNever(ShouldInline.message);
+ if (ShouldInline && CA.getCost() >= CA.getThreshold())
+ return InlineCost::getAlways("empty function");
+
+ return llvm::InlineCost::get(CA.getCost(), CA.getThreshold());
+}
+
+InlineResult llvm::isInlineViable(Function &F) {
+ bool ReturnsTwice = F.hasFnAttribute(Attribute::ReturnsTwice);
+ for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
+ // Disallow inlining of functions which contain indirect branches.
+ if (isa<IndirectBrInst>(BI->getTerminator()))
+ return "contains indirect branches";
+
+ // Disallow inlining of blockaddresses which are used by non-callbr
+ // instructions.
+ if (BI->hasAddressTaken())
+ for (User *U : BlockAddress::get(&*BI)->users())
+ if (!isa<CallBrInst>(*U))
+ return "blockaddress used outside of callbr";
+
+ for (auto &II : *BI) {
+ CallBase *Call = dyn_cast<CallBase>(&II);
+ if (!Call)
+ continue;
+
+ // Disallow recursive calls.
+ if (&F == Call->getCalledFunction())
+ return "recursive call";
+
+ // Disallow calls which expose returns-twice to a function not previously
+ // attributed as such.
+ if (!ReturnsTwice && isa<CallInst>(Call) &&
+ cast<CallInst>(Call)->canReturnTwice())
+ return "exposes returns-twice attribute";
+
+ if (Call->getCalledFunction())
+ switch (Call->getCalledFunction()->getIntrinsicID()) {
+ default:
+ break;
+ // Disallow inlining of @llvm.icall.branch.funnel because current
+ // backend can't separate call targets from call arguments.
+ case llvm::Intrinsic::icall_branch_funnel:
+ return "disallowed inlining of @llvm.icall.branch.funnel";
+ // Disallow inlining functions that call @llvm.localescape. Doing this
+ // correctly would require major changes to the inliner.
+ case llvm::Intrinsic::localescape:
+ return "disallowed inlining of @llvm.localescape";
+ // Disallow inlining of functions that initialize VarArgs with va_start.
+ case llvm::Intrinsic::vastart:
+ return "contains VarArgs initialized with va_start";
+ }
+ }
+ }
+
+ return true;
+}
+
+// APIs to create InlineParams based on command line flags and/or other
+// parameters.
+
+InlineParams llvm::getInlineParams(int Threshold) {
+ InlineParams Params;
+
+ // This field is the threshold to use for a callee by default. This is
+ // derived from one or more of:
+ // * optimization or size-optimization levels,
+ // * a value passed to createFunctionInliningPass function, or
+ // * the -inline-threshold flag.
+ // If the -inline-threshold flag is explicitly specified, that is used
+ // irrespective of anything else.
+ if (InlineThreshold.getNumOccurrences() > 0)
+ Params.DefaultThreshold = InlineThreshold;
+ else
+ Params.DefaultThreshold = Threshold;
+
+ // Set the HintThreshold knob from the -inlinehint-threshold.
+ Params.HintThreshold = HintThreshold;
+
+ // Set the HotCallSiteThreshold knob from the -hot-callsite-threshold.
+ Params.HotCallSiteThreshold = HotCallSiteThreshold;
+
+ // If the -locally-hot-callsite-threshold is explicitly specified, use it to
+ // populate LocallyHotCallSiteThreshold. Later, we populate
+ // Params.LocallyHotCallSiteThreshold from -locally-hot-callsite-threshold if
+ // we know that optimization level is O3 (in the getInlineParams variant that
+ // takes the opt and size levels).
+ // FIXME: Remove this check (and make the assignment unconditional) after
+ // addressing size regression issues at O2.
+ if (LocallyHotCallSiteThreshold.getNumOccurrences() > 0)
+ Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold;
+
+ // Set the ColdCallSiteThreshold knob from the -inline-cold-callsite-threshold.
+ Params.ColdCallSiteThreshold = ColdCallSiteThreshold;
+
+ // Set the OptMinSizeThreshold and OptSizeThreshold params only if the
+ // -inlinehint-threshold commandline option is not explicitly given. If that
+ // option is present, then its value applies even for callees with size and
+ // minsize attributes.
+ // If the -inline-threshold is not specified, set the ColdThreshold from the
+ // -inlinecold-threshold even if it is not explicitly passed. If
+ // -inline-threshold is specified, then -inlinecold-threshold needs to be
+ // explicitly specified to set the ColdThreshold knob
+ if (InlineThreshold.getNumOccurrences() == 0) {
+ Params.OptMinSizeThreshold = InlineConstants::OptMinSizeThreshold;
+ Params.OptSizeThreshold = InlineConstants::OptSizeThreshold;
+ Params.ColdThreshold = ColdThreshold;
+ } else if (ColdThreshold.getNumOccurrences() > 0) {
+ Params.ColdThreshold = ColdThreshold;
+ }
+ return Params;
+}
+
+InlineParams llvm::getInlineParams() {
+ return getInlineParams(InlineThreshold);
+}
+
+// Compute the default threshold for inlining based on the opt level and the
+// size opt level.
+static int computeThresholdFromOptLevels(unsigned OptLevel,
+ unsigned SizeOptLevel) {
+ if (OptLevel > 2)
+ return InlineConstants::OptAggressiveThreshold;
+ if (SizeOptLevel == 1) // -Os
+ return InlineConstants::OptSizeThreshold;
+ if (SizeOptLevel == 2) // -Oz
+ return InlineConstants::OptMinSizeThreshold;
+ return InlineThreshold;
+}
+
+InlineParams llvm::getInlineParams(unsigned OptLevel, unsigned SizeOptLevel) {
+ auto Params =
+ getInlineParams(computeThresholdFromOptLevels(OptLevel, SizeOptLevel));
+ // At O3, use the value of -locally-hot-callsite-threshold option to populate
+ // Params.LocallyHotCallSiteThreshold. Below O3, this flag has effect only
+ // when it is specified explicitly.
+ if (OptLevel > 2)
+ Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold;
+ return Params;
+}
diff --git a/llvm/lib/Analysis/InstCount.cpp b/llvm/lib/Analysis/InstCount.cpp
new file mode 100644
index 000000000000..943a99a5f46d
--- /dev/null
+++ b/llvm/lib/Analysis/InstCount.cpp
@@ -0,0 +1,78 @@
+//===-- InstCount.cpp - Collects the count of all instructions ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass collects the count of all instructions and reports them
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+#define DEBUG_TYPE "instcount"
+
+STATISTIC(TotalInsts , "Number of instructions (of all types)");
+STATISTIC(TotalBlocks, "Number of basic blocks");
+STATISTIC(TotalFuncs , "Number of non-external functions");
+
+#define HANDLE_INST(N, OPCODE, CLASS) \
+ STATISTIC(Num ## OPCODE ## Inst, "Number of " #OPCODE " insts");
+
+#include "llvm/IR/Instruction.def"
+
+namespace {
+ class InstCount : public FunctionPass, public InstVisitor<InstCount> {
+ friend class InstVisitor<InstCount>;
+
+ void visitFunction (Function &F) { ++TotalFuncs; }
+ void visitBasicBlock(BasicBlock &BB) { ++TotalBlocks; }
+
+#define HANDLE_INST(N, OPCODE, CLASS) \
+ void visit##OPCODE(CLASS &) { ++Num##OPCODE##Inst; ++TotalInsts; }
+
+#include "llvm/IR/Instruction.def"
+
+ void visitInstruction(Instruction &I) {
+ errs() << "Instruction Count does not know about " << I;
+ llvm_unreachable(nullptr);
+ }
+ public:
+ static char ID; // Pass identification, replacement for typeid
+ InstCount() : FunctionPass(ID) {
+ initializeInstCountPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ void print(raw_ostream &O, const Module *M) const override {}
+
+ };
+}
+
+char InstCount::ID = 0;
+INITIALIZE_PASS(InstCount, "instcount",
+ "Counts the various types of Instructions", false, true)
+
+FunctionPass *llvm::createInstCountPass() { return new InstCount(); }
+
+// InstCount::run - This is the main Analysis entry point for a
+// function.
+//
+bool InstCount::runOnFunction(Function &F) {
+ visit(F);
+ return false;
+}
diff --git a/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp b/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp
new file mode 100644
index 000000000000..35190ce3e11a
--- /dev/null
+++ b/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp
@@ -0,0 +1,160 @@
+//===-- InstructionPrecedenceTracking.cpp -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Implements a class that is able to define some instructions as "special"
+// (e.g. as having implicit control flow, or writing memory, or having another
+// interesting property) and then efficiently answers queries of the types:
+// 1. Are there any special instructions in the block of interest?
+// 2. Return first of the special instructions in the given block;
+// 3. Check if the given instruction is preceeded by the first special
+// instruction in the same block.
+// The class provides caching that allows to answer these queries quickly. The
+// user must make sure that the cached data is invalidated properly whenever
+// a content of some tracked block is changed.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/InstructionPrecedenceTracking.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/PatternMatch.h"
+
+using namespace llvm;
+
+#ifndef NDEBUG
+static cl::opt<bool> ExpensiveAsserts(
+ "ipt-expensive-asserts",
+ cl::desc("Perform expensive assert validation on every query to Instruction"
+ " Precedence Tracking"),
+ cl::init(false), cl::Hidden);
+#endif
+
+const Instruction *InstructionPrecedenceTracking::getFirstSpecialInstruction(
+ const BasicBlock *BB) {
+#ifndef NDEBUG
+ // If there is a bug connected to invalid cache, turn on ExpensiveAsserts to
+ // catch this situation as early as possible.
+ if (ExpensiveAsserts)
+ validateAll();
+ else
+ validate(BB);
+#endif
+
+ if (FirstSpecialInsts.find(BB) == FirstSpecialInsts.end()) {
+ fill(BB);
+ assert(FirstSpecialInsts.find(BB) != FirstSpecialInsts.end() && "Must be!");
+ }
+ return FirstSpecialInsts[BB];
+}
+
+bool InstructionPrecedenceTracking::hasSpecialInstructions(
+ const BasicBlock *BB) {
+ return getFirstSpecialInstruction(BB) != nullptr;
+}
+
+bool InstructionPrecedenceTracking::isPreceededBySpecialInstruction(
+ const Instruction *Insn) {
+ const Instruction *MaybeFirstSpecial =
+ getFirstSpecialInstruction(Insn->getParent());
+ return MaybeFirstSpecial && OI.dominates(MaybeFirstSpecial, Insn);
+}
+
+void InstructionPrecedenceTracking::fill(const BasicBlock *BB) {
+ FirstSpecialInsts.erase(BB);
+ for (auto &I : *BB)
+ if (isSpecialInstruction(&I)) {
+ FirstSpecialInsts[BB] = &I;
+ return;
+ }
+
+ // Mark this block as having no special instructions.
+ FirstSpecialInsts[BB] = nullptr;
+}
+
+#ifndef NDEBUG
+void InstructionPrecedenceTracking::validate(const BasicBlock *BB) const {
+ auto It = FirstSpecialInsts.find(BB);
+ // Bail if we don't have anything cached for this block.
+ if (It == FirstSpecialInsts.end())
+ return;
+
+ for (const Instruction &Insn : *BB)
+ if (isSpecialInstruction(&Insn)) {
+ assert(It->second == &Insn &&
+ "Cached first special instruction is wrong!");
+ return;
+ }
+
+ assert(It->second == nullptr &&
+ "Block is marked as having special instructions but in fact it has "
+ "none!");
+}
+
+void InstructionPrecedenceTracking::validateAll() const {
+ // Check that for every known block the cached value is correct.
+ for (auto &It : FirstSpecialInsts)
+ validate(It.first);
+}
+#endif
+
+void InstructionPrecedenceTracking::insertInstructionTo(const Instruction *Inst,
+ const BasicBlock *BB) {
+ if (isSpecialInstruction(Inst))
+ FirstSpecialInsts.erase(BB);
+ OI.invalidateBlock(BB);
+}
+
+void InstructionPrecedenceTracking::removeInstruction(const Instruction *Inst) {
+ if (isSpecialInstruction(Inst))
+ FirstSpecialInsts.erase(Inst->getParent());
+ OI.invalidateBlock(Inst->getParent());
+}
+
+void InstructionPrecedenceTracking::clear() {
+ for (auto It : FirstSpecialInsts)
+ OI.invalidateBlock(It.first);
+ FirstSpecialInsts.clear();
+#ifndef NDEBUG
+ // The map should be valid after clearing (at least empty).
+ validateAll();
+#endif
+}
+
+bool ImplicitControlFlowTracking::isSpecialInstruction(
+ const Instruction *Insn) const {
+ // If a block's instruction doesn't always pass the control to its successor
+ // instruction, mark the block as having implicit control flow. We use them
+ // to avoid wrong assumptions of sort "if A is executed and B post-dominates
+ // A, then B is also executed". This is not true is there is an implicit
+ // control flow instruction (e.g. a guard) between them.
+ //
+ // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false
+ // for volatile stores and loads because they can trap. The discussion on
+ // whether or not it is correct is still ongoing. We might want to get rid
+ // of this logic in the future. Anyways, trapping instructions shouldn't
+ // introduce implicit control flow, so we explicitly allow them here. This
+ // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed.
+ if (isGuaranteedToTransferExecutionToSuccessor(Insn))
+ return false;
+ if (isa<LoadInst>(Insn)) {
+ assert(cast<LoadInst>(Insn)->isVolatile() &&
+ "Non-volatile load should transfer execution to successor!");
+ return false;
+ }
+ if (isa<StoreInst>(Insn)) {
+ assert(cast<StoreInst>(Insn)->isVolatile() &&
+ "Non-volatile store should transfer execution to successor!");
+ return false;
+ }
+ return true;
+}
+
+bool MemoryWriteTracking::isSpecialInstruction(
+ const Instruction *Insn) const {
+ using namespace PatternMatch;
+ if (match(Insn, m_Intrinsic<Intrinsic::experimental_widenable_condition>()))
+ return false;
+ return Insn->mayWriteToMemory();
+}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
new file mode 100644
index 000000000000..cb8987721700
--- /dev/null
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -0,0 +1,5518 @@
+//===- InstructionSimplify.cpp - Fold instruction operands ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements routines for folding instructions into simpler forms
+// that do not require creating new instructions. This does constant folding
+// ("add i32 1, 1" -> "2") but can also handle non-constant operands, either
+// returning a constant ("and i32 %x, 0" -> "0") or an already existing value
+// ("and i32 %x, %x" -> "%x"). All operands are assumed to have already been
+// simplified: This is usually true and assuming it simplifies the logic (if
+// they have not been simplified then results are correct but maybe suboptimal).
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/CmpInstAnalysis.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Support/KnownBits.h"
+#include <algorithm>
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "instsimplify"
+
+enum { RecursionLimit = 3 };
+
+STATISTIC(NumExpand, "Number of expansions");
+STATISTIC(NumReassoc, "Number of reassociations");
+
+static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned);
+static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned);
+static Value *simplifyFPUnOp(unsigned, Value *, const FastMathFlags &,
+ const SimplifyQuery &, unsigned);
+static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &,
+ unsigned);
+static Value *SimplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &,
+ const SimplifyQuery &, unsigned);
+static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &,
+ unsigned);
+static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse);
+static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);
+static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned);
+static Value *SimplifyCastInst(unsigned, Value *, Type *,
+ const SimplifyQuery &, unsigned);
+static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &,
+ unsigned);
+
+static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
+ Value *FalseVal) {
+ BinaryOperator::BinaryOps BinOpCode;
+ if (auto *BO = dyn_cast<BinaryOperator>(Cond))
+ BinOpCode = BO->getOpcode();
+ else
+ return nullptr;
+
+ CmpInst::Predicate ExpectedPred, Pred1, Pred2;
+ if (BinOpCode == BinaryOperator::Or) {
+ ExpectedPred = ICmpInst::ICMP_NE;
+ } else if (BinOpCode == BinaryOperator::And) {
+ ExpectedPred = ICmpInst::ICMP_EQ;
+ } else
+ return nullptr;
+
+ // %A = icmp eq %TV, %FV
+ // %B = icmp eq %X, %Y (and one of these is a select operand)
+ // %C = and %A, %B
+ // %D = select %C, %TV, %FV
+ // -->
+ // %FV
+
+ // %A = icmp ne %TV, %FV
+ // %B = icmp ne %X, %Y (and one of these is a select operand)
+ // %C = or %A, %B
+ // %D = select %C, %TV, %FV
+ // -->
+ // %TV
+ Value *X, *Y;
+ if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal),
+ m_Specific(FalseVal)),
+ m_ICmp(Pred2, m_Value(X), m_Value(Y)))) ||
+ Pred1 != Pred2 || Pred1 != ExpectedPred)
+ return nullptr;
+
+ if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal)
+ return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal;
+
+ return nullptr;
+}
+
+/// For a boolean type or a vector of boolean type, return false or a vector
+/// with every element false.
+static Constant *getFalse(Type *Ty) {
+ return ConstantInt::getFalse(Ty);
+}
+
+/// For a boolean type or a vector of boolean type, return true or a vector
+/// with every element true.
+static Constant *getTrue(Type *Ty) {
+ return ConstantInt::getTrue(Ty);
+}
+
+/// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"?
+static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS) {
+ CmpInst *Cmp = dyn_cast<CmpInst>(V);
+ if (!Cmp)
+ return false;
+ CmpInst::Predicate CPred = Cmp->getPredicate();
+ Value *CLHS = Cmp->getOperand(0), *CRHS = Cmp->getOperand(1);
+ if (CPred == Pred && CLHS == LHS && CRHS == RHS)
+ return true;
+ return CPred == CmpInst::getSwappedPredicate(Pred) && CLHS == RHS &&
+ CRHS == LHS;
+}
+
+/// Does the given value dominate the specified phi node?
+static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I)
+ // Arguments and constants dominate all instructions.
+ return true;
+
+ // If we are processing instructions (and/or basic blocks) that have not been
+ // fully added to a function, the parent nodes may still be null. Simply
+ // return the conservative answer in these cases.
+ if (!I->getParent() || !P->getParent() || !I->getFunction())
+ return false;
+
+ // If we have a DominatorTree then do a precise test.
+ if (DT)
+ return DT->dominates(I, P);
+
+ // Otherwise, if the instruction is in the entry block and is not an invoke,
+ // then it obviously dominates all phi nodes.
+ if (I->getParent() == &I->getFunction()->getEntryBlock() &&
+ !isa<InvokeInst>(I))
+ return true;
+
+ return false;
+}
+
+/// Simplify "A op (B op' C)" by distributing op over op', turning it into
+/// "(A op B) op' (A op C)". Here "op" is given by Opcode and "op'" is
+/// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS.
+/// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)".
+/// Returns the simplified value, or null if no simplification was performed.
+static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
+ Instruction::BinaryOps OpcodeToExpand,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ // Check whether the expression has the form "(A op' B) op C".
+ if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS))
+ if (Op0->getOpcode() == OpcodeToExpand) {
+ // It does! Try turning it into "(A op C) op' (B op C)".
+ Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS;
+ // Do "A op C" and "B op C" both simplify?
+ if (Value *L = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse))
+ if (Value *R = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) {
+ // They do! Return "L op' R" if it simplifies or is already available.
+ // If "L op' R" equals "A op' B" then "L op' R" is just the LHS.
+ if ((L == A && R == B) || (Instruction::isCommutative(OpcodeToExpand)
+ && L == B && R == A)) {
+ ++NumExpand;
+ return LHS;
+ }
+ // Otherwise return "L op' R" if it simplifies.
+ if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) {
+ ++NumExpand;
+ return V;
+ }
+ }
+ }
+
+ // Check whether the expression has the form "A op (B op' C)".
+ if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS))
+ if (Op1->getOpcode() == OpcodeToExpand) {
+ // It does! Try turning it into "(A op B) op' (A op C)".
+ Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1);
+ // Do "A op B" and "A op C" both simplify?
+ if (Value *L = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse))
+ if (Value *R = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) {
+ // They do! Return "L op' R" if it simplifies or is already available.
+ // If "L op' R" equals "B op' C" then "L op' R" is just the RHS.
+ if ((L == B && R == C) || (Instruction::isCommutative(OpcodeToExpand)
+ && L == C && R == B)) {
+ ++NumExpand;
+ return RHS;
+ }
+ // Otherwise return "L op' R" if it simplifies.
+ if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) {
+ ++NumExpand;
+ return V;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+/// Generic simplifications for associative binary operations.
+/// Returns the simpler value, or null if none was found.
+static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode,
+ Value *LHS, Value *RHS,
+ const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ assert(Instruction::isAssociative(Opcode) && "Not an associative operation!");
+
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
+ BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
+
+ // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely.
+ if (Op0 && Op0->getOpcode() == Opcode) {
+ Value *A = Op0->getOperand(0);
+ Value *B = Op0->getOperand(1);
+ Value *C = RHS;
+
+ // Does "B op C" simplify?
+ if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) {
+ // It does! Return "A op V" if it simplifies or is already available.
+ // If V equals B then "A op V" is just the LHS.
+ if (V == B) return LHS;
+ // Otherwise return "A op V" if it simplifies.
+ if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) {
+ ++NumReassoc;
+ return W;
+ }
+ }
+ }
+
+ // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely.
+ if (Op1 && Op1->getOpcode() == Opcode) {
+ Value *A = LHS;
+ Value *B = Op1->getOperand(0);
+ Value *C = Op1->getOperand(1);
+
+ // Does "A op B" simplify?
+ if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) {
+ // It does! Return "V op C" if it simplifies or is already available.
+ // If V equals B then "V op C" is just the RHS.
+ if (V == B) return RHS;
+ // Otherwise return "V op C" if it simplifies.
+ if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) {
+ ++NumReassoc;
+ return W;
+ }
+ }
+ }
+
+ // The remaining transforms require commutativity as well as associativity.
+ if (!Instruction::isCommutative(Opcode))
+ return nullptr;
+
+ // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely.
+ if (Op0 && Op0->getOpcode() == Opcode) {
+ Value *A = Op0->getOperand(0);
+ Value *B = Op0->getOperand(1);
+ Value *C = RHS;
+
+ // Does "C op A" simplify?
+ if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) {
+ // It does! Return "V op B" if it simplifies or is already available.
+ // If V equals A then "V op B" is just the LHS.
+ if (V == A) return LHS;
+ // Otherwise return "V op B" if it simplifies.
+ if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) {
+ ++NumReassoc;
+ return W;
+ }
+ }
+ }
+
+ // Transform: "A op (B op C)" ==> "B op (C op A)" if it simplifies completely.
+ if (Op1 && Op1->getOpcode() == Opcode) {
+ Value *A = LHS;
+ Value *B = Op1->getOperand(0);
+ Value *C = Op1->getOperand(1);
+
+ // Does "C op A" simplify?
+ if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) {
+ // It does! Return "B op V" if it simplifies or is already available.
+ // If V equals C then "B op V" is just the RHS.
+ if (V == C) return RHS;
+ // Otherwise return "B op V" if it simplifies.
+ if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) {
+ ++NumReassoc;
+ return W;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+/// In the case of a binary operation with a select instruction as an operand,
+/// try to simplify the binop by seeing whether evaluating it on both branches
+/// of the select results in the same value. Returns the common value if so,
+/// otherwise returns null.
+static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ SelectInst *SI;
+ if (isa<SelectInst>(LHS)) {
+ SI = cast<SelectInst>(LHS);
+ } else {
+ assert(isa<SelectInst>(RHS) && "No select instruction operand!");
+ SI = cast<SelectInst>(RHS);
+ }
+
+ // Evaluate the BinOp on the true and false branches of the select.
+ Value *TV;
+ Value *FV;
+ if (SI == LHS) {
+ TV = SimplifyBinOp(Opcode, SI->getTrueValue(), RHS, Q, MaxRecurse);
+ FV = SimplifyBinOp(Opcode, SI->getFalseValue(), RHS, Q, MaxRecurse);
+ } else {
+ TV = SimplifyBinOp(Opcode, LHS, SI->getTrueValue(), Q, MaxRecurse);
+ FV = SimplifyBinOp(Opcode, LHS, SI->getFalseValue(), Q, MaxRecurse);
+ }
+
+ // If they simplified to the same value, then return the common value.
+ // If they both failed to simplify then return null.
+ if (TV == FV)
+ return TV;
+
+ // If one branch simplified to undef, return the other one.
+ if (TV && isa<UndefValue>(TV))
+ return FV;
+ if (FV && isa<UndefValue>(FV))
+ return TV;
+
+ // If applying the operation did not change the true and false select values,
+ // then the result of the binop is the select itself.
+ if (TV == SI->getTrueValue() && FV == SI->getFalseValue())
+ return SI;
+
+ // If one branch simplified and the other did not, and the simplified
+ // value is equal to the unsimplified one, return the simplified value.
+ // For example, select (cond, X, X & Z) & Z -> X & Z.
+ if ((FV && !TV) || (TV && !FV)) {
+ // Check that the simplified value has the form "X op Y" where "op" is the
+ // same as the original operation.
+ Instruction *Simplified = dyn_cast<Instruction>(FV ? FV : TV);
+ if (Simplified && Simplified->getOpcode() == unsigned(Opcode)) {
+ // The value that didn't simplify is "UnsimplifiedLHS op UnsimplifiedRHS".
+ // We already know that "op" is the same as for the simplified value. See
+ // if the operands match too. If so, return the simplified value.
+ Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue();
+ Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS;
+ Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch;
+ if (Simplified->getOperand(0) == UnsimplifiedLHS &&
+ Simplified->getOperand(1) == UnsimplifiedRHS)
+ return Simplified;
+ if (Simplified->isCommutative() &&
+ Simplified->getOperand(1) == UnsimplifiedLHS &&
+ Simplified->getOperand(0) == UnsimplifiedRHS)
+ return Simplified;
+ }
+ }
+
+ return nullptr;
+}
+
+/// In the case of a comparison with a select instruction, try to simplify the
+/// comparison by seeing whether both branches of the select result in the same
+/// value. Returns the common value if so, otherwise returns null.
+static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ // Make sure the select is on the LHS.
+ if (!isa<SelectInst>(LHS)) {
+ std::swap(LHS, RHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ }
+ assert(isa<SelectInst>(LHS) && "Not comparing with a select instruction!");
+ SelectInst *SI = cast<SelectInst>(LHS);
+ Value *Cond = SI->getCondition();
+ Value *TV = SI->getTrueValue();
+ Value *FV = SI->getFalseValue();
+
+ // Now that we have "cmp select(Cond, TV, FV), RHS", analyse it.
+ // Does "cmp TV, RHS" simplify?
+ Value *TCmp = SimplifyCmpInst(Pred, TV, RHS, Q, MaxRecurse);
+ if (TCmp == Cond) {
+ // It not only simplified, it simplified to the select condition. Replace
+ // it with 'true'.
+ TCmp = getTrue(Cond->getType());
+ } else if (!TCmp) {
+ // It didn't simplify. However if "cmp TV, RHS" is equal to the select
+ // condition then we can replace it with 'true'. Otherwise give up.
+ if (!isSameCompare(Cond, Pred, TV, RHS))
+ return nullptr;
+ TCmp = getTrue(Cond->getType());
+ }
+
+ // Does "cmp FV, RHS" simplify?
+ Value *FCmp = SimplifyCmpInst(Pred, FV, RHS, Q, MaxRecurse);
+ if (FCmp == Cond) {
+ // It not only simplified, it simplified to the select condition. Replace
+ // it with 'false'.
+ FCmp = getFalse(Cond->getType());
+ } else if (!FCmp) {
+ // It didn't simplify. However if "cmp FV, RHS" is equal to the select
+ // condition then we can replace it with 'false'. Otherwise give up.
+ if (!isSameCompare(Cond, Pred, FV, RHS))
+ return nullptr;
+ FCmp = getFalse(Cond->getType());
+ }
+
+ // If both sides simplified to the same value, then use it as the result of
+ // the original comparison.
+ if (TCmp == FCmp)
+ return TCmp;
+
+ // The remaining cases only make sense if the select condition has the same
+ // type as the result of the comparison, so bail out if this is not so.
+ if (Cond->getType()->isVectorTy() != RHS->getType()->isVectorTy())
+ return nullptr;
+ // If the false value simplified to false, then the result of the compare
+ // is equal to "Cond && TCmp". This also catches the case when the false
+ // value simplified to false and the true value to true, returning "Cond".
+ if (match(FCmp, m_Zero()))
+ if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse))
+ return V;
+ // If the true value simplified to true, then the result of the compare
+ // is equal to "Cond || FCmp".
+ if (match(TCmp, m_One()))
+ if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse))
+ return V;
+ // Finally, if the false value simplified to true and the true value to
+ // false, then the result of the compare is equal to "!Cond".
+ if (match(FCmp, m_One()) && match(TCmp, m_Zero()))
+ if (Value *V =
+ SimplifyXorInst(Cond, Constant::getAllOnesValue(Cond->getType()),
+ Q, MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
+/// In the case of a binary operation with an operand that is a PHI instruction,
+/// try to simplify the binop by seeing whether evaluating it on the incoming
+/// phi values yields the same result for every value. If so returns the common
+/// value, otherwise returns null.
+static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ PHINode *PI;
+ if (isa<PHINode>(LHS)) {
+ PI = cast<PHINode>(LHS);
+ // Bail out if RHS and the phi may be mutually interdependent due to a loop.
+ if (!valueDominatesPHI(RHS, PI, Q.DT))
+ return nullptr;
+ } else {
+ assert(isa<PHINode>(RHS) && "No PHI instruction operand!");
+ PI = cast<PHINode>(RHS);
+ // Bail out if LHS and the phi may be mutually interdependent due to a loop.
+ if (!valueDominatesPHI(LHS, PI, Q.DT))
+ return nullptr;
+ }
+
+ // Evaluate the BinOp on the incoming phi values.
+ Value *CommonValue = nullptr;
+ for (Value *Incoming : PI->incoming_values()) {
+ // If the incoming value is the phi node itself, it can safely be skipped.
+ if (Incoming == PI) continue;
+ Value *V = PI == LHS ?
+ SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) :
+ SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse);
+ // If the operation failed to simplify, or simplified to a different value
+ // to previously, then give up.
+ if (!V || (CommonValue && V != CommonValue))
+ return nullptr;
+ CommonValue = V;
+ }
+
+ return CommonValue;
+}
+
+/// In the case of a comparison with a PHI instruction, try to simplify the
+/// comparison by seeing whether comparing with all of the incoming phi values
+/// yields the same result every time. If so returns the common result,
+/// otherwise returns null.
+static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return nullptr;
+
+ // Make sure the phi is on the LHS.
+ if (!isa<PHINode>(LHS)) {
+ std::swap(LHS, RHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ }
+ assert(isa<PHINode>(LHS) && "Not comparing with a phi instruction!");
+ PHINode *PI = cast<PHINode>(LHS);
+
+ // Bail out if RHS and the phi may be mutually interdependent due to a loop.
+ if (!valueDominatesPHI(RHS, PI, Q.DT))
+ return nullptr;
+
+ // Evaluate the BinOp on the incoming phi values.
+ Value *CommonValue = nullptr;
+ for (Value *Incoming : PI->incoming_values()) {
+ // If the incoming value is the phi node itself, it can safely be skipped.
+ if (Incoming == PI) continue;
+ Value *V = SimplifyCmpInst(Pred, Incoming, RHS, Q, MaxRecurse);
+ // If the operation failed to simplify, or simplified to a different value
+ // to previously, then give up.
+ if (!V || (CommonValue && V != CommonValue))
+ return nullptr;
+ CommonValue = V;
+ }
+
+ return CommonValue;
+}
+
+static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode,
+ Value *&Op0, Value *&Op1,
+ const SimplifyQuery &Q) {
+ if (auto *CLHS = dyn_cast<Constant>(Op0)) {
+ if (auto *CRHS = dyn_cast<Constant>(Op1))
+ return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL);
+
+ // Canonicalize the constant to the RHS if this is a commutative operation.
+ if (Instruction::isCommutative(Opcode))
+ std::swap(Op0, Op1);
+ }
+ return nullptr;
+}
+
+/// Given operands for an Add, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q))
+ return C;
+
+ // X + undef -> undef
+ if (match(Op1, m_Undef()))
+ return Op1;
+
+ // X + 0 -> X
+ if (match(Op1, m_Zero()))
+ return Op0;
+
+ // If two operands are negative, return 0.
+ if (isKnownNegation(Op0, Op1))
+ return Constant::getNullValue(Op0->getType());
+
+ // X + (Y - X) -> Y
+ // (Y - X) + X -> Y
+ // Eg: X + -X -> 0
+ Value *Y = nullptr;
+ if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) ||
+ match(Op0, m_Sub(m_Value(Y), m_Specific(Op1))))
+ return Y;
+
+ // X + ~X -> -1 since ~X = -X-1
+ Type *Ty = Op0->getType();
+ if (match(Op0, m_Not(m_Specific(Op1))) ||
+ match(Op1, m_Not(m_Specific(Op0))))
+ return Constant::getAllOnesValue(Ty);
+
+ // add nsw/nuw (xor Y, signmask), signmask --> Y
+ // The no-wrapping add guarantees that the top bit will be set by the add.
+ // Therefore, the xor must be clearing the already set sign bit of Y.
+ if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) &&
+ match(Op0, m_Xor(m_Value(Y), m_SignMask())))
+ return Y;
+
+ // add nuw %x, -1 -> -1, because %x can only be 0.
+ if (IsNUW && match(Op1, m_AllOnes()))
+ return Op1; // Which is -1.
+
+ /// i1 add -> xor.
+ if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))
+ if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1))
+ return V;
+
+ // Try some generic simplifications for associative operations.
+ if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // Threading Add over selects and phi nodes is pointless, so don't bother.
+ // Threading over the select in "A + select(cond, B, C)" means evaluating
+ // "A+B" and "A+C" and seeing if they are equal; but they are equal if and
+ // only if B and C are equal. If B and C are equal then (since we assume
+ // that operands have already been simplified) "select(cond, B, C)" should
+ // have been simplified to the common value of B and C already. Analysing
+ // "A+B" and "A+C" thus gains nothing, but costs compile time. Similarly
+ // for threading over phi nodes.
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+ const SimplifyQuery &Query) {
+ return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit);
+}
+
+/// Compute the base pointer and cumulative constant offsets for V.
+///
+/// This strips all constant offsets off of V, leaving it the base pointer, and
+/// accumulates the total constant offset applied in the returned constant. It
+/// returns 0 if V is not a pointer, and returns the constant '0' if there are
+/// no constant offsets applied.
+///
+/// This is very similar to GetPointerBaseWithConstantOffset except it doesn't
+/// follow non-inbounds geps. This allows it to remain usable for icmp ult/etc.
+/// folding.
+static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V,
+ bool AllowNonInbounds = false) {
+ assert(V->getType()->isPtrOrPtrVectorTy());
+
+ Type *IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType();
+ APInt Offset = APInt::getNullValue(IntPtrTy->getIntegerBitWidth());
+
+ V = V->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds);
+ // As that strip may trace through `addrspacecast`, need to sext or trunc
+ // the offset calculated.
+ IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType();
+ Offset = Offset.sextOrTrunc(IntPtrTy->getIntegerBitWidth());
+
+ Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset);
+ if (V->getType()->isVectorTy())
+ return ConstantVector::getSplat(V->getType()->getVectorNumElements(),
+ OffsetIntPtr);
+ return OffsetIntPtr;
+}
+
+/// Compute the constant difference between two pointer values.
+/// If the difference is not a constant, returns zero.
+static Constant *computePointerDifference(const DataLayout &DL, Value *LHS,
+ Value *RHS) {
+ Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS);
+ Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS);
+
+ // If LHS and RHS are not related via constant offsets to the same base
+ // value, there is nothing we can do here.
+ if (LHS != RHS)
+ return nullptr;
+
+ // Otherwise, the difference of LHS - RHS can be computed as:
+ // LHS - RHS
+ // = (LHSOffset + Base) - (RHSOffset + Base)
+ // = LHSOffset - RHSOffset
+ return ConstantExpr::getSub(LHSOffset, RHSOffset);
+}
+
+/// Given operands for a Sub, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
+ return C;
+
+ // X - undef -> undef
+ // undef - X -> undef
+ if (match(Op0, m_Undef()) || match(Op1, m_Undef()))
+ return UndefValue::get(Op0->getType());
+
+ // X - 0 -> X
+ if (match(Op1, m_Zero()))
+ return Op0;
+
+ // X - X -> 0
+ if (Op0 == Op1)
+ return Constant::getNullValue(Op0->getType());
+
+ // Is this a negation?
+ if (match(Op0, m_Zero())) {
+ // 0 - X -> 0 if the sub is NUW.
+ if (isNUW)
+ return Constant::getNullValue(Op0->getType());
+
+ KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (Known.Zero.isMaxSignedValue()) {
+ // Op1 is either 0 or the minimum signed value. If the sub is NSW, then
+ // Op1 must be 0 because negating the minimum signed value is undefined.
+ if (isNSW)
+ return Constant::getNullValue(Op0->getType());
+
+ // 0 - X -> X if X is 0 or the minimum signed value.
+ return Op1;
+ }
+ }
+
+ // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies.
+ // For example, (X + Y) - Y -> X; (Y + X) - Y -> X
+ Value *X = nullptr, *Y = nullptr, *Z = Op1;
+ if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z
+ // See if "V === Y - Z" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse-1))
+ // It does! Now see if "X + V" simplifies.
+ if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse-1)) {
+ // It does, we successfully reassociated!
+ ++NumReassoc;
+ return W;
+ }
+ // See if "V === X - Z" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1))
+ // It does! Now see if "Y + V" simplifies.
+ if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse-1)) {
+ // It does, we successfully reassociated!
+ ++NumReassoc;
+ return W;
+ }
+ }
+
+ // X - (Y + Z) -> (X - Y) - Z or (X - Z) - Y if everything simplifies.
+ // For example, X - (X + 1) -> -1
+ X = Op0;
+ if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z)
+ // See if "V === X - Y" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1))
+ // It does! Now see if "V - Z" simplifies.
+ if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse-1)) {
+ // It does, we successfully reassociated!
+ ++NumReassoc;
+ return W;
+ }
+ // See if "V === X - Z" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1))
+ // It does! Now see if "V - Y" simplifies.
+ if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse-1)) {
+ // It does, we successfully reassociated!
+ ++NumReassoc;
+ return W;
+ }
+ }
+
+ // Z - (X - Y) -> (Z - X) + Y if everything simplifies.
+ // For example, X - (X - Y) -> Y.
+ Z = Op0;
+ if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y)
+ // See if "V === Z - X" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse-1))
+ // It does! Now see if "V + Y" simplifies.
+ if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse-1)) {
+ // It does, we successfully reassociated!
+ ++NumReassoc;
+ return W;
+ }
+
+ // trunc(X) - trunc(Y) -> trunc(X - Y) if everything simplifies.
+ if (MaxRecurse && match(Op0, m_Trunc(m_Value(X))) &&
+ match(Op1, m_Trunc(m_Value(Y))))
+ if (X->getType() == Y->getType())
+ // See if "V === X - Y" simplifies.
+ if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1))
+ // It does! Now see if "trunc V" simplifies.
+ if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(),
+ Q, MaxRecurse - 1))
+ // It does, return the simplified "trunc V".
+ return W;
+
+ // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...).
+ if (match(Op0, m_PtrToInt(m_Value(X))) &&
+ match(Op1, m_PtrToInt(m_Value(Y))))
+ if (Constant *Result = computePointerDifference(Q.DL, X, Y))
+ return ConstantExpr::getIntegerCast(Result, Op0->getType(), true);
+
+ // i1 sub -> xor.
+ if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))
+ if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1))
+ return V;
+
+ // Threading Sub over selects and phi nodes is pointless, so don't bother.
+ // Threading over the select in "A - select(cond, B, C)" means evaluating
+ // "A-B" and "A-C" and seeing if they are equal; but they are equal if and
+ // only if B and C are equal. If B and C are equal then (since we assume
+ // that operands have already been simplified) "select(cond, B, C)" should
+ // have been simplified to the common value of B and C already. Analysing
+ // "A-B" and "A-C" thus gains nothing, but costs compile time. Similarly
+ // for threading over phi nodes.
+
+ return nullptr;
+}
+
+Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+ const SimplifyQuery &Q) {
+ return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit);
+}
+
+/// Given operands for a Mul, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
+ return C;
+
+ // X * undef -> 0
+ // X * 0 -> 0
+ if (match(Op1, m_CombineOr(m_Undef(), m_Zero())))
+ return Constant::getNullValue(Op0->getType());
+
+ // X * 1 -> X
+ if (match(Op1, m_One()))
+ return Op0;
+
+ // (X / Y) * Y -> X if the division is exact.
+ Value *X = nullptr;
+ if (Q.IIQ.UseInstrInfo &&
+ (match(Op0,
+ m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y
+ match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))) // Y * (X / Y)
+ return X;
+
+ // i1 mul -> and.
+ if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))
+ if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1))
+ return V;
+
+ // Try some generic simplifications for associative operations.
+ if (Value *V = SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // Mul distributes over Add. Try some generic simplifications based on this.
+ if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add,
+ Q, MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Check for common or similar folds of integer division or integer remainder.
+/// This applies to all 4 opcodes (sdiv/udiv/srem/urem).
+static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) {
+ Type *Ty = Op0->getType();
+
+ // X / undef -> undef
+ // X % undef -> undef
+ if (match(Op1, m_Undef()))
+ return Op1;
+
+ // X / 0 -> undef
+ // X % 0 -> undef
+ // We don't need to preserve faults!
+ if (match(Op1, m_Zero()))
+ return UndefValue::get(Ty);
+
+ // If any element of a constant divisor vector is zero or undef, the whole op
+ // is undef.
+ auto *Op1C = dyn_cast<Constant>(Op1);
+ if (Op1C && Ty->isVectorTy()) {
+ unsigned NumElts = Ty->getVectorNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *Elt = Op1C->getAggregateElement(i);
+ if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt)))
+ return UndefValue::get(Ty);
+ }
+ }
+
+ // undef / X -> 0
+ // undef % X -> 0
+ if (match(Op0, m_Undef()))
+ return Constant::getNullValue(Ty);
+
+ // 0 / X -> 0
+ // 0 % X -> 0
+ if (match(Op0, m_Zero()))
+ return Constant::getNullValue(Op0->getType());
+
+ // X / X -> 1
+ // X % X -> 0
+ if (Op0 == Op1)
+ return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
+
+ // X / 1 -> X
+ // X % 1 -> 0
+ // If this is a boolean op (single-bit element type), we can't have
+ // division-by-zero or remainder-by-zero, so assume the divisor is 1.
+ // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1.
+ Value *X;
+ if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) ||
+ (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
+ return IsDiv ? Op0 : Constant::getNullValue(Ty);
+
+ return nullptr;
+}
+
+/// Given a predicate and two operands, return true if the comparison is true.
+/// This is a helper for div/rem simplification where we return some other value
+/// when we can prove a relationship between the operands.
+static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ Value *V = SimplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse);
+ Constant *C = dyn_cast_or_null<Constant>(V);
+ return (C && C->isAllOnesValue());
+}
+
+/// Return true if we can simplify X / Y to 0. Remainder can adapt that answer
+/// to simplify X % Y to X.
+static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
+ unsigned MaxRecurse, bool IsSigned) {
+ // Recursion is always used, so bail out at once if we already hit the limit.
+ if (!MaxRecurse--)
+ return false;
+
+ if (IsSigned) {
+ // |X| / |Y| --> 0
+ //
+ // We require that 1 operand is a simple constant. That could be extended to
+ // 2 variables if we computed the sign bit for each.
+ //
+ // Make sure that a constant is not the minimum signed value because taking
+ // the abs() of that is undefined.
+ Type *Ty = X->getType();
+ const APInt *C;
+ if (match(X, m_APInt(C)) && !C->isMinSignedValue()) {
+ // Is the variable divisor magnitude always greater than the constant
+ // dividend magnitude?
+ // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
+ Constant *PosDividendC = ConstantInt::get(Ty, C->abs());
+ Constant *NegDividendC = ConstantInt::get(Ty, -C->abs());
+ if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
+ isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
+ return true;
+ }
+ if (match(Y, m_APInt(C))) {
+ // Special-case: we can't take the abs() of a minimum signed value. If
+ // that's the divisor, then all we have to do is prove that the dividend
+ // is also not the minimum signed value.
+ if (C->isMinSignedValue())
+ return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse);
+
+ // Is the variable dividend magnitude always less than the constant
+ // divisor magnitude?
+ // |X| < |C| --> X > -abs(C) and X < abs(C)
+ Constant *PosDivisorC = ConstantInt::get(Ty, C->abs());
+ Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs());
+ if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) &&
+ isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse))
+ return true;
+ }
+ return false;
+ }
+
+ // IsSigned == false.
+ // Is the dividend unsigned less than the divisor?
+ return isICmpTrue(ICmpInst::ICMP_ULT, X, Y, Q, MaxRecurse);
+}
+
+/// These are simplifications common to SDiv and UDiv.
+static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+ return C;
+
+ if (Value *V = simplifyDivRem(Op0, Op1, true))
+ return V;
+
+ bool IsSigned = Opcode == Instruction::SDiv;
+
+ // (X * Y) / Y -> X if the multiplication does not overflow.
+ Value *X;
+ if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
+ auto *Mul = cast<OverflowingBinaryOperator>(Op0);
+ // If the Mul does not overflow, then we are good to go.
+ if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
+ (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)))
+ return X;
+ // If X has the form X = A / Y, then X * Y cannot overflow.
+ if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
+ (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1)))))
+ return X;
+ }
+
+ // (X rem Y) / Y -> 0
+ if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) ||
+ (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1)))))
+ return Constant::getNullValue(Op0->getType());
+
+ // (X /u C1) /u C2 -> 0 if C1 * C2 overflow
+ ConstantInt *C1, *C2;
+ if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) &&
+ match(Op1, m_ConstantInt(C2))) {
+ bool Overflow;
+ (void)C1->getValue().umul_ov(C2->getValue(), Overflow);
+ if (Overflow)
+ return Constant::getNullValue(Op0->getType());
+ }
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned))
+ return Constant::getNullValue(Op0->getType());
+
+ return nullptr;
+}
+
+/// These are simplifications common to SRem and URem.
+static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+ return C;
+
+ if (Value *V = simplifyDivRem(Op0, Op1, false))
+ return V;
+
+ // (X % Y) % Y -> X % Y
+ if ((Opcode == Instruction::SRem &&
+ match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) ||
+ (Opcode == Instruction::URem &&
+ match(Op0, m_URem(m_Value(), m_Specific(Op1)))))
+ return Op0;
+
+ // (X << Y) % X -> 0
+ if (Q.IIQ.UseInstrInfo &&
+ ((Opcode == Instruction::SRem &&
+ match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) ||
+ (Opcode == Instruction::URem &&
+ match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))))
+ return Constant::getNullValue(Op0->getType());
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // If X / Y == 0, then X % Y == X.
+ if (isDivZero(Op0, Op1, Q, MaxRecurse, Opcode == Instruction::SRem))
+ return Op0;
+
+ return nullptr;
+}
+
+/// Given operands for an SDiv, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // If two operands are negated and no signed overflow, return -1.
+ if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse);
+}
+
+Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Given operands for a UDiv, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse);
+}
+
+Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Given operands for an SRem, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // If the divisor is 0, the result is undefined, so assume the divisor is -1.
+ // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0
+ Value *X;
+ if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
+ return ConstantInt::getNullValue(Op0->getType());
+
+ // If the two operands are negated, return 0.
+ if (isKnownNegation(Op0, Op1))
+ return ConstantInt::getNullValue(Op0->getType());
+
+ return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse);
+}
+
+Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Given operands for a URem, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ return simplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse);
+}
+
+Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Returns true if a shift by \c Amount always yields undef.
+static bool isUndefShift(Value *Amount) {
+ Constant *C = dyn_cast<Constant>(Amount);
+ if (!C)
+ return false;
+
+ // X shift by undef -> undef because it may shift by the bitwidth.
+ if (isa<UndefValue>(C))
+ return true;
+
+ // Shifting by the bitwidth or more is undefined.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(C))
+ if (CI->getValue().getLimitedValue() >=
+ CI->getType()->getScalarSizeInBits())
+ return true;
+
+ // If all lanes of a vector shift are undefined the whole shift is.
+ if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) {
+ for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; ++I)
+ if (!isUndefShift(C->getAggregateElement(I)))
+ return false;
+ return true;
+ }
+
+ return false;
+}
+
+/// Given operands for an Shl, LShr or AShr, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
+ Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+ return C;
+
+ // 0 shift by X -> 0
+ if (match(Op0, m_Zero()))
+ return Constant::getNullValue(Op0->getType());
+
+ // X shift by 0 -> X
+ // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones
+ // would be poison.
+ Value *X;
+ if (match(Op1, m_Zero()) ||
+ (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
+ return Op0;
+
+ // Fold undefined shifts.
+ if (isUndefShift(Op1))
+ return UndefValue::get(Op0->getType());
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // If any bits in the shift amount make that value greater than or equal to
+ // the number of bits in the type, the shift is undefined.
+ KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (Known.One.getLimitedValue() >= Known.getBitWidth())
+ return UndefValue::get(Op0->getType());
+
+ // If all valid bits in the shift amount are known zero, the first operand is
+ // unchanged.
+ unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth());
+ if (Known.countMinTrailingZeros() >= NumValidShiftBits)
+ return Op0;
+
+ return nullptr;
+}
+
+/// Given operands for an Shl, LShr or AShr, see if we can
+/// fold the result. If not, this returns null.
+static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
+ Value *Op1, bool isExact, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // X >> X -> 0
+ if (Op0 == Op1)
+ return Constant::getNullValue(Op0->getType());
+
+ // undef >> X -> 0
+ // undef >> X -> undef (if it's exact)
+ if (match(Op0, m_Undef()))
+ return isExact ? Op0 : Constant::getNullValue(Op0->getType());
+
+ // The low bit cannot be shifted out of an exact shift if it is set.
+ if (isExact) {
+ KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ if (Op0Known.One[0])
+ return Op0;
+ }
+
+ return nullptr;
+}
+
+/// Given operands for an Shl, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ // undef << X -> 0
+ // undef << X -> undef if (if it's NSW/NUW)
+ if (match(Op0, m_Undef()))
+ return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType());
+
+ // (X >> A) << A -> X
+ Value *X;
+ if (Q.IIQ.UseInstrInfo &&
+ match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1)))))
+ return X;
+
+ // shl nuw i8 C, %x -> C iff C has sign bit set.
+ if (isNUW && match(Op0, m_Negative()))
+ return Op0;
+ // NOTE: could use computeKnownBits() / LazyValueInfo,
+ // but the cost-benefit analysis suggests it isn't worth it.
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+ const SimplifyQuery &Q) {
+ return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit);
+}
+
+/// Given operands for an LShr, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q,
+ MaxRecurse))
+ return V;
+
+ // (X << A) >> A -> X
+ Value *X;
+ if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
+ return X;
+
+ // ((X << A) | Y) >> A -> X if effective width of Y is not larger than A.
+ // We can return X as we do in the above case since OR alters no bits in X.
+ // SimplifyDemandedBits in InstCombine can do more general optimization for
+ // bit manipulation. This pattern aims to provide opportunities for other
+ // optimizers by supporting a simple but common case in InstSimplify.
+ Value *Y;
+ const APInt *ShRAmt, *ShLAmt;
+ if (match(Op1, m_APInt(ShRAmt)) &&
+ match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
+ *ShRAmt == *ShLAmt) {
+ const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const unsigned Width = Op0->getType()->getScalarSizeInBits();
+ const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros();
+ if (ShRAmt->uge(EffWidthY))
+ return X;
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
+ const SimplifyQuery &Q) {
+ return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit);
+}
+
+/// Given operands for an AShr, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q,
+ MaxRecurse))
+ return V;
+
+ // all ones >>a X -> -1
+ // Do not return Op0 because it may contain undef elements if it's a vector.
+ if (match(Op0, m_AllOnes()))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ // (X << A) >> A -> X
+ Value *X;
+ if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
+ return X;
+
+ // Arithmetic shifting an all-sign-bit value is a no-op.
+ unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (NumSignBits == Op0->getType()->getScalarSizeInBits())
+ return Op0;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
+ const SimplifyQuery &Q) {
+ return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit);
+}
+
+/// Commuted variants are assumed to be handled by calling this function again
+/// with the parameters swapped.
+static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
+ ICmpInst *UnsignedICmp, bool IsAnd,
+ const SimplifyQuery &Q) {
+ Value *X, *Y;
+
+ ICmpInst::Predicate EqPred;
+ if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) ||
+ !ICmpInst::isEquality(EqPred))
+ return nullptr;
+
+ ICmpInst::Predicate UnsignedPred;
+
+ Value *A, *B;
+ // Y = (A - B);
+ if (match(Y, m_Sub(m_Value(A), m_Value(B)))) {
+ if (match(UnsignedICmp,
+ m_c_ICmp(UnsignedPred, m_Specific(A), m_Specific(B))) &&
+ ICmpInst::isUnsigned(UnsignedPred)) {
+ if (UnsignedICmp->getOperand(0) != A)
+ UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+
+ // A >=/<= B || (A - B) != 0 <--> true
+ if ((UnsignedPred == ICmpInst::ICMP_UGE ||
+ UnsignedPred == ICmpInst::ICMP_ULE) &&
+ EqPred == ICmpInst::ICMP_NE && !IsAnd)
+ return ConstantInt::getTrue(UnsignedICmp->getType());
+ // A </> B && (A - B) == 0 <--> false
+ if ((UnsignedPred == ICmpInst::ICMP_ULT ||
+ UnsignedPred == ICmpInst::ICMP_UGT) &&
+ EqPred == ICmpInst::ICMP_EQ && IsAnd)
+ return ConstantInt::getFalse(UnsignedICmp->getType());
+
+ // A </> B && (A - B) != 0 <--> A </> B
+ // A </> B || (A - B) != 0 <--> (A - B) != 0
+ if (EqPred == ICmpInst::ICMP_NE && (UnsignedPred == ICmpInst::ICMP_ULT ||
+ UnsignedPred == ICmpInst::ICMP_UGT))
+ return IsAnd ? UnsignedICmp : ZeroICmp;
+
+ // A <=/>= B && (A - B) == 0 <--> (A - B) == 0
+ // A <=/>= B || (A - B) == 0 <--> A <=/>= B
+ if (EqPred == ICmpInst::ICMP_EQ && (UnsignedPred == ICmpInst::ICMP_ULE ||
+ UnsignedPred == ICmpInst::ICMP_UGE))
+ return IsAnd ? ZeroICmp : UnsignedICmp;
+ }
+
+ // Given Y = (A - B)
+ // Y >= A && Y != 0 --> Y >= A iff B != 0
+ // Y < A || Y == 0 --> Y < A iff B != 0
+ if (match(UnsignedICmp,
+ m_c_ICmp(UnsignedPred, m_Specific(Y), m_Specific(A)))) {
+ if (UnsignedICmp->getOperand(0) != Y)
+ UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+
+ if (UnsignedPred == ICmpInst::ICMP_UGE && IsAnd &&
+ EqPred == ICmpInst::ICMP_NE &&
+ isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return UnsignedICmp;
+ if (UnsignedPred == ICmpInst::ICMP_ULT && !IsAnd &&
+ EqPred == ICmpInst::ICMP_EQ &&
+ isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return UnsignedICmp;
+ }
+ }
+
+ if (match(UnsignedICmp, m_ICmp(UnsignedPred, m_Value(X), m_Specific(Y))) &&
+ ICmpInst::isUnsigned(UnsignedPred))
+ ;
+ else if (match(UnsignedICmp,
+ m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) &&
+ ICmpInst::isUnsigned(UnsignedPred))
+ UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+ else
+ return nullptr;
+
+ // X < Y && Y != 0 --> X < Y
+ // X < Y || Y != 0 --> Y != 0
+ if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE)
+ return IsAnd ? UnsignedICmp : ZeroICmp;
+
+ // X <= Y && Y != 0 --> X <= Y iff X != 0
+ // X <= Y || Y != 0 --> Y != 0 iff X != 0
+ if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE &&
+ isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return IsAnd ? UnsignedICmp : ZeroICmp;
+
+ // X >= Y && Y == 0 --> Y == 0
+ // X >= Y || Y == 0 --> X >= Y
+ if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ)
+ return IsAnd ? ZeroICmp : UnsignedICmp;
+
+ // X > Y && Y == 0 --> Y == 0 iff X != 0
+ // X > Y || Y == 0 --> X > Y iff X != 0
+ if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ &&
+ isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return IsAnd ? ZeroICmp : UnsignedICmp;
+
+ // X < Y && Y == 0 --> false
+ if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_EQ &&
+ IsAnd)
+ return getFalse(UnsignedICmp->getType());
+
+ // X >= Y || Y != 0 --> true
+ if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_NE &&
+ !IsAnd)
+ return getTrue(UnsignedICmp->getType());
+
+ return nullptr;
+}
+
+/// Commuted variants are assumed to be handled by calling this function again
+/// with the parameters swapped.
+static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) {
+ ICmpInst::Predicate Pred0, Pred1;
+ Value *A ,*B;
+ if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) ||
+ !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B))))
+ return nullptr;
+
+ // We have (icmp Pred0, A, B) & (icmp Pred1, A, B).
+ // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we
+ // can eliminate Op1 from this 'and'.
+ if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1))
+ return Op0;
+
+ // Check for any combination of predicates that are guaranteed to be disjoint.
+ if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) ||
+ (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) ||
+ (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) ||
+ (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT))
+ return getFalse(Op0->getType());
+
+ return nullptr;
+}
+
+/// Commuted variants are assumed to be handled by calling this function again
+/// with the parameters swapped.
+static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) {
+ ICmpInst::Predicate Pred0, Pred1;
+ Value *A ,*B;
+ if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) ||
+ !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B))))
+ return nullptr;
+
+ // We have (icmp Pred0, A, B) | (icmp Pred1, A, B).
+ // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we
+ // can eliminate Op0 from this 'or'.
+ if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1))
+ return Op1;
+
+ // Check for any combination of predicates that cover the entire range of
+ // possibilities.
+ if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) ||
+ (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) ||
+ (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) ||
+ (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE))
+ return getTrue(Op0->getType());
+
+ return nullptr;
+}
+
+/// Test if a pair of compares with a shared operand and 2 constants has an
+/// empty set intersection, full set union, or if one compare is a superset of
+/// the other.
+static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,
+ bool IsAnd) {
+ // Look for this pattern: {and/or} (icmp X, C0), (icmp X, C1)).
+ if (Cmp0->getOperand(0) != Cmp1->getOperand(0))
+ return nullptr;
+
+ const APInt *C0, *C1;
+ if (!match(Cmp0->getOperand(1), m_APInt(C0)) ||
+ !match(Cmp1->getOperand(1), m_APInt(C1)))
+ return nullptr;
+
+ auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0);
+ auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1);
+
+ // For and-of-compares, check if the intersection is empty:
+ // (icmp X, C0) && (icmp X, C1) --> empty set --> false
+ if (IsAnd && Range0.intersectWith(Range1).isEmptySet())
+ return getFalse(Cmp0->getType());
+
+ // For or-of-compares, check if the union is full:
+ // (icmp X, C0) || (icmp X, C1) --> full set --> true
+ if (!IsAnd && Range0.unionWith(Range1).isFullSet())
+ return getTrue(Cmp0->getType());
+
+ // Is one range a superset of the other?
+ // If this is and-of-compares, take the smaller set:
+ // (icmp sgt X, 4) && (icmp sgt X, 42) --> icmp sgt X, 42
+ // If this is or-of-compares, take the larger set:
+ // (icmp sgt X, 4) || (icmp sgt X, 42) --> icmp sgt X, 4
+ if (Range0.contains(Range1))
+ return IsAnd ? Cmp1 : Cmp0;
+ if (Range1.contains(Range0))
+ return IsAnd ? Cmp0 : Cmp1;
+
+ return nullptr;
+}
+
+static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1,
+ bool IsAnd) {
+ ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate();
+ if (!match(Cmp0->getOperand(1), m_Zero()) ||
+ !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1)
+ return nullptr;
+
+ if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ))
+ return nullptr;
+
+ // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)".
+ Value *X = Cmp0->getOperand(0);
+ Value *Y = Cmp1->getOperand(0);
+
+ // If one of the compares is a masked version of a (not) null check, then
+ // that compare implies the other, so we eliminate the other. Optionally, look
+ // through a pointer-to-int cast to match a null check of a pointer type.
+
+ // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0
+ // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0
+ // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0
+ // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0
+ if (match(Y, m_c_And(m_Specific(X), m_Value())) ||
+ match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value())))
+ return Cmp1;
+
+ // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0
+ // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0
+ // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0
+ // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0
+ if (match(X, m_c_And(m_Specific(Y), m_Value())) ||
+ match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value())))
+ return Cmp0;
+
+ return nullptr;
+}
+
+static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
+ const InstrInfoQuery &IIQ) {
+ // (icmp (add V, C0), C1) & (icmp V, C0)
+ ICmpInst::Predicate Pred0, Pred1;
+ const APInt *C0, *C1;
+ Value *V;
+ if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
+ return nullptr;
+
+ if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value())))
+ return nullptr;
+
+ auto *AddInst = cast<OverflowingBinaryOperator>(Op0->getOperand(0));
+ if (AddInst->getOperand(1) != Op1->getOperand(1))
+ return nullptr;
+
+ Type *ITy = Op0->getType();
+ bool isNSW = IIQ.hasNoSignedWrap(AddInst);
+ bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
+
+ const APInt Delta = *C1 - *C0;
+ if (C0->isStrictlyPositive()) {
+ if (Delta == 2) {
+ if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT)
+ return getFalse(ITy);
+ if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW)
+ return getFalse(ITy);
+ }
+ if (Delta == 1) {
+ if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT)
+ return getFalse(ITy);
+ if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW)
+ return getFalse(ITy);
+ }
+ }
+ if (C0->getBoolValue() && isNUW) {
+ if (Delta == 2)
+ if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)
+ return getFalse(ITy);
+ if (Delta == 1)
+ if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGT)
+ return getFalse(ITy);
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
+ const SimplifyQuery &Q) {
+ if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true, Q))
+ return X;
+ if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true, Q))
+ return X;
+
+ if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1))
+ return X;
+ if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0))
+ return X;
+
+ if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true))
+ return X;
+
+ if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true))
+ return X;
+
+ if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1, Q.IIQ))
+ return X;
+ if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0, Q.IIQ))
+ return X;
+
+ return nullptr;
+}
+
+static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
+ const InstrInfoQuery &IIQ) {
+ // (icmp (add V, C0), C1) | (icmp V, C0)
+ ICmpInst::Predicate Pred0, Pred1;
+ const APInt *C0, *C1;
+ Value *V;
+ if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
+ return nullptr;
+
+ if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value())))
+ return nullptr;
+
+ auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0));
+ if (AddInst->getOperand(1) != Op1->getOperand(1))
+ return nullptr;
+
+ Type *ITy = Op0->getType();
+ bool isNSW = IIQ.hasNoSignedWrap(AddInst);
+ bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
+
+ const APInt Delta = *C1 - *C0;
+ if (C0->isStrictlyPositive()) {
+ if (Delta == 2) {
+ if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE)
+ return getTrue(ITy);
+ if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW)
+ return getTrue(ITy);
+ }
+ if (Delta == 1) {
+ if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE)
+ return getTrue(ITy);
+ if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW)
+ return getTrue(ITy);
+ }
+ }
+ if (C0->getBoolValue() && isNUW) {
+ if (Delta == 2)
+ if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE)
+ return getTrue(ITy);
+ if (Delta == 1)
+ if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE)
+ return getTrue(ITy);
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1,
+ const SimplifyQuery &Q) {
+ if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false, Q))
+ return X;
+ if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false, Q))
+ return X;
+
+ if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1))
+ return X;
+ if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0))
+ return X;
+
+ if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false))
+ return X;
+
+ if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false))
+ return X;
+
+ if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1, Q.IIQ))
+ return X;
+ if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0, Q.IIQ))
+ return X;
+
+ return nullptr;
+}
+
+static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI,
+ FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) {
+ Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1);
+ Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1);
+ if (LHS0->getType() != RHS0->getType())
+ return nullptr;
+
+ FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) ||
+ (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) {
+ // (fcmp ord NNAN, X) & (fcmp ord X, Y) --> fcmp ord X, Y
+ // (fcmp ord NNAN, X) & (fcmp ord Y, X) --> fcmp ord Y, X
+ // (fcmp ord X, NNAN) & (fcmp ord X, Y) --> fcmp ord X, Y
+ // (fcmp ord X, NNAN) & (fcmp ord Y, X) --> fcmp ord Y, X
+ // (fcmp uno NNAN, X) | (fcmp uno X, Y) --> fcmp uno X, Y
+ // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X
+ // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y
+ // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X
+ if ((isKnownNeverNaN(LHS0, TLI) && (LHS1 == RHS0 || LHS1 == RHS1)) ||
+ (isKnownNeverNaN(LHS1, TLI) && (LHS0 == RHS0 || LHS0 == RHS1)))
+ return RHS;
+
+ // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y
+ // (fcmp ord Y, X) & (fcmp ord NNAN, X) --> fcmp ord Y, X
+ // (fcmp ord X, Y) & (fcmp ord X, NNAN) --> fcmp ord X, Y
+ // (fcmp ord Y, X) & (fcmp ord X, NNAN) --> fcmp ord Y, X
+ // (fcmp uno X, Y) | (fcmp uno NNAN, X) --> fcmp uno X, Y
+ // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X
+ // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y
+ // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X
+ if ((isKnownNeverNaN(RHS0, TLI) && (RHS1 == LHS0 || RHS1 == LHS1)) ||
+ (isKnownNeverNaN(RHS1, TLI) && (RHS0 == LHS0 || RHS0 == LHS1)))
+ return LHS;
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q,
+ Value *Op0, Value *Op1, bool IsAnd) {
+ // Look through casts of the 'and' operands to find compares.
+ auto *Cast0 = dyn_cast<CastInst>(Op0);
+ auto *Cast1 = dyn_cast<CastInst>(Op1);
+ if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() &&
+ Cast0->getSrcTy() == Cast1->getSrcTy()) {
+ Op0 = Cast0->getOperand(0);
+ Op1 = Cast1->getOperand(0);
+ }
+
+ Value *V = nullptr;
+ auto *ICmp0 = dyn_cast<ICmpInst>(Op0);
+ auto *ICmp1 = dyn_cast<ICmpInst>(Op1);
+ if (ICmp0 && ICmp1)
+ V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1, Q)
+ : simplifyOrOfICmps(ICmp0, ICmp1, Q);
+
+ auto *FCmp0 = dyn_cast<FCmpInst>(Op0);
+ auto *FCmp1 = dyn_cast<FCmpInst>(Op1);
+ if (FCmp0 && FCmp1)
+ V = simplifyAndOrOfFCmps(Q.TLI, FCmp0, FCmp1, IsAnd);
+
+ if (!V)
+ return nullptr;
+ if (!Cast0)
+ return V;
+
+ // If we looked through casts, we can only handle a constant simplification
+ // because we are not allowed to create a cast instruction here.
+ if (auto *C = dyn_cast<Constant>(V))
+ return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType());
+
+ return nullptr;
+}
+
+/// Check that the Op1 is in expected form, i.e.:
+/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
+/// %Op1 = extractvalue { i4, i1 } %Agg, 1
+static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1,
+ Value *X) {
+ auto *Extract = dyn_cast<ExtractValueInst>(Op1);
+ // We should only be extracting the overflow bit.
+ if (!Extract || !Extract->getIndices().equals(1))
+ return false;
+ Value *Agg = Extract->getAggregateOperand();
+ // This should be a multiplication-with-overflow intrinsic.
+ if (!match(Agg, m_CombineOr(m_Intrinsic<Intrinsic::umul_with_overflow>(),
+ m_Intrinsic<Intrinsic::smul_with_overflow>())))
+ return false;
+ // One of its multipliers should be the value we checked for zero before.
+ if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)),
+ m_Argument<1>(m_Specific(X)))))
+ return false;
+ return true;
+}
+
+/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some
+/// other form of check, e.g. one that was using division; it may have been
+/// guarded against division-by-zero. We can drop that check now.
+/// Look for:
+/// %Op0 = icmp ne i4 %X, 0
+/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
+/// %Op1 = extractvalue { i4, i1 } %Agg, 1
+/// %??? = and i1 %Op0, %Op1
+/// We can just return %Op1
+static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) {
+ ICmpInst::Predicate Pred;
+ Value *X;
+ if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) ||
+ Pred != ICmpInst::Predicate::ICMP_NE)
+ return nullptr;
+ // Is Op1 in expected form?
+ if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X))
+ return nullptr;
+ // Can omit 'and', and just return the overflow bit.
+ return Op1;
+}
+
+/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some
+/// other form of check, e.g. one that was using division; it may have been
+/// guarded against division-by-zero. We can drop that check now.
+/// Look for:
+/// %Op0 = icmp eq i4 %X, 0
+/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
+/// %Op1 = extractvalue { i4, i1 } %Agg, 1
+/// %NotOp1 = xor i1 %Op1, true
+/// %or = or i1 %Op0, %NotOp1
+/// We can just return %NotOp1
+static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0,
+ Value *NotOp1) {
+ ICmpInst::Predicate Pred;
+ Value *X;
+ if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) ||
+ Pred != ICmpInst::Predicate::ICMP_EQ)
+ return nullptr;
+ // We expect the other hand of an 'or' to be a 'not'.
+ Value *Op1;
+ if (!match(NotOp1, m_Not(m_Value(Op1))))
+ return nullptr;
+ // Is Op1 in expected form?
+ if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X))
+ return nullptr;
+ // Can omit 'and', and just return the inverted overflow bit.
+ return NotOp1;
+}
+
+/// Given operands for an And, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q))
+ return C;
+
+ // X & undef -> 0
+ if (match(Op1, m_Undef()))
+ return Constant::getNullValue(Op0->getType());
+
+ // X & X = X
+ if (Op0 == Op1)
+ return Op0;
+
+ // X & 0 = 0
+ if (match(Op1, m_Zero()))
+ return Constant::getNullValue(Op0->getType());
+
+ // X & -1 = X
+ if (match(Op1, m_AllOnes()))
+ return Op0;
+
+ // A & ~A = ~A & A = 0
+ if (match(Op0, m_Not(m_Specific(Op1))) ||
+ match(Op1, m_Not(m_Specific(Op0))))
+ return Constant::getNullValue(Op0->getType());
+
+ // (A | ?) & A = A
+ if (match(Op0, m_c_Or(m_Specific(Op1), m_Value())))
+ return Op1;
+
+ // A & (A | ?) = A
+ if (match(Op1, m_c_Or(m_Specific(Op0), m_Value())))
+ return Op0;
+
+ // A mask that only clears known zeros of a shifted value is a no-op.
+ Value *X;
+ const APInt *Mask;
+ const APInt *ShAmt;
+ if (match(Op1, m_APInt(Mask))) {
+ // If all bits in the inverted and shifted mask are clear:
+ // and (shl X, ShAmt), Mask --> shl X, ShAmt
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShAmt))) &&
+ (~(*Mask)).lshr(*ShAmt).isNullValue())
+ return Op0;
+
+ // If all bits in the inverted and shifted mask are clear:
+ // and (lshr X, ShAmt), Mask --> lshr X, ShAmt
+ if (match(Op0, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
+ (~(*Mask)).shl(*ShAmt).isNullValue())
+ return Op0;
+ }
+
+ // If we have a multiplication overflow check that is being 'and'ed with a
+ // check that one of the multipliers is not zero, we can omit the 'and', and
+ // only keep the overflow check.
+ if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1))
+ return V;
+ if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0))
+ return V;
+
+ // A & (-A) = A if A is a power of two or zero.
+ if (match(Op0, m_Neg(m_Specific(Op1))) ||
+ match(Op1, m_Neg(m_Specific(Op0)))) {
+ if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI,
+ Q.DT))
+ return Op0;
+ if (isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI,
+ Q.DT))
+ return Op1;
+ }
+
+ // This is a similar pattern used for checking if a value is a power-of-2:
+ // (A - 1) & A --> 0 (if A is a power-of-2 or 0)
+ // A & (A - 1) --> 0 (if A is a power-of-2 or 0)
+ if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) &&
+ isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
+ return Constant::getNullValue(Op1->getType());
+ if (match(Op1, m_Add(m_Specific(Op0), m_AllOnes())) &&
+ isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
+ return Constant::getNullValue(Op0->getType());
+
+ if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true))
+ return V;
+
+ // Try some generic simplifications for associative operations.
+ if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // And distributes over Or. Try some generic simplifications based on this.
+ if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Or,
+ Q, MaxRecurse))
+ return V;
+
+ // And distributes over Xor. Try some generic simplifications based on this.
+ if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Xor,
+ Q, MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // Assuming the effective width of Y is not larger than A, i.e. all bits
+ // from X and Y are disjoint in (X << A) | Y,
+ // if the mask of this AND op covers all bits of X or Y, while it covers
+ // no bits from the other, we can bypass this AND op. E.g.,
+ // ((X << A) | Y) & Mask -> Y,
+ // if Mask = ((1 << effective_width_of(Y)) - 1)
+ // ((X << A) | Y) & Mask -> X << A,
+ // if Mask = ((1 << effective_width_of(X)) - 1) << A
+ // SimplifyDemandedBits in InstCombine can optimize the general case.
+ // This pattern aims to help other passes for a common case.
+ Value *Y, *XShifted;
+ if (match(Op1, m_APInt(Mask)) &&
+ match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)),
+ m_Value(XShifted)),
+ m_Value(Y)))) {
+ const unsigned Width = Op0->getType()->getScalarSizeInBits();
+ const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
+ const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros();
+ if (EffWidthY <= ShftCnt) {
+ const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI,
+ Q.DT);
+ const unsigned EffWidthX = Width - XKnown.countMinLeadingZeros();
+ const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
+ const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
+ // If the mask is extracting all bits from X or Y as is, we can skip
+ // this AND op.
+ if (EffBitsY.isSubsetOf(*Mask) && !EffBitsX.intersects(*Mask))
+ return Y;
+ if (EffBitsX.isSubsetOf(*Mask) && !EffBitsY.intersects(*Mask))
+ return XShifted;
+ }
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Given operands for an Or, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q))
+ return C;
+
+ // X | undef -> -1
+ // X | -1 = -1
+ // Do not return Op1 because it may contain undef elements if it's a vector.
+ if (match(Op1, m_Undef()) || match(Op1, m_AllOnes()))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ // X | X = X
+ // X | 0 = X
+ if (Op0 == Op1 || match(Op1, m_Zero()))
+ return Op0;
+
+ // A | ~A = ~A | A = -1
+ if (match(Op0, m_Not(m_Specific(Op1))) ||
+ match(Op1, m_Not(m_Specific(Op0))))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ // (A & ?) | A = A
+ if (match(Op0, m_c_And(m_Specific(Op1), m_Value())))
+ return Op1;
+
+ // A | (A & ?) = A
+ if (match(Op1, m_c_And(m_Specific(Op0), m_Value())))
+ return Op0;
+
+ // ~(A & ?) | A = -1
+ if (match(Op0, m_Not(m_c_And(m_Specific(Op1), m_Value()))))
+ return Constant::getAllOnesValue(Op1->getType());
+
+ // A | ~(A & ?) = -1
+ if (match(Op1, m_Not(m_c_And(m_Specific(Op1), m_Value()))))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ Value *A, *B;
+ // (A & ~B) | (A ^ B) -> (A ^ B)
+ // (~B & A) | (A ^ B) -> (A ^ B)
+ // (A & ~B) | (B ^ A) -> (B ^ A)
+ // (~B & A) | (B ^ A) -> (B ^ A)
+ if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
+ (match(Op0, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) ||
+ match(Op0, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))))
+ return Op1;
+
+ // Commute the 'or' operands.
+ // (A ^ B) | (A & ~B) -> (A ^ B)
+ // (A ^ B) | (~B & A) -> (A ^ B)
+ // (B ^ A) | (A & ~B) -> (B ^ A)
+ // (B ^ A) | (~B & A) -> (B ^ A)
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ (match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) ||
+ match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))))
+ return Op0;
+
+ // (A & B) | (~A ^ B) -> (~A ^ B)
+ // (B & A) | (~A ^ B) -> (~A ^ B)
+ // (A & B) | (B ^ ~A) -> (B ^ ~A)
+ // (B & A) | (B ^ ~A) -> (B ^ ~A)
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) ||
+ match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B)))))
+ return Op1;
+
+ // (~A ^ B) | (A & B) -> (~A ^ B)
+ // (~A ^ B) | (B & A) -> (~A ^ B)
+ // (B ^ ~A) | (A & B) -> (B ^ ~A)
+ // (B ^ ~A) | (B & A) -> (B ^ ~A)
+ if (match(Op1, m_And(m_Value(A), m_Value(B))) &&
+ (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) ||
+ match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B)))))
+ return Op0;
+
+ if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false))
+ return V;
+
+ // If we have a multiplication overflow check that is being 'and'ed with a
+ // check that one of the multipliers is not zero, we can omit the 'and', and
+ // only keep the overflow check.
+ if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1))
+ return V;
+ if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0))
+ return V;
+
+ // Try some generic simplifications for associative operations.
+ if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // Or distributes over And. Try some generic simplifications based on this.
+ if (Value *V = ExpandBinOp(Instruction::Or, Op0, Op1, Instruction::And, Q,
+ MaxRecurse))
+ return V;
+
+ // If the operation is with the result of a select instruction, check whether
+ // operating on either branch of the select always yields the same value.
+ if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+ if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // (A & C1)|(B & C2)
+ const APInt *C1, *C2;
+ if (match(Op0, m_And(m_Value(A), m_APInt(C1))) &&
+ match(Op1, m_And(m_Value(B), m_APInt(C2)))) {
+ if (*C1 == ~*C2) {
+ // (A & C1)|(B & C2)
+ // If we have: ((V + N) & C1) | (V & C2)
+ // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0
+ // replace with V+N.
+ Value *N;
+ if (C2->isMask() && // C2 == 0+1+
+ match(A, m_c_Add(m_Specific(B), m_Value(N)))) {
+ // Add commutes, try both ways.
+ if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return A;
+ }
+ // Or commutes, try both ways.
+ if (C1->isMask() &&
+ match(B, m_c_Add(m_Specific(A), m_Value(N)))) {
+ // Add commutes, try both ways.
+ if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return B;
+ }
+ }
+ }
+
+ // If the operation is with the result of a phi instruction, check whether
+ // operating on all incoming values of the phi always yields the same value.
+ if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+ if (Value *V = ThreadBinOpOverPHI(Instruction::Or, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit);
+}
+
+/// Given operands for a Xor, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q))
+ return C;
+
+ // A ^ undef -> undef
+ if (match(Op1, m_Undef()))
+ return Op1;
+
+ // A ^ 0 = A
+ if (match(Op1, m_Zero()))
+ return Op0;
+
+ // A ^ A = 0
+ if (Op0 == Op1)
+ return Constant::getNullValue(Op0->getType());
+
+ // A ^ ~A = ~A ^ A = -1
+ if (match(Op0, m_Not(m_Specific(Op1))) ||
+ match(Op1, m_Not(m_Specific(Op0))))
+ return Constant::getAllOnesValue(Op0->getType());
+
+ // Try some generic simplifications for associative operations.
+ if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q,
+ MaxRecurse))
+ return V;
+
+ // Threading Xor over selects and phi nodes is pointless, so don't bother.
+ // Threading over the select in "A ^ select(cond, B, C)" means evaluating
+ // "A^B" and "A^C" and seeing if they are equal; but they are equal if and
+ // only if B and C are equal. If B and C are equal then (since we assume
+ // that operands have already been simplified) "select(cond, B, C)" should
+ // have been simplified to the common value of B and C already. Analysing
+ // "A^B" and "A^C" thus gains nothing, but costs compile time. Similarly
+ // for threading over phi nodes.
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
+ return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit);
+}
+
+
+static Type *GetCompareTy(Value *Op) {
+ return CmpInst::makeCmpResultType(Op->getType());
+}
+
+/// Rummage around inside V looking for something equivalent to the comparison
+/// "LHS Pred RHS". Return such a value if found, otherwise return null.
+/// Helper function for analyzing max/min idioms.
+static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
+ Value *LHS, Value *RHS) {
+ SelectInst *SI = dyn_cast<SelectInst>(V);
+ if (!SI)
+ return nullptr;
+ CmpInst *Cmp = dyn_cast<CmpInst>(SI->getCondition());
+ if (!Cmp)
+ return nullptr;
+ Value *CmpLHS = Cmp->getOperand(0), *CmpRHS = Cmp->getOperand(1);
+ if (Pred == Cmp->getPredicate() && LHS == CmpLHS && RHS == CmpRHS)
+ return Cmp;
+ if (Pred == CmpInst::getSwappedPredicate(Cmp->getPredicate()) &&
+ LHS == CmpRHS && RHS == CmpLHS)
+ return Cmp;
+ return nullptr;
+}
+
+// A significant optimization not implemented here is assuming that alloca
+// addresses are not equal to incoming argument values. They don't *alias*,
+// as we say, but that doesn't mean they aren't equal, so we take a
+// conservative approach.
+//
+// This is inspired in part by C++11 5.10p1:
+// "Two pointers of the same type compare equal if and only if they are both
+// null, both point to the same function, or both represent the same
+// address."
+//
+// This is pretty permissive.
+//
+// It's also partly due to C11 6.5.9p6:
+// "Two pointers compare equal if and only if both are null pointers, both are
+// pointers to the same object (including a pointer to an object and a
+// subobject at its beginning) or function, both are pointers to one past the
+// last element of the same array object, or one is a pointer to one past the
+// end of one array object and the other is a pointer to the start of a
+// different array object that happens to immediately follow the first array
+// object in the address space.)
+//
+// C11's version is more restrictive, however there's no reason why an argument
+// couldn't be a one-past-the-end value for a stack object in the caller and be
+// equal to the beginning of a stack object in the callee.
+//
+// If the C and C++ standards are ever made sufficiently restrictive in this
+// area, it may be possible to update LLVM's semantics accordingly and reinstate
+// this optimization.
+static Constant *
+computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI,
+ const DominatorTree *DT, CmpInst::Predicate Pred,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) {
+ // First, skip past any trivial no-ops.
+ LHS = LHS->stripPointerCasts();
+ RHS = RHS->stripPointerCasts();
+
+ // A non-null pointer is not equal to a null pointer.
+ if (llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr,
+ IIQ.UseInstrInfo) &&
+ isa<ConstantPointerNull>(RHS) &&
+ (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE))
+ return ConstantInt::get(GetCompareTy(LHS),
+ !CmpInst::isTrueWhenEqual(Pred));
+
+ // We can only fold certain predicates on pointer comparisons.
+ switch (Pred) {
+ default:
+ return nullptr;
+
+ // Equality comaprisons are easy to fold.
+ case CmpInst::ICMP_EQ:
+ case CmpInst::ICMP_NE:
+ break;
+
+ // We can only handle unsigned relational comparisons because 'inbounds' on
+ // a GEP only protects against unsigned wrapping.
+ case CmpInst::ICMP_UGT:
+ case CmpInst::ICMP_UGE:
+ case CmpInst::ICMP_ULT:
+ case CmpInst::ICMP_ULE:
+ // However, we have to switch them to their signed variants to handle
+ // negative indices from the base pointer.
+ Pred = ICmpInst::getSignedPredicate(Pred);
+ break;
+ }
+
+ // Strip off any constant offsets so that we can reason about them.
+ // It's tempting to use getUnderlyingObject or even just stripInBoundsOffsets
+ // here and compare base addresses like AliasAnalysis does, however there are
+ // numerous hazards. AliasAnalysis and its utilities rely on special rules
+ // governing loads and stores which don't apply to icmps. Also, AliasAnalysis
+ // doesn't need to guarantee pointer inequality when it says NoAlias.
+ Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS);
+ Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS);
+
+ // If LHS and RHS are related via constant offsets to the same base
+ // value, we can replace it with an icmp which just compares the offsets.
+ if (LHS == RHS)
+ return ConstantExpr::getICmp(Pred, LHSOffset, RHSOffset);
+
+ // Various optimizations for (in)equality comparisons.
+ if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) {
+ // Different non-empty allocations that exist at the same time have
+ // different addresses (if the program can tell). Global variables always
+ // exist, so they always exist during the lifetime of each other and all
+ // allocas. Two different allocas usually have different addresses...
+ //
+ // However, if there's an @llvm.stackrestore dynamically in between two
+ // allocas, they may have the same address. It's tempting to reduce the
+ // scope of the problem by only looking at *static* allocas here. That would
+ // cover the majority of allocas while significantly reducing the likelihood
+ // of having an @llvm.stackrestore pop up in the middle. However, it's not
+ // actually impossible for an @llvm.stackrestore to pop up in the middle of
+ // an entry block. Also, if we have a block that's not attached to a
+ // function, we can't tell if it's "static" under the current definition.
+ // Theoretically, this problem could be fixed by creating a new kind of
+ // instruction kind specifically for static allocas. Such a new instruction
+ // could be required to be at the top of the entry block, thus preventing it
+ // from being subject to a @llvm.stackrestore. Instcombine could even
+ // convert regular allocas into these special allocas. It'd be nifty.
+ // However, until then, this problem remains open.
+ //
+ // So, we'll assume that two non-empty allocas have different addresses
+ // for now.
+ //
+ // With all that, if the offsets are within the bounds of their allocations
+ // (and not one-past-the-end! so we can't use inbounds!), and their
+ // allocations aren't the same, the pointers are not equal.
+ //
+ // Note that it's not necessary to check for LHS being a global variable
+ // address, due to canonicalization and constant folding.
+ if (isa<AllocaInst>(LHS) &&
+ (isa<AllocaInst>(RHS) || isa<GlobalVariable>(RHS))) {
+ ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset);
+ ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset);
+ uint64_t LHSSize, RHSSize;
+ ObjectSizeOpts Opts;
+ Opts.NullIsUnknownSize =
+ NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction());
+ if (LHSOffsetCI && RHSOffsetCI &&
+ getObjectSize(LHS, LHSSize, DL, TLI, Opts) &&
+ getObjectSize(RHS, RHSSize, DL, TLI, Opts)) {
+ const APInt &LHSOffsetValue = LHSOffsetCI->getValue();
+ const APInt &RHSOffsetValue = RHSOffsetCI->getValue();
+ if (!LHSOffsetValue.isNegative() &&
+ !RHSOffsetValue.isNegative() &&
+ LHSOffsetValue.ult(LHSSize) &&
+ RHSOffsetValue.ult(RHSSize)) {
+ return ConstantInt::get(GetCompareTy(LHS),
+ !CmpInst::isTrueWhenEqual(Pred));
+ }
+ }
+
+ // Repeat the above check but this time without depending on DataLayout
+ // or being able to compute a precise size.
+ if (!cast<PointerType>(LHS->getType())->isEmptyTy() &&
+ !cast<PointerType>(RHS->getType())->isEmptyTy() &&
+ LHSOffset->isNullValue() &&
+ RHSOffset->isNullValue())
+ return ConstantInt::get(GetCompareTy(LHS),
+ !CmpInst::isTrueWhenEqual(Pred));
+ }
+
+ // Even if an non-inbounds GEP occurs along the path we can still optimize
+ // equality comparisons concerning the result. We avoid walking the whole
+ // chain again by starting where the last calls to
+ // stripAndComputeConstantOffsets left off and accumulate the offsets.
+ Constant *LHSNoBound = stripAndComputeConstantOffsets(DL, LHS, true);
+ Constant *RHSNoBound = stripAndComputeConstantOffsets(DL, RHS, true);
+ if (LHS == RHS)
+ return ConstantExpr::getICmp(Pred,
+ ConstantExpr::getAdd(LHSOffset, LHSNoBound),
+ ConstantExpr::getAdd(RHSOffset, RHSNoBound));
+
+ // If one side of the equality comparison must come from a noalias call
+ // (meaning a system memory allocation function), and the other side must
+ // come from a pointer that cannot overlap with dynamically-allocated
+ // memory within the lifetime of the current function (allocas, byval
+ // arguments, globals), then determine the comparison result here.
+ SmallVector<const Value *, 8> LHSUObjs, RHSUObjs;
+ GetUnderlyingObjects(LHS, LHSUObjs, DL);
+ GetUnderlyingObjects(RHS, RHSUObjs, DL);
+
+ // Is the set of underlying objects all noalias calls?
+ auto IsNAC = [](ArrayRef<const Value *> Objects) {
+ return all_of(Objects, isNoAliasCall);
+ };
+
+ // Is the set of underlying objects all things which must be disjoint from
+ // noalias calls. For allocas, we consider only static ones (dynamic
+ // allocas might be transformed into calls to malloc not simultaneously
+ // live with the compared-to allocation). For globals, we exclude symbols
+ // that might be resolve lazily to symbols in another dynamically-loaded
+ // library (and, thus, could be malloc'ed by the implementation).
+ auto IsAllocDisjoint = [](ArrayRef<const Value *> Objects) {
+ return all_of(Objects, [](const Value *V) {
+ if (const AllocaInst *AI = dyn_cast<AllocaInst>(V))
+ return AI->getParent() && AI->getFunction() && AI->isStaticAlloca();
+ if (const GlobalValue *GV = dyn_cast<GlobalValue>(V))
+ return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() ||
+ GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) &&
+ !GV->isThreadLocal();
+ if (const Argument *A = dyn_cast<Argument>(V))
+ return A->hasByValAttr();
+ return false;
+ });
+ };
+
+ if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) ||
+ (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs)))
+ return ConstantInt::get(GetCompareTy(LHS),
+ !CmpInst::isTrueWhenEqual(Pred));
+
+ // Fold comparisons for non-escaping pointer even if the allocation call
+ // cannot be elided. We cannot fold malloc comparison to null. Also, the
+ // dynamic allocation call could be either of the operands.
+ Value *MI = nullptr;
+ if (isAllocLikeFn(LHS, TLI) &&
+ llvm::isKnownNonZero(RHS, DL, 0, nullptr, CxtI, DT))
+ MI = LHS;
+ else if (isAllocLikeFn(RHS, TLI) &&
+ llvm::isKnownNonZero(LHS, DL, 0, nullptr, CxtI, DT))
+ MI = RHS;
+ // FIXME: We should also fold the compare when the pointer escapes, but the
+ // compare dominates the pointer escape
+ if (MI && !PointerMayBeCaptured(MI, true, true))
+ return ConstantInt::get(GetCompareTy(LHS),
+ CmpInst::isFalseWhenEqual(Pred));
+ }
+
+ // Otherwise, fail.
+ return nullptr;
+}
+
+/// Fold an icmp when its operands have i1 scalar type.
+static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q) {
+ Type *ITy = GetCompareTy(LHS); // The return type.
+ Type *OpTy = LHS->getType(); // The operand type.
+ if (!OpTy->isIntOrIntVectorTy(1))
+ return nullptr;
+
+ // A boolean compared to true/false can be simplified in 14 out of the 20
+ // (10 predicates * 2 constants) possible combinations. Cases not handled here
+ // require a 'not' of the LHS, so those must be transformed in InstCombine.
+ if (match(RHS, m_Zero())) {
+ switch (Pred) {
+ case CmpInst::ICMP_NE: // X != 0 -> X
+ case CmpInst::ICMP_UGT: // X >u 0 -> X
+ case CmpInst::ICMP_SLT: // X <s 0 -> X
+ return LHS;
+
+ case CmpInst::ICMP_ULT: // X <u 0 -> false
+ case CmpInst::ICMP_SGT: // X >s 0 -> false
+ return getFalse(ITy);
+
+ case CmpInst::ICMP_UGE: // X >=u 0 -> true
+ case CmpInst::ICMP_SLE: // X <=s 0 -> true
+ return getTrue(ITy);
+
+ default: break;
+ }
+ } else if (match(RHS, m_One())) {
+ switch (Pred) {
+ case CmpInst::ICMP_EQ: // X == 1 -> X
+ case CmpInst::ICMP_UGE: // X >=u 1 -> X
+ case CmpInst::ICMP_SLE: // X <=s -1 -> X
+ return LHS;
+
+ case CmpInst::ICMP_UGT: // X >u 1 -> false
+ case CmpInst::ICMP_SLT: // X <s -1 -> false
+ return getFalse(ITy);
+
+ case CmpInst::ICMP_ULE: // X <=u 1 -> true
+ case CmpInst::ICMP_SGE: // X >=s -1 -> true
+ return getTrue(ITy);
+
+ default: break;
+ }
+ }
+
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_UGE:
+ if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false))
+ return getTrue(ITy);
+ break;
+ case ICmpInst::ICMP_SGE:
+ /// For signed comparison, the values for an i1 are 0 and -1
+ /// respectively. This maps into a truth table of:
+ /// LHS | RHS | LHS >=s RHS | LHS implies RHS
+ /// 0 | 0 | 1 (0 >= 0) | 1
+ /// 0 | 1 | 1 (0 >= -1) | 1
+ /// 1 | 0 | 0 (-1 >= 0) | 0
+ /// 1 | 1 | 1 (-1 >= -1) | 1
+ if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false))
+ return getTrue(ITy);
+ break;
+ case ICmpInst::ICMP_ULE:
+ if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false))
+ return getTrue(ITy);
+ break;
+ }
+
+ return nullptr;
+}
+
+/// Try hard to fold icmp with zero RHS because this is a common case.
+static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q) {
+ if (!match(RHS, m_Zero()))
+ return nullptr;
+
+ Type *ITy = GetCompareTy(LHS); // The return type.
+ switch (Pred) {
+ default:
+ llvm_unreachable("Unknown ICmp predicate!");
+ case ICmpInst::ICMP_ULT:
+ return getFalse(ITy);
+ case ICmpInst::ICMP_UGE:
+ return getTrue(ITy);
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_ULE:
+ if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo))
+ return getFalse(ITy);
+ break;
+ case ICmpInst::ICMP_NE:
+ case ICmpInst::ICMP_UGT:
+ if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo))
+ return getTrue(ITy);
+ break;
+ case ICmpInst::ICMP_SLT: {
+ KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (LHSKnown.isNegative())
+ return getTrue(ITy);
+ if (LHSKnown.isNonNegative())
+ return getFalse(ITy);
+ break;
+ }
+ case ICmpInst::ICMP_SLE: {
+ KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (LHSKnown.isNegative())
+ return getTrue(ITy);
+ if (LHSKnown.isNonNegative() &&
+ isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return getFalse(ITy);
+ break;
+ }
+ case ICmpInst::ICMP_SGE: {
+ KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (LHSKnown.isNegative())
+ return getFalse(ITy);
+ if (LHSKnown.isNonNegative())
+ return getTrue(ITy);
+ break;
+ }
+ case ICmpInst::ICMP_SGT: {
+ KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (LHSKnown.isNegative())
+ return getFalse(ITy);
+ if (LHSKnown.isNonNegative() &&
+ isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return getTrue(ITy);
+ break;
+ }
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const InstrInfoQuery &IIQ) {
+ Type *ITy = GetCompareTy(RHS); // The return type.
+
+ Value *X;
+ // Sign-bit checks can be optimized to true/false after unsigned
+ // floating-point casts:
+ // icmp slt (bitcast (uitofp X)), 0 --> false
+ // icmp sgt (bitcast (uitofp X)), -1 --> true
+ if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+ if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
+ return ConstantInt::getFalse(ITy);
+ if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
+ return ConstantInt::getTrue(ITy);
+ }
+
+ const APInt *C;
+ if (!match(RHS, m_APInt(C)))
+ return nullptr;
+
+ // Rule out tautological comparisons (eg., ult 0 or uge 0).
+ ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C);
+ if (RHS_CR.isEmptySet())
+ return ConstantInt::getFalse(ITy);
+ if (RHS_CR.isFullSet())
+ return ConstantInt::getTrue(ITy);
+
+ ConstantRange LHS_CR = computeConstantRange(LHS, IIQ.UseInstrInfo);
+ if (!LHS_CR.isFullSet()) {
+ if (RHS_CR.contains(LHS_CR))
+ return ConstantInt::getTrue(ITy);
+ if (RHS_CR.inverse().contains(LHS_CR))
+ return ConstantInt::getFalse(ITy);
+ }
+
+ return nullptr;
+}
+
+/// TODO: A large part of this logic is duplicated in InstCombine's
+/// foldICmpBinOp(). We should be able to share that and avoid the code
+/// duplication.
+static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ Type *ITy = GetCompareTy(LHS); // The return type.
+
+ BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
+ BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);
+ if (MaxRecurse && (LBO || RBO)) {
+ // Analyze the case when either LHS or RHS is an add instruction.
+ Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+ // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null).
+ bool NoLHSWrapProblem = false, NoRHSWrapProblem = false;
+ if (LBO && LBO->getOpcode() == Instruction::Add) {
+ A = LBO->getOperand(0);
+ B = LBO->getOperand(1);
+ NoLHSWrapProblem =
+ ICmpInst::isEquality(Pred) ||
+ (CmpInst::isUnsigned(Pred) &&
+ Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO))) ||
+ (CmpInst::isSigned(Pred) &&
+ Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)));
+ }
+ if (RBO && RBO->getOpcode() == Instruction::Add) {
+ C = RBO->getOperand(0);
+ D = RBO->getOperand(1);
+ NoRHSWrapProblem =
+ ICmpInst::isEquality(Pred) ||
+ (CmpInst::isUnsigned(Pred) &&
+ Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(RBO))) ||
+ (CmpInst::isSigned(Pred) &&
+ Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(RBO)));
+ }
+
+ // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
+ if ((A == RHS || B == RHS) && NoLHSWrapProblem)
+ if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A,
+ Constant::getNullValue(RHS->getType()), Q,
+ MaxRecurse - 1))
+ return V;
+
+ // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow.
+ if ((C == LHS || D == LHS) && NoRHSWrapProblem)
+ if (Value *V =
+ SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()),
+ C == LHS ? D : C, Q, MaxRecurse - 1))
+ return V;
+
+ // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow.
+ if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem &&
+ NoRHSWrapProblem) {
+ // Determine Y and Z in the form icmp (X+Y), (X+Z).
+ Value *Y, *Z;
+ if (A == C) {
+ // C + B == C + D -> B == D
+ Y = B;
+ Z = D;
+ } else if (A == D) {
+ // D + B == C + D -> B == C
+ Y = B;
+ Z = C;
+ } else if (B == C) {
+ // A + C == C + D -> A == D
+ Y = A;
+ Z = D;
+ } else {
+ assert(B == D);
+ // A + D == C + D -> A == C
+ Y = A;
+ Z = C;
+ }
+ if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1))
+ return V;
+ }
+ }
+
+ {
+ Value *Y = nullptr;
+ // icmp pred (or X, Y), X
+ if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
+ if (Pred == ICmpInst::ICMP_ULT)
+ return getFalse(ITy);
+ if (Pred == ICmpInst::ICMP_UGE)
+ return getTrue(ITy);
+
+ if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
+ KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (RHSKnown.isNonNegative() && YKnown.isNegative())
+ return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
+ if (RHSKnown.isNegative() || YKnown.isNonNegative())
+ return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy);
+ }
+ }
+ // icmp pred X, (or X, Y)
+ if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) {
+ if (Pred == ICmpInst::ICMP_ULE)
+ return getTrue(ITy);
+ if (Pred == ICmpInst::ICMP_UGT)
+ return getFalse(ITy);
+
+ if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) {
+ KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (LHSKnown.isNonNegative() && YKnown.isNegative())
+ return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy);
+ if (LHSKnown.isNegative() || YKnown.isNonNegative())
+ return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy);
+ }
+ }
+ }
+
+ // icmp pred (and X, Y), X
+ if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
+ if (Pred == ICmpInst::ICMP_UGT)
+ return getFalse(ITy);
+ if (Pred == ICmpInst::ICMP_ULE)
+ return getTrue(ITy);
+ }
+ // icmp pred X, (and X, Y)
+ if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) {
+ if (Pred == ICmpInst::ICMP_UGE)
+ return getTrue(ITy);
+ if (Pred == ICmpInst::ICMP_ULT)
+ return getFalse(ITy);
+ }
+
+ // 0 - (zext X) pred C
+ if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) {
+ if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) {
+ if (RHSC->getValue().isStrictlyPositive()) {
+ if (Pred == ICmpInst::ICMP_SLT)
+ return ConstantInt::getTrue(RHSC->getContext());
+ if (Pred == ICmpInst::ICMP_SGE)
+ return ConstantInt::getFalse(RHSC->getContext());
+ if (Pred == ICmpInst::ICMP_EQ)
+ return ConstantInt::getFalse(RHSC->getContext());
+ if (Pred == ICmpInst::ICMP_NE)
+ return ConstantInt::getTrue(RHSC->getContext());
+ }
+ if (RHSC->getValue().isNonNegative()) {
+ if (Pred == ICmpInst::ICMP_SLE)
+ return ConstantInt::getTrue(RHSC->getContext());
+ if (Pred == ICmpInst::ICMP_SGT)
+ return ConstantInt::getFalse(RHSC->getContext());
+ }
+ }
+ }
+
+ // icmp pred (urem X, Y), Y
+ if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE: {
+ KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (!Known.isNonNegative())
+ break;
+ LLVM_FALLTHROUGH;
+ }
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ return getFalse(ITy);
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE: {
+ KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (!Known.isNonNegative())
+ break;
+ LLVM_FALLTHROUGH;
+ }
+ case ICmpInst::ICMP_NE:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ return getTrue(ITy);
+ }
+ }
+
+ // icmp pred X, (urem Y, X)
+ if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) {
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE: {
+ KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (!Known.isNonNegative())
+ break;
+ LLVM_FALLTHROUGH;
+ }
+ case ICmpInst::ICMP_NE:
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ return getTrue(ITy);
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE: {
+ KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (!Known.isNonNegative())
+ break;
+ LLVM_FALLTHROUGH;
+ }
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ return getFalse(ITy);
+ }
+ }
+
+ // x >> y <=u x
+ // x udiv y <=u x.
+ if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
+ match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) {
+ // icmp pred (X op Y), X
+ if (Pred == ICmpInst::ICMP_UGT)
+ return getFalse(ITy);
+ if (Pred == ICmpInst::ICMP_ULE)
+ return getTrue(ITy);
+ }
+
+ // x >=u x >> y
+ // x >=u x udiv y.
+ if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) ||
+ match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) {
+ // icmp pred X, (X op Y)
+ if (Pred == ICmpInst::ICMP_ULT)
+ return getFalse(ITy);
+ if (Pred == ICmpInst::ICMP_UGE)
+ return getTrue(ITy);
+ }
+
+ // handle:
+ // CI2 << X == CI
+ // CI2 << X != CI
+ //
+ // where CI2 is a power of 2 and CI isn't
+ if (auto *CI = dyn_cast<ConstantInt>(RHS)) {
+ const APInt *CI2Val, *CIVal = &CI->getValue();
+ if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) &&
+ CI2Val->isPowerOf2()) {
+ if (!CIVal->isPowerOf2()) {
+ // CI2 << X can equal zero in some circumstances,
+ // this simplification is unsafe if CI is zero.
+ //
+ // We know it is safe if:
+ // - The shift is nsw, we can't shift out the one bit.
+ // - The shift is nuw, we can't shift out the one bit.
+ // - CI2 is one
+ // - CI isn't zero
+ if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) ||
+ Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) ||
+ CI2Val->isOneValue() || !CI->isZero()) {
+ if (Pred == ICmpInst::ICMP_EQ)
+ return ConstantInt::getFalse(RHS->getContext());
+ if (Pred == ICmpInst::ICMP_NE)
+ return ConstantInt::getTrue(RHS->getContext());
+ }
+ }
+ if (CIVal->isSignMask() && CI2Val->isOneValue()) {
+ if (Pred == ICmpInst::ICMP_UGT)
+ return ConstantInt::getFalse(RHS->getContext());
+ if (Pred == ICmpInst::ICMP_ULE)
+ return ConstantInt::getTrue(RHS->getContext());
+ }
+ }
+ }
+
+ if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() &&
+ LBO->getOperand(1) == RBO->getOperand(1)) {
+ switch (LBO->getOpcode()) {
+ default:
+ break;
+ case Instruction::UDiv:
+ case Instruction::LShr:
+ if (ICmpInst::isSigned(Pred) || !Q.IIQ.isExact(LBO) ||
+ !Q.IIQ.isExact(RBO))
+ break;
+ if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+ RBO->getOperand(0), Q, MaxRecurse - 1))
+ return V;
+ break;
+ case Instruction::SDiv:
+ if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) ||
+ !Q.IIQ.isExact(RBO))
+ break;
+ if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+ RBO->getOperand(0), Q, MaxRecurse - 1))
+ return V;
+ break;
+ case Instruction::AShr:
+ if (!Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO))
+ break;
+ if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+ RBO->getOperand(0), Q, MaxRecurse - 1))
+ return V;
+ break;
+ case Instruction::Shl: {
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO);
+ bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO);
+ if (!NUW && !NSW)
+ break;
+ if (!NSW && ICmpInst::isSigned(Pred))
+ break;
+ if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+ RBO->getOperand(0), Q, MaxRecurse - 1))
+ return V;
+ break;
+ }
+ }
+ }
+ return nullptr;
+}
+
+/// Simplify integer comparisons where at least one operand of the compare
+/// matches an integer min/max idiom.
+static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
+ Value *RHS, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ Type *ITy = GetCompareTy(LHS); // The return type.
+ Value *A, *B;
+ CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE;
+ CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B".
+
+ // Signed variants on "max(a,b)>=a -> true".
+ if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) {
+ if (A != RHS)
+ std::swap(A, B); // smax(A, B) pred A.
+ EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B".
+ // We analyze this as smax(A, B) pred A.
+ P = Pred;
+ } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) &&
+ (A == LHS || B == LHS)) {
+ if (A != LHS)
+ std::swap(A, B); // A pred smax(A, B).
+ EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B".
+ // We analyze this as smax(A, B) swapped-pred A.
+ P = CmpInst::getSwappedPredicate(Pred);
+ } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) &&
+ (A == RHS || B == RHS)) {
+ if (A != RHS)
+ std::swap(A, B); // smin(A, B) pred A.
+ EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B".
+ // We analyze this as smax(-A, -B) swapped-pred -A.
+ // Note that we do not need to actually form -A or -B thanks to EqP.
+ P = CmpInst::getSwappedPredicate(Pred);
+ } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) &&
+ (A == LHS || B == LHS)) {
+ if (A != LHS)
+ std::swap(A, B); // A pred smin(A, B).
+ EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B".
+ // We analyze this as smax(-A, -B) pred -A.
+ // Note that we do not need to actually form -A or -B thanks to EqP.
+ P = Pred;
+ }
+ if (P != CmpInst::BAD_ICMP_PREDICATE) {
+ // Cases correspond to "max(A, B) p A".
+ switch (P) {
+ default:
+ break;
+ case CmpInst::ICMP_EQ:
+ case CmpInst::ICMP_SLE:
+ // Equivalent to "A EqP B". This may be the same as the condition tested
+ // in the max/min; if so, we can just return that.
+ if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B))
+ return V;
+ if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B))
+ return V;
+ // Otherwise, see if "A EqP B" simplifies.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1))
+ return V;
+ break;
+ case CmpInst::ICMP_NE:
+ case CmpInst::ICMP_SGT: {
+ CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP);
+ // Equivalent to "A InvEqP B". This may be the same as the condition
+ // tested in the max/min; if so, we can just return that.
+ if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B))
+ return V;
+ if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B))
+ return V;
+ // Otherwise, see if "A InvEqP B" simplifies.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1))
+ return V;
+ break;
+ }
+ case CmpInst::ICMP_SGE:
+ // Always true.
+ return getTrue(ITy);
+ case CmpInst::ICMP_SLT:
+ // Always false.
+ return getFalse(ITy);
+ }
+ }
+
+ // Unsigned variants on "max(a,b)>=a -> true".
+ P = CmpInst::BAD_ICMP_PREDICATE;
+ if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) {
+ if (A != RHS)
+ std::swap(A, B); // umax(A, B) pred A.
+ EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B".
+ // We analyze this as umax(A, B) pred A.
+ P = Pred;
+ } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) &&
+ (A == LHS || B == LHS)) {
+ if (A != LHS)
+ std::swap(A, B); // A pred umax(A, B).
+ EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B".
+ // We analyze this as umax(A, B) swapped-pred A.
+ P = CmpInst::getSwappedPredicate(Pred);
+ } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) &&
+ (A == RHS || B == RHS)) {
+ if (A != RHS)
+ std::swap(A, B); // umin(A, B) pred A.
+ EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B".
+ // We analyze this as umax(-A, -B) swapped-pred -A.
+ // Note that we do not need to actually form -A or -B thanks to EqP.
+ P = CmpInst::getSwappedPredicate(Pred);
+ } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) &&
+ (A == LHS || B == LHS)) {
+ if (A != LHS)
+ std::swap(A, B); // A pred umin(A, B).
+ EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B".
+ // We analyze this as umax(-A, -B) pred -A.
+ // Note that we do not need to actually form -A or -B thanks to EqP.
+ P = Pred;
+ }
+ if (P != CmpInst::BAD_ICMP_PREDICATE) {
+ // Cases correspond to "max(A, B) p A".
+ switch (P) {
+ default:
+ break;
+ case CmpInst::ICMP_EQ:
+ case CmpInst::ICMP_ULE:
+ // Equivalent to "A EqP B". This may be the same as the condition tested
+ // in the max/min; if so, we can just return that.
+ if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B))
+ return V;
+ if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B))
+ return V;
+ // Otherwise, see if "A EqP B" simplifies.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1))
+ return V;
+ break;
+ case CmpInst::ICMP_NE:
+ case CmpInst::ICMP_UGT: {
+ CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP);
+ // Equivalent to "A InvEqP B". This may be the same as the condition
+ // tested in the max/min; if so, we can just return that.
+ if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B))
+ return V;
+ if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B))
+ return V;
+ // Otherwise, see if "A InvEqP B" simplifies.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1))
+ return V;
+ break;
+ }
+ case CmpInst::ICMP_UGE:
+ // Always true.
+ return getTrue(ITy);
+ case CmpInst::ICMP_ULT:
+ // Always false.
+ return getFalse(ITy);
+ }
+ }
+
+ // Variants on "max(x,y) >= min(x,z)".
+ Value *C, *D;
+ if (match(LHS, m_SMax(m_Value(A), m_Value(B))) &&
+ match(RHS, m_SMin(m_Value(C), m_Value(D))) &&
+ (A == C || A == D || B == C || B == D)) {
+ // max(x, ?) pred min(x, ?).
+ if (Pred == CmpInst::ICMP_SGE)
+ // Always true.
+ return getTrue(ITy);
+ if (Pred == CmpInst::ICMP_SLT)
+ // Always false.
+ return getFalse(ITy);
+ } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) &&
+ match(RHS, m_SMax(m_Value(C), m_Value(D))) &&
+ (A == C || A == D || B == C || B == D)) {
+ // min(x, ?) pred max(x, ?).
+ if (Pred == CmpInst::ICMP_SLE)
+ // Always true.
+ return getTrue(ITy);
+ if (Pred == CmpInst::ICMP_SGT)
+ // Always false.
+ return getFalse(ITy);
+ } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) &&
+ match(RHS, m_UMin(m_Value(C), m_Value(D))) &&
+ (A == C || A == D || B == C || B == D)) {
+ // max(x, ?) pred min(x, ?).
+ if (Pred == CmpInst::ICMP_UGE)
+ // Always true.
+ return getTrue(ITy);
+ if (Pred == CmpInst::ICMP_ULT)
+ // Always false.
+ return getFalse(ITy);
+ } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) &&
+ match(RHS, m_UMax(m_Value(C), m_Value(D))) &&
+ (A == C || A == D || B == C || B == D)) {
+ // min(x, ?) pred max(x, ?).
+ if (Pred == CmpInst::ICMP_ULE)
+ // Always true.
+ return getTrue(ITy);
+ if (Pred == CmpInst::ICMP_UGT)
+ // Always false.
+ return getFalse(ITy);
+ }
+
+ return nullptr;
+}
+
+/// Given operands for an ICmpInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
+ assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!");
+
+ if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
+ if (Constant *CRHS = dyn_cast<Constant>(RHS))
+ return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI);
+
+ // If we have a constant, make sure it is on the RHS.
+ std::swap(LHS, RHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ }
+ assert(!isa<UndefValue>(LHS) && "Unexpected icmp undef,%X");
+
+ Type *ITy = GetCompareTy(LHS); // The return type.
+
+ // For EQ and NE, we can always pick a value for the undef to make the
+ // predicate pass or fail, so we can return undef.
+ // Matches behavior in llvm::ConstantFoldCompareInstruction.
+ if (isa<UndefValue>(RHS) && ICmpInst::isEquality(Pred))
+ return UndefValue::get(ITy);
+
+ // icmp X, X -> true/false
+ // icmp X, undef -> true/false because undef could be X.
+ if (LHS == RHS || isa<UndefValue>(RHS))
+ return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred));
+
+ if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q))
+ return V;
+
+ if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q))
+ return V;
+
+ if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS, Q.IIQ))
+ return V;
+
+ // If both operands have range metadata, use the metadata
+ // to simplify the comparison.
+ if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) {
+ auto RHS_Instr = cast<Instruction>(RHS);
+ auto LHS_Instr = cast<Instruction>(LHS);
+
+ if (Q.IIQ.getMetadata(RHS_Instr, LLVMContext::MD_range) &&
+ Q.IIQ.getMetadata(LHS_Instr, LLVMContext::MD_range)) {
+ auto RHS_CR = getConstantRangeFromMetadata(
+ *RHS_Instr->getMetadata(LLVMContext::MD_range));
+ auto LHS_CR = getConstantRangeFromMetadata(
+ *LHS_Instr->getMetadata(LLVMContext::MD_range));
+
+ auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
+ if (Satisfied_CR.contains(LHS_CR))
+ return ConstantInt::getTrue(RHS->getContext());
+
+ auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
+ CmpInst::getInversePredicate(Pred), RHS_CR);
+ if (InversedSatisfied_CR.contains(LHS_CR))
+ return ConstantInt::getFalse(RHS->getContext());
+ }
+ }
+
+ // Compare of cast, for example (zext X) != 0 -> X != 0
+ if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
+ Instruction *LI = cast<CastInst>(LHS);
+ Value *SrcOp = LI->getOperand(0);
+ Type *SrcTy = SrcOp->getType();
+ Type *DstTy = LI->getType();
+
+ // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input
+ // if the integer type is the same size as the pointer type.
+ if (MaxRecurse && isa<PtrToIntInst>(LI) &&
+ Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) {
+ if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
+ // Transfer the cast to the constant.
+ if (Value *V = SimplifyICmpInst(Pred, SrcOp,
+ ConstantExpr::getIntToPtr(RHSC, SrcTy),
+ Q, MaxRecurse-1))
+ return V;
+ } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) {
+ if (RI->getOperand(0)->getType() == SrcTy)
+ // Compare without the cast.
+ if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
+ Q, MaxRecurse-1))
+ return V;
+ }
+ }
+
+ if (isa<ZExtInst>(LHS)) {
+ // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the
+ // same type.
+ if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) {
+ if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
+ // Compare X and Y. Note that signed predicates become unsigned.
+ if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
+ SrcOp, RI->getOperand(0), Q,
+ MaxRecurse-1))
+ return V;
+ }
+ // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended
+ // too. If not, then try to deduce the result of the comparison.
+ else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+ // Compute the constant that would happen if we truncated to SrcTy then
+ // reextended to DstTy.
+ Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+ Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
+
+ // If the re-extended constant didn't change then this is effectively
+ // also a case of comparing two zero-extended values.
+ if (RExt == CI && MaxRecurse)
+ if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
+ SrcOp, Trunc, Q, MaxRecurse-1))
+ return V;
+
+ // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit
+ // there. Use this to work out the result of the comparison.
+ if (RExt != CI) {
+ switch (Pred) {
+ default: llvm_unreachable("Unknown ICmp predicate!");
+ // LHS <u RHS.
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ return ConstantInt::getFalse(CI->getContext());
+
+ case ICmpInst::ICMP_NE:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ return ConstantInt::getTrue(CI->getContext());
+
+ // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS
+ // is non-negative then LHS <s RHS.
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ return CI->getValue().isNegative() ?
+ ConstantInt::getTrue(CI->getContext()) :
+ ConstantInt::getFalse(CI->getContext());
+
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ return CI->getValue().isNegative() ?
+ ConstantInt::getFalse(CI->getContext()) :
+ ConstantInt::getTrue(CI->getContext());
+ }
+ }
+ }
+ }
+
+ if (isa<SExtInst>(LHS)) {
+ // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the
+ // same type.
+ if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) {
+ if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
+ // Compare X and Y. Note that the predicate does not change.
+ if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
+ Q, MaxRecurse-1))
+ return V;
+ }
+ // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
+ // too. If not, then try to deduce the result of the comparison.
+ else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+ // Compute the constant that would happen if we truncated to SrcTy then
+ // reextended to DstTy.
+ Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+ Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
+
+ // If the re-extended constant didn't change then this is effectively
+ // also a case of comparing two sign-extended values.
+ if (RExt == CI && MaxRecurse)
+ if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1))
+ return V;
+
+ // Otherwise the upper bits of LHS are all equal, while RHS has varying
+ // bits there. Use this to work out the result of the comparison.
+ if (RExt != CI) {
+ switch (Pred) {
+ default: llvm_unreachable("Unknown ICmp predicate!");
+ case ICmpInst::ICMP_EQ:
+ return ConstantInt::getFalse(CI->getContext());
+ case ICmpInst::ICMP_NE:
+ return ConstantInt::getTrue(CI->getContext());
+
+ // If RHS is non-negative then LHS <s RHS. If RHS is negative then
+ // LHS >s RHS.
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ return CI->getValue().isNegative() ?
+ ConstantInt::getTrue(CI->getContext()) :
+ ConstantInt::getFalse(CI->getContext());
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ return CI->getValue().isNegative() ?
+ ConstantInt::getFalse(CI->getContext()) :
+ ConstantInt::getTrue(CI->getContext());
+
+ // If LHS is non-negative then LHS <u RHS. If LHS is negative then
+ // LHS >u RHS.
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ // Comparison is true iff the LHS <s 0.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp,
+ Constant::getNullValue(SrcTy),
+ Q, MaxRecurse-1))
+ return V;
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ // Comparison is true iff the LHS >=s 0.
+ if (MaxRecurse)
+ if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp,
+ Constant::getNullValue(SrcTy),
+ Q, MaxRecurse-1))
+ return V;
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ // icmp eq|ne X, Y -> false|true if X != Y
+ if (ICmpInst::isEquality(Pred) &&
+ isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) {
+ return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy);
+ }
+
+ if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ // Simplify comparisons of related pointers using a powerful, recursive
+ // GEP-walk when we have target data available..
+ if (LHS->getType()->isPointerTy())
+ if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
+ Q.IIQ, LHS, RHS))
+ return C;
+ if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS))
+ if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS))
+ if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) ==
+ Q.DL.getTypeSizeInBits(CLHS->getType()) &&
+ Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) ==
+ Q.DL.getTypeSizeInBits(CRHS->getType()))
+ if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
+ Q.IIQ, CLHS->getPointerOperand(),
+ CRHS->getPointerOperand()))
+ return C;
+
+ if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) {
+ if (GEPOperator *GRHS = dyn_cast<GEPOperator>(RHS)) {
+ if (GLHS->getPointerOperand() == GRHS->getPointerOperand() &&
+ GLHS->hasAllConstantIndices() && GRHS->hasAllConstantIndices() &&
+ (ICmpInst::isEquality(Pred) ||
+ (GLHS->isInBounds() && GRHS->isInBounds() &&
+ Pred == ICmpInst::getSignedPredicate(Pred)))) {
+ // The bases are equal and the indices are constant. Build a constant
+ // expression GEP with the same indices and a null base pointer to see
+ // what constant folding can make out of it.
+ Constant *Null = Constant::getNullValue(GLHS->getPointerOperandType());
+ SmallVector<Value *, 4> IndicesLHS(GLHS->idx_begin(), GLHS->idx_end());
+ Constant *NewLHS = ConstantExpr::getGetElementPtr(
+ GLHS->getSourceElementType(), Null, IndicesLHS);
+
+ SmallVector<Value *, 4> IndicesRHS(GRHS->idx_begin(), GRHS->idx_end());
+ Constant *NewRHS = ConstantExpr::getGetElementPtr(
+ GLHS->getSourceElementType(), Null, IndicesRHS);
+ return ConstantExpr::getICmp(Pred, NewLHS, NewRHS);
+ }
+ }
+ }
+
+ // If the comparison is with the result of a select instruction, check whether
+ // comparing with either branch of the select always yields the same value.
+ if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS))
+ if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ // If the comparison is with the result of a phi instruction, check whether
+ // doing the compare with each incoming phi value yields a common result.
+ if (isa<PHINode>(LHS) || isa<PHINode>(RHS))
+ if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q) {
+ return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit);
+}
+
+/// Given operands for an FCmpInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ FastMathFlags FMF, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
+ assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!");
+
+ if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
+ if (Constant *CRHS = dyn_cast<Constant>(RHS))
+ return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI);
+
+ // If we have a constant, make sure it is on the RHS.
+ std::swap(LHS, RHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ }
+
+ // Fold trivial predicates.
+ Type *RetTy = GetCompareTy(LHS);
+ if (Pred == FCmpInst::FCMP_FALSE)
+ return getFalse(RetTy);
+ if (Pred == FCmpInst::FCMP_TRUE)
+ return getTrue(RetTy);
+
+ // Fold (un)ordered comparison if we can determine there are no NaNs.
+ if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD)
+ if (FMF.noNaNs() ||
+ (isKnownNeverNaN(LHS, Q.TLI) && isKnownNeverNaN(RHS, Q.TLI)))
+ return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD);
+
+ // NaN is unordered; NaN is not ordered.
+ assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) &&
+ "Comparison must be either ordered or unordered");
+ if (match(RHS, m_NaN()))
+ return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred));
+
+ // fcmp pred x, undef and fcmp pred undef, x
+ // fold to true if unordered, false if ordered
+ if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) {
+ // Choosing NaN for the undef will always make unordered comparison succeed
+ // and ordered comparison fail.
+ return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred));
+ }
+
+ // fcmp x,x -> true/false. Not all compares are foldable.
+ if (LHS == RHS) {
+ if (CmpInst::isTrueWhenEqual(Pred))
+ return getTrue(RetTy);
+ if (CmpInst::isFalseWhenEqual(Pred))
+ return getFalse(RetTy);
+ }
+
+ // Handle fcmp with constant RHS.
+ // TODO: Use match with a specific FP value, so these work with vectors with
+ // undef lanes.
+ const APFloat *C;
+ if (match(RHS, m_APFloat(C))) {
+ // Check whether the constant is an infinity.
+ if (C->isInfinity()) {
+ if (C->isNegative()) {
+ switch (Pred) {
+ case FCmpInst::FCMP_OLT:
+ // No value is ordered and less than negative infinity.
+ return getFalse(RetTy);
+ case FCmpInst::FCMP_UGE:
+ // All values are unordered with or at least negative infinity.
+ return getTrue(RetTy);
+ default:
+ break;
+ }
+ } else {
+ switch (Pred) {
+ case FCmpInst::FCMP_OGT:
+ // No value is ordered and greater than infinity.
+ return getFalse(RetTy);
+ case FCmpInst::FCMP_ULE:
+ // All values are unordered with and at most infinity.
+ return getTrue(RetTy);
+ default:
+ break;
+ }
+ }
+ }
+ if (C->isNegative() && !C->isNegZero()) {
+ assert(!C->isNaN() && "Unexpected NaN constant!");
+ // TODO: We can catch more cases by using a range check rather than
+ // relying on CannotBeOrderedLessThanZero.
+ switch (Pred) {
+ case FCmpInst::FCMP_UGE:
+ case FCmpInst::FCMP_UGT:
+ case FCmpInst::FCMP_UNE:
+ // (X >= 0) implies (X > C) when (C < 0)
+ if (CannotBeOrderedLessThanZero(LHS, Q.TLI))
+ return getTrue(RetTy);
+ break;
+ case FCmpInst::FCMP_OEQ:
+ case FCmpInst::FCMP_OLE:
+ case FCmpInst::FCMP_OLT:
+ // (X >= 0) implies !(X < C) when (C < 0)
+ if (CannotBeOrderedLessThanZero(LHS, Q.TLI))
+ return getFalse(RetTy);
+ break;
+ default:
+ break;
+ }
+ }
+
+ // Check comparison of [minnum/maxnum with constant] with other constant.
+ const APFloat *C2;
+ if ((match(LHS, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_APFloat(C2))) &&
+ C2->compare(*C) == APFloat::cmpLessThan) ||
+ (match(LHS, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_APFloat(C2))) &&
+ C2->compare(*C) == APFloat::cmpGreaterThan)) {
+ bool IsMaxNum =
+ cast<IntrinsicInst>(LHS)->getIntrinsicID() == Intrinsic::maxnum;
+ // The ordered relationship and minnum/maxnum guarantee that we do not
+ // have NaN constants, so ordered/unordered preds are handled the same.
+ switch (Pred) {
+ case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_UEQ:
+ // minnum(X, LesserC) == C --> false
+ // maxnum(X, GreaterC) == C --> false
+ return getFalse(RetTy);
+ case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE:
+ // minnum(X, LesserC) != C --> true
+ // maxnum(X, GreaterC) != C --> true
+ return getTrue(RetTy);
+ case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_UGE:
+ case FCmpInst::FCMP_OGT: case FCmpInst::FCMP_UGT:
+ // minnum(X, LesserC) >= C --> false
+ // minnum(X, LesserC) > C --> false
+ // maxnum(X, GreaterC) >= C --> true
+ // maxnum(X, GreaterC) > C --> true
+ return ConstantInt::get(RetTy, IsMaxNum);
+ case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_ULE:
+ case FCmpInst::FCMP_OLT: case FCmpInst::FCMP_ULT:
+ // minnum(X, LesserC) <= C --> true
+ // minnum(X, LesserC) < C --> true
+ // maxnum(X, GreaterC) <= C --> false
+ // maxnum(X, GreaterC) < C --> false
+ return ConstantInt::get(RetTy, !IsMaxNum);
+ default:
+ // TRUE/FALSE/ORD/UNO should be handled before this.
+ llvm_unreachable("Unexpected fcmp predicate");
+ }
+ }
+ }
+
+ if (match(RHS, m_AnyZeroFP())) {
+ switch (Pred) {
+ case FCmpInst::FCMP_OGE:
+ case FCmpInst::FCMP_ULT:
+ // Positive or zero X >= 0.0 --> true
+ // Positive or zero X < 0.0 --> false
+ if ((FMF.noNaNs() || isKnownNeverNaN(LHS, Q.TLI)) &&
+ CannotBeOrderedLessThanZero(LHS, Q.TLI))
+ return Pred == FCmpInst::FCMP_OGE ? getTrue(RetTy) : getFalse(RetTy);
+ break;
+ case FCmpInst::FCMP_UGE:
+ case FCmpInst::FCMP_OLT:
+ // Positive or zero or nan X >= 0.0 --> true
+ // Positive or zero or nan X < 0.0 --> false
+ if (CannotBeOrderedLessThanZero(LHS, Q.TLI))
+ return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy);
+ break;
+ default:
+ break;
+ }
+ }
+
+ // If the comparison is with the result of a select instruction, check whether
+ // comparing with either branch of the select always yields the same value.
+ if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS))
+ if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ // If the comparison is with the result of a phi instruction, check whether
+ // doing the compare with each incoming phi value yields a common result.
+ if (isa<PHINode>(LHS) || isa<PHINode>(RHS))
+ if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ FastMathFlags FMF, const SimplifyQuery &Q) {
+ return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
+}
+
+/// See if V simplifies when its operand Op is replaced with RepOp.
+static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+ const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // Trivial replacement.
+ if (V == Op)
+ return RepOp;
+
+ // We cannot replace a constant, and shouldn't even try.
+ if (isa<Constant>(Op))
+ return nullptr;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return nullptr;
+
+ // If this is a binary operator, try to simplify it with the replaced op.
+ if (auto *B = dyn_cast<BinaryOperator>(I)) {
+ // Consider:
+ // %cmp = icmp eq i32 %x, 2147483647
+ // %add = add nsw i32 %x, 1
+ // %sel = select i1 %cmp, i32 -2147483648, i32 %add
+ //
+ // We can't replace %sel with %add unless we strip away the flags.
+ // TODO: This is an unusual limitation because better analysis results in
+ // worse simplification. InstCombine can do this fold more generally
+ // by dropping the flags. Remove this fold to save compile-time?
+ if (isa<OverflowingBinaryOperator>(B))
+ if (Q.IIQ.hasNoSignedWrap(B) || Q.IIQ.hasNoUnsignedWrap(B))
+ return nullptr;
+ if (isa<PossiblyExactOperator>(B) && Q.IIQ.isExact(B))
+ return nullptr;
+
+ if (MaxRecurse) {
+ if (B->getOperand(0) == Op)
+ return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q,
+ MaxRecurse - 1);
+ if (B->getOperand(1) == Op)
+ return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, Q,
+ MaxRecurse - 1);
+ }
+ }
+
+ // Same for CmpInsts.
+ if (CmpInst *C = dyn_cast<CmpInst>(I)) {
+ if (MaxRecurse) {
+ if (C->getOperand(0) == Op)
+ return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), Q,
+ MaxRecurse - 1);
+ if (C->getOperand(1) == Op)
+ return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, Q,
+ MaxRecurse - 1);
+ }
+ }
+
+ // Same for GEPs.
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
+ if (MaxRecurse) {
+ SmallVector<Value *, 8> NewOps(GEP->getNumOperands());
+ transform(GEP->operands(), NewOps.begin(),
+ [&](Value *V) { return V == Op ? RepOp : V; });
+ return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q,
+ MaxRecurse - 1);
+ }
+ }
+
+ // TODO: We could hand off more cases to instsimplify here.
+
+ // If all operands are constant after substituting Op for RepOp then we can
+ // constant fold the instruction.
+ if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
+ // Build a list of all constant operands.
+ SmallVector<Constant *, 8> ConstOps;
+ for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+ if (I->getOperand(i) == Op)
+ ConstOps.push_back(CRepOp);
+ else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
+ ConstOps.push_back(COp);
+ else
+ break;
+ }
+
+ // All operands were constants, fold it.
+ if (ConstOps.size() == I->getNumOperands()) {
+ if (CmpInst *C = dyn_cast<CmpInst>(I))
+ return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+ ConstOps[1], Q.DL, Q.TLI);
+
+ if (LoadInst *LI = dyn_cast<LoadInst>(I))
+ if (!LI->isVolatile())
+ return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL);
+
+ return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
+ }
+ }
+
+ return nullptr;
+}
+
+/// Try to simplify a select instruction when its condition operand is an
+/// integer comparison where one operand of the compare is a constant.
+static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
+ const APInt *Y, bool TrueWhenUnset) {
+ const APInt *C;
+
+ // (X & Y) == 0 ? X & ~Y : X --> X
+ // (X & Y) != 0 ? X & ~Y : X --> X & ~Y
+ if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) &&
+ *Y == ~*C)
+ return TrueWhenUnset ? FalseVal : TrueVal;
+
+ // (X & Y) == 0 ? X : X & ~Y --> X & ~Y
+ // (X & Y) != 0 ? X : X & ~Y --> X
+ if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) &&
+ *Y == ~*C)
+ return TrueWhenUnset ? FalseVal : TrueVal;
+
+ if (Y->isPowerOf2()) {
+ // (X & Y) == 0 ? X | Y : X --> X | Y
+ // (X & Y) != 0 ? X | Y : X --> X
+ if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) &&
+ *Y == *C)
+ return TrueWhenUnset ? TrueVal : FalseVal;
+
+ // (X & Y) == 0 ? X : X | Y --> X
+ // (X & Y) != 0 ? X : X | Y --> X | Y
+ if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) &&
+ *Y == *C)
+ return TrueWhenUnset ? TrueVal : FalseVal;
+ }
+
+ return nullptr;
+}
+
+/// An alternative way to test if a bit is set or not uses sgt/slt instead of
+/// eq/ne.
+static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
+ ICmpInst::Predicate Pred,
+ Value *TrueVal, Value *FalseVal) {
+ Value *X;
+ APInt Mask;
+ if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask))
+ return nullptr;
+
+ return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask,
+ Pred == ICmpInst::ICMP_EQ);
+}
+
+/// Try to simplify a select instruction when its condition operand is an
+/// integer comparison.
+static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
+ Value *FalseVal, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ ICmpInst::Predicate Pred;
+ Value *CmpLHS, *CmpRHS;
+ if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
+ return nullptr;
+
+ if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) {
+ Value *X;
+ const APInt *Y;
+ if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y))))
+ if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y,
+ Pred == ICmpInst::ICMP_EQ))
+ return V;
+
+ // Test for a bogus zero-shift-guard-op around funnel-shift or rotate.
+ Value *ShAmt;
+ auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(),
+ m_Value(ShAmt)),
+ m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X),
+ m_Value(ShAmt)));
+ // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X
+ // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X
+ if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt &&
+ Pred == ICmpInst::ICMP_EQ)
+ return X;
+ // (ShAmt != 0) ? X : fshl(X, *, ShAmt) --> X
+ // (ShAmt != 0) ? X : fshr(*, X, ShAmt) --> X
+ if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt &&
+ Pred == ICmpInst::ICMP_NE)
+ return X;
+
+ // Test for a zero-shift-guard-op around rotates. These are used to
+ // avoid UB from oversized shifts in raw IR rotate patterns, but the
+ // intrinsics do not have that problem.
+ // We do not allow this transform for the general funnel shift case because
+ // that would not preserve the poison safety of the original code.
+ auto isRotate = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X),
+ m_Deferred(X),
+ m_Value(ShAmt)),
+ m_Intrinsic<Intrinsic::fshr>(m_Value(X),
+ m_Deferred(X),
+ m_Value(ShAmt)));
+ // (ShAmt != 0) ? fshl(X, X, ShAmt) : X --> fshl(X, X, ShAmt)
+ // (ShAmt != 0) ? fshr(X, X, ShAmt) : X --> fshr(X, X, ShAmt)
+ if (match(TrueVal, isRotate) && FalseVal == X && CmpLHS == ShAmt &&
+ Pred == ICmpInst::ICMP_NE)
+ return TrueVal;
+ // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt)
+ // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt)
+ if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt &&
+ Pred == ICmpInst::ICMP_EQ)
+ return FalseVal;
+ }
+
+ // Check for other compares that behave like bit test.
+ if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred,
+ TrueVal, FalseVal))
+ return V;
+
+ // If we have an equality comparison, then we know the value in one of the
+ // arms of the select. See if substituting this value into the arm and
+ // simplifying the result yields the same value as the other arm.
+ if (Pred == ICmpInst::ICMP_EQ) {
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+ TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+ TrueVal)
+ return FalseVal;
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+ FalseVal ||
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+ FalseVal)
+ return FalseVal;
+ } else if (Pred == ICmpInst::ICMP_NE) {
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+ FalseVal ||
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+ FalseVal)
+ return TrueVal;
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+ TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+ TrueVal)
+ return TrueVal;
+ }
+
+ return nullptr;
+}
+
+/// Try to simplify a select instruction when its condition operand is a
+/// floating-point comparison.
+static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F) {
+ FCmpInst::Predicate Pred;
+ if (!match(Cond, m_FCmp(Pred, m_Specific(T), m_Specific(F))) &&
+ !match(Cond, m_FCmp(Pred, m_Specific(F), m_Specific(T))))
+ return nullptr;
+
+ // TODO: The transform may not be valid with -0.0. An incomplete way of
+ // testing for that possibility is to check if at least one operand is a
+ // non-zero constant.
+ const APFloat *C;
+ if ((match(T, m_APFloat(C)) && C->isNonZero()) ||
+ (match(F, m_APFloat(C)) && C->isNonZero())) {
+ // (T == F) ? T : F --> F
+ // (F == T) ? T : F --> F
+ if (Pred == FCmpInst::FCMP_OEQ)
+ return F;
+
+ // (T != F) ? T : F --> T
+ // (F != T) ? T : F --> T
+ if (Pred == FCmpInst::FCMP_UNE)
+ return T;
+ }
+
+ return nullptr;
+}
+
+/// Given operands for a SelectInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (auto *CondC = dyn_cast<Constant>(Cond)) {
+ if (auto *TrueC = dyn_cast<Constant>(TrueVal))
+ if (auto *FalseC = dyn_cast<Constant>(FalseVal))
+ return ConstantFoldSelectInstruction(CondC, TrueC, FalseC);
+
+ // select undef, X, Y -> X or Y
+ if (isa<UndefValue>(CondC))
+ return isa<Constant>(FalseVal) ? FalseVal : TrueVal;
+
+ // TODO: Vector constants with undef elements don't simplify.
+
+ // select true, X, Y -> X
+ if (CondC->isAllOnesValue())
+ return TrueVal;
+ // select false, X, Y -> Y
+ if (CondC->isNullValue())
+ return FalseVal;
+ }
+
+ // select ?, X, X -> X
+ if (TrueVal == FalseVal)
+ return TrueVal;
+
+ if (isa<UndefValue>(TrueVal)) // select ?, undef, X -> X
+ return FalseVal;
+ if (isa<UndefValue>(FalseVal)) // select ?, X, undef -> X
+ return TrueVal;
+
+ if (Value *V =
+ simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
+ return V;
+
+ if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal))
+ return V;
+
+ if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal))
+ return V;
+
+ Optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL);
+ if (Imp)
+ return *Imp ? TrueVal : FalseVal;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
+ const SimplifyQuery &Q) {
+ return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit);
+}
+
+/// Given operands for an GetElementPtrInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,
+ const SimplifyQuery &Q, unsigned) {
+ // The type of the GEP pointer operand.
+ unsigned AS =
+ cast<PointerType>(Ops[0]->getType()->getScalarType())->getAddressSpace();
+
+ // getelementptr P -> P.
+ if (Ops.size() == 1)
+ return Ops[0];
+
+ // Compute the (pointer) type returned by the GEP instruction.
+ Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1));
+ Type *GEPTy = PointerType::get(LastType, AS);
+ if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType()))
+ GEPTy = VectorType::get(GEPTy, VT->getNumElements());
+ else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType()))
+ GEPTy = VectorType::get(GEPTy, VT->getNumElements());
+
+ if (isa<UndefValue>(Ops[0]))
+ return UndefValue::get(GEPTy);
+
+ if (Ops.size() == 2) {
+ // getelementptr P, 0 -> P.
+ if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy)
+ return Ops[0];
+
+ Type *Ty = SrcTy;
+ if (Ty->isSized()) {
+ Value *P;
+ uint64_t C;
+ uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty);
+ // getelementptr P, N -> P if P points to a type of zero size.
+ if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy)
+ return Ops[0];
+
+ // The following transforms are only safe if the ptrtoint cast
+ // doesn't truncate the pointers.
+ if (Ops[1]->getType()->getScalarSizeInBits() ==
+ Q.DL.getIndexSizeInBits(AS)) {
+ auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * {
+ if (match(P, m_Zero()))
+ return Constant::getNullValue(GEPTy);
+ Value *Temp;
+ if (match(P, m_PtrToInt(m_Value(Temp))))
+ if (Temp->getType() == GEPTy)
+ return Temp;
+ return nullptr;
+ };
+
+ // getelementptr V, (sub P, V) -> P if P points to a type of size 1.
+ if (TyAllocSize == 1 &&
+ match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0])))))
+ if (Value *R = PtrToIntOrZero(P))
+ return R;
+
+ // getelementptr V, (ashr (sub P, V), C) -> Q
+ // if P points to a type of size 1 << C.
+ if (match(Ops[1],
+ m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),
+ m_ConstantInt(C))) &&
+ TyAllocSize == 1ULL << C)
+ if (Value *R = PtrToIntOrZero(P))
+ return R;
+
+ // getelementptr V, (sdiv (sub P, V), C) -> Q
+ // if P points to a type of size C.
+ if (match(Ops[1],
+ m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),
+ m_SpecificInt(TyAllocSize))))
+ if (Value *R = PtrToIntOrZero(P))
+ return R;
+ }
+ }
+ }
+
+ if (Q.DL.getTypeAllocSize(LastType) == 1 &&
+ all_of(Ops.slice(1).drop_back(1),
+ [](Value *Idx) { return match(Idx, m_Zero()); })) {
+ unsigned IdxWidth =
+ Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace());
+ if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) {
+ APInt BasePtrOffset(IdxWidth, 0);
+ Value *StrippedBasePtr =
+ Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL,
+ BasePtrOffset);
+
+ // gep (gep V, C), (sub 0, V) -> C
+ if (match(Ops.back(),
+ m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) {
+ auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset);
+ return ConstantExpr::getIntToPtr(CI, GEPTy);
+ }
+ // gep (gep V, C), (xor V, -1) -> C-1
+ if (match(Ops.back(),
+ m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) {
+ auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1);
+ return ConstantExpr::getIntToPtr(CI, GEPTy);
+ }
+ }
+ }
+
+ // Check to see if this is constant foldable.
+ if (!all_of(Ops, [](Value *V) { return isa<Constant>(V); }))
+ return nullptr;
+
+ auto *CE = ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ops[0]),
+ Ops.slice(1));
+ if (auto *CEFolded = ConstantFoldConstant(CE, Q.DL))
+ return CEFolded;
+ return CE;
+}
+
+Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,
+ const SimplifyQuery &Q) {
+ return ::SimplifyGEPInst(SrcTy, Ops, Q, RecursionLimit);
+}
+
+/// Given operands for an InsertValueInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyInsertValueInst(Value *Agg, Value *Val,
+ ArrayRef<unsigned> Idxs, const SimplifyQuery &Q,
+ unsigned) {
+ if (Constant *CAgg = dyn_cast<Constant>(Agg))
+ if (Constant *CVal = dyn_cast<Constant>(Val))
+ return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs);
+
+ // insertvalue x, undef, n -> x
+ if (match(Val, m_Undef()))
+ return Agg;
+
+ // insertvalue x, (extractvalue y, n), n
+ if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Val))
+ if (EV->getAggregateOperand()->getType() == Agg->getType() &&
+ EV->getIndices() == Idxs) {
+ // insertvalue undef, (extractvalue y, n), n -> y
+ if (match(Agg, m_Undef()))
+ return EV->getAggregateOperand();
+
+ // insertvalue y, (extractvalue y, n), n -> y
+ if (Agg == EV->getAggregateOperand())
+ return Agg;
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val,
+ ArrayRef<unsigned> Idxs,
+ const SimplifyQuery &Q) {
+ return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit);
+}
+
+Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx,
+ const SimplifyQuery &Q) {
+ // Try to constant fold.
+ auto *VecC = dyn_cast<Constant>(Vec);
+ auto *ValC = dyn_cast<Constant>(Val);
+ auto *IdxC = dyn_cast<Constant>(Idx);
+ if (VecC && ValC && IdxC)
+ return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC);
+
+ // Fold into undef if index is out of bounds.
+ if (auto *CI = dyn_cast<ConstantInt>(Idx)) {
+ uint64_t NumElements = cast<VectorType>(Vec->getType())->getNumElements();
+ if (CI->uge(NumElements))
+ return UndefValue::get(Vec->getType());
+ }
+
+ // If index is undef, it might be out of bounds (see above case)
+ if (isa<UndefValue>(Idx))
+ return UndefValue::get(Vec->getType());
+
+ // Inserting an undef scalar? Assume it is the same value as the existing
+ // vector element.
+ if (isa<UndefValue>(Val))
+ return Vec;
+
+ // If we are extracting a value from a vector, then inserting it into the same
+ // place, that's the input vector:
+ // insertelt Vec, (extractelt Vec, Idx), Idx --> Vec
+ if (match(Val, m_ExtractElement(m_Specific(Vec), m_Specific(Idx))))
+ return Vec;
+
+ return nullptr;
+}
+
+/// Given operands for an ExtractValueInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+ const SimplifyQuery &, unsigned) {
+ if (auto *CAgg = dyn_cast<Constant>(Agg))
+ return ConstantFoldExtractValueInstruction(CAgg, Idxs);
+
+ // extractvalue x, (insertvalue y, elt, n), n -> elt
+ unsigned NumIdxs = Idxs.size();
+ for (auto *IVI = dyn_cast<InsertValueInst>(Agg); IVI != nullptr;
+ IVI = dyn_cast<InsertValueInst>(IVI->getAggregateOperand())) {
+ ArrayRef<unsigned> InsertValueIdxs = IVI->getIndices();
+ unsigned NumInsertValueIdxs = InsertValueIdxs.size();
+ unsigned NumCommonIdxs = std::min(NumInsertValueIdxs, NumIdxs);
+ if (InsertValueIdxs.slice(0, NumCommonIdxs) ==
+ Idxs.slice(0, NumCommonIdxs)) {
+ if (NumIdxs == NumInsertValueIdxs)
+ return IVI->getInsertedValueOperand();
+ break;
+ }
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+ const SimplifyQuery &Q) {
+ return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit);
+}
+
+/// Given operands for an ExtractElementInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &,
+ unsigned) {
+ if (auto *CVec = dyn_cast<Constant>(Vec)) {
+ if (auto *CIdx = dyn_cast<Constant>(Idx))
+ return ConstantFoldExtractElementInstruction(CVec, CIdx);
+
+ // The index is not relevant if our vector is a splat.
+ if (auto *Splat = CVec->getSplatValue())
+ return Splat;
+
+ if (isa<UndefValue>(Vec))
+ return UndefValue::get(Vec->getType()->getVectorElementType());
+ }
+
+ // If extracting a specified index from the vector, see if we can recursively
+ // find a previously computed scalar that was inserted into the vector.
+ if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {
+ if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements()))
+ // definitely out of bounds, thus undefined result
+ return UndefValue::get(Vec->getType()->getVectorElementType());
+ if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue()))
+ return Elt;
+ }
+
+ // An undef extract index can be arbitrarily chosen to be an out-of-range
+ // index value, which would result in the instruction being undef.
+ if (isa<UndefValue>(Idx))
+ return UndefValue::get(Vec->getType()->getVectorElementType());
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx,
+ const SimplifyQuery &Q) {
+ return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit);
+}
+
+/// See if we can fold the given phi. If not, returns null.
+static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) {
+ // If all of the PHI's incoming values are the same then replace the PHI node
+ // with the common value.
+ Value *CommonValue = nullptr;
+ bool HasUndefInput = false;
+ for (Value *Incoming : PN->incoming_values()) {
+ // If the incoming value is the phi node itself, it can safely be skipped.
+ if (Incoming == PN) continue;
+ if (isa<UndefValue>(Incoming)) {
+ // Remember that we saw an undef value, but otherwise ignore them.
+ HasUndefInput = true;
+ continue;
+ }
+ if (CommonValue && Incoming != CommonValue)
+ return nullptr; // Not the same, bail out.
+ CommonValue = Incoming;
+ }
+
+ // If CommonValue is null then all of the incoming values were either undef or
+ // equal to the phi node itself.
+ if (!CommonValue)
+ return UndefValue::get(PN->getType());
+
+ // If we have a PHI node like phi(X, undef, X), where X is defined by some
+ // instruction, we cannot return X as the result of the PHI node unless it
+ // dominates the PHI block.
+ if (HasUndefInput)
+ return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr;
+
+ return CommonValue;
+}
+
+static Value *SimplifyCastInst(unsigned CastOpc, Value *Op,
+ Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (auto *C = dyn_cast<Constant>(Op))
+ return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL);
+
+ if (auto *CI = dyn_cast<CastInst>(Op)) {
+ auto *Src = CI->getOperand(0);
+ Type *SrcTy = Src->getType();
+ Type *MidTy = CI->getType();
+ Type *DstTy = Ty;
+ if (Src->getType() == Ty) {
+ auto FirstOp = static_cast<Instruction::CastOps>(CI->getOpcode());
+ auto SecondOp = static_cast<Instruction::CastOps>(CastOpc);
+ Type *SrcIntPtrTy =
+ SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr;
+ Type *MidIntPtrTy =
+ MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr;
+ Type *DstIntPtrTy =
+ DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr;
+ if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy,
+ SrcIntPtrTy, MidIntPtrTy,
+ DstIntPtrTy) == Instruction::BitCast)
+ return Src;
+ }
+ }
+
+ // bitcast x -> x
+ if (CastOpc == Instruction::BitCast)
+ if (Op->getType() == Ty)
+ return Op;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
+ const SimplifyQuery &Q) {
+ return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit);
+}
+
+/// For the given destination element of a shuffle, peek through shuffles to
+/// match a root vector source operand that contains that element in the same
+/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s).
+static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1,
+ int MaskVal, Value *RootVec,
+ unsigned MaxRecurse) {
+ if (!MaxRecurse--)
+ return nullptr;
+
+ // Bail out if any mask value is undefined. That kind of shuffle may be
+ // simplified further based on demanded bits or other folds.
+ if (MaskVal == -1)
+ return nullptr;
+
+ // The mask value chooses which source operand we need to look at next.
+ int InVecNumElts = Op0->getType()->getVectorNumElements();
+ int RootElt = MaskVal;
+ Value *SourceOp = Op0;
+ if (MaskVal >= InVecNumElts) {
+ RootElt = MaskVal - InVecNumElts;
+ SourceOp = Op1;
+ }
+
+ // If the source operand is a shuffle itself, look through it to find the
+ // matching root vector.
+ if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) {
+ return foldIdentityShuffles(
+ DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1),
+ SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse);
+ }
+
+ // TODO: Look through bitcasts? What if the bitcast changes the vector element
+ // size?
+
+ // The source operand is not a shuffle. Initialize the root vector value for
+ // this shuffle if that has not been done yet.
+ if (!RootVec)
+ RootVec = SourceOp;
+
+ // Give up as soon as a source operand does not match the existing root value.
+ if (RootVec != SourceOp)
+ return nullptr;
+
+ // The element must be coming from the same lane in the source vector
+ // (although it may have crossed lanes in intermediate shuffles).
+ if (RootElt != DestElt)
+ return nullptr;
+
+ return RootVec;
+}
+
+static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask,
+ Type *RetTy, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ if (isa<UndefValue>(Mask))
+ return UndefValue::get(RetTy);
+
+ Type *InVecTy = Op0->getType();
+ unsigned MaskNumElts = Mask->getType()->getVectorNumElements();
+ unsigned InVecNumElts = InVecTy->getVectorNumElements();
+
+ SmallVector<int, 32> Indices;
+ ShuffleVectorInst::getShuffleMask(Mask, Indices);
+ assert(MaskNumElts == Indices.size() &&
+ "Size of Indices not same as number of mask elements?");
+
+ // Canonicalization: If mask does not select elements from an input vector,
+ // replace that input vector with undef.
+ bool MaskSelects0 = false, MaskSelects1 = false;
+ for (unsigned i = 0; i != MaskNumElts; ++i) {
+ if (Indices[i] == -1)
+ continue;
+ if ((unsigned)Indices[i] < InVecNumElts)
+ MaskSelects0 = true;
+ else
+ MaskSelects1 = true;
+ }
+ if (!MaskSelects0)
+ Op0 = UndefValue::get(InVecTy);
+ if (!MaskSelects1)
+ Op1 = UndefValue::get(InVecTy);
+
+ auto *Op0Const = dyn_cast<Constant>(Op0);
+ auto *Op1Const = dyn_cast<Constant>(Op1);
+
+ // If all operands are constant, constant fold the shuffle.
+ if (Op0Const && Op1Const)
+ return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask);
+
+ // Canonicalization: if only one input vector is constant, it shall be the
+ // second one.
+ if (Op0Const && !Op1Const) {
+ std::swap(Op0, Op1);
+ ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts);
+ }
+
+ // A shuffle of a splat is always the splat itself. Legal if the shuffle's
+ // value type is same as the input vectors' type.
+ if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
+ if (isa<UndefValue>(Op1) && RetTy == InVecTy &&
+ OpShuf->getMask()->getSplatValue())
+ return Op0;
+
+ // Don't fold a shuffle with undef mask elements. This may get folded in a
+ // better way using demanded bits or other analysis.
+ // TODO: Should we allow this?
+ if (find(Indices, -1) != Indices.end())
+ return nullptr;
+
+ // Check if every element of this shuffle can be mapped back to the
+ // corresponding element of a single root vector. If so, we don't need this
+ // shuffle. This handles simple identity shuffles as well as chains of
+ // shuffles that may widen/narrow and/or move elements across lanes and back.
+ Value *RootVec = nullptr;
+ for (unsigned i = 0; i != MaskNumElts; ++i) {
+ // Note that recursion is limited for each vector element, so if any element
+ // exceeds the limit, this will fail to simplify.
+ RootVec =
+ foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse);
+
+ // We can't replace a widening/narrowing shuffle with one of its operands.
+ if (!RootVec || RootVec->getType() != RetTy)
+ return nullptr;
+ }
+ return RootVec;
+}
+
+/// Given operands for a ShuffleVectorInst, fold the result or return null.
+Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask,
+ Type *RetTy, const SimplifyQuery &Q) {
+ return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit);
+}
+
+static Constant *foldConstant(Instruction::UnaryOps Opcode,
+ Value *&Op, const SimplifyQuery &Q) {
+ if (auto *C = dyn_cast<Constant>(Op))
+ return ConstantFoldUnaryOpOperand(Opcode, C, Q.DL);
+ return nullptr;
+}
+
+/// Given the operand for an FNeg, see if we can fold the result. If not, this
+/// returns null.
+static Value *simplifyFNegInst(Value *Op, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldConstant(Instruction::FNeg, Op, Q))
+ return C;
+
+ Value *X;
+ // fneg (fneg X) ==> X
+ if (match(Op, m_FNeg(m_Value(X))))
+ return X;
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyFNegInst(Value *Op, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::simplifyFNegInst(Op, FMF, Q, RecursionLimit);
+}
+
+static Constant *propagateNaN(Constant *In) {
+ // If the input is a vector with undef elements, just return a default NaN.
+ if (!In->isNaN())
+ return ConstantFP::getNaN(In->getType());
+
+ // Propagate the existing NaN constant when possible.
+ // TODO: Should we quiet a signaling NaN?
+ return In;
+}
+
+/// Perform folds that are common to any floating-point operation. This implies
+/// transforms based on undef/NaN because the operation itself makes no
+/// difference to the result.
+static Constant *simplifyFPOp(ArrayRef<Value *> Ops) {
+ if (any_of(Ops, [](Value *V) { return isa<UndefValue>(V); }))
+ return ConstantFP::getNaN(Ops[0]->getType());
+
+ for (Value *V : Ops)
+ if (match(V, m_NaN()))
+ return propagateNaN(cast<Constant>(V));
+
+ return nullptr;
+}
+
+/// Given operands for an FAdd, see if we can fold the result. If not, this
+/// returns null.
+static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
+ return C;
+
+ if (Constant *C = simplifyFPOp({Op0, Op1}))
+ return C;
+
+ // fadd X, -0 ==> X
+ if (match(Op1, m_NegZeroFP()))
+ return Op0;
+
+ // fadd X, 0 ==> X, when we know X is not -0
+ if (match(Op1, m_PosZeroFP()) &&
+ (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI)))
+ return Op0;
+
+ // With nnan: -X + X --> 0.0 (and commuted variant)
+ // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN.
+ // Negative zeros are allowed because we always end up with positive zero:
+ // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
+ // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
+ // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0
+ // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0
+ if (FMF.noNaNs()) {
+ if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) ||
+ match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))
+ return ConstantFP::getNullValue(Op0->getType());
+
+ if (match(Op0, m_FNeg(m_Specific(Op1))) ||
+ match(Op1, m_FNeg(m_Specific(Op0))))
+ return ConstantFP::getNullValue(Op0->getType());
+ }
+
+ // (X - Y) + Y --> X
+ // Y + (X - Y) --> X
+ Value *X;
+ if (FMF.noSignedZeros() && FMF.allowReassoc() &&
+ (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) ||
+ match(Op1, m_FSub(m_Value(X), m_Specific(Op0)))))
+ return X;
+
+ return nullptr;
+}
+
+/// Given operands for an FSub, see if we can fold the result. If not, this
+/// returns null.
+static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
+ return C;
+
+ if (Constant *C = simplifyFPOp({Op0, Op1}))
+ return C;
+
+ // fsub X, +0 ==> X
+ if (match(Op1, m_PosZeroFP()))
+ return Op0;
+
+ // fsub X, -0 ==> X, when we know X is not -0
+ if (match(Op1, m_NegZeroFP()) &&
+ (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI)))
+ return Op0;
+
+ // fsub -0.0, (fsub -0.0, X) ==> X
+ // fsub -0.0, (fneg X) ==> X
+ Value *X;
+ if (match(Op0, m_NegZeroFP()) &&
+ match(Op1, m_FNeg(m_Value(X))))
+ return X;
+
+ // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored.
+ // fsub 0.0, (fneg X) ==> X if signed zeros are ignored.
+ if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) &&
+ (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) ||
+ match(Op1, m_FNeg(m_Value(X)))))
+ return X;
+
+ // fsub nnan x, x ==> 0.0
+ if (FMF.noNaNs() && Op0 == Op1)
+ return Constant::getNullValue(Op0->getType());
+
+ // Y - (Y - X) --> X
+ // (X + Y) - Y --> X
+ if (FMF.noSignedZeros() && FMF.allowReassoc() &&
+ (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) ||
+ match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X)))))
+ return X;
+
+ return nullptr;
+}
+
+static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = simplifyFPOp({Op0, Op1}))
+ return C;
+
+ // fmul X, 1.0 ==> X
+ if (match(Op1, m_FPOne()))
+ return Op0;
+
+ // fmul 1.0, X ==> X
+ if (match(Op0, m_FPOne()))
+ return Op1;
+
+ // fmul nnan nsz X, 0 ==> 0
+ if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP()))
+ return ConstantFP::getNullValue(Op0->getType());
+
+ // fmul nnan nsz 0, X ==> 0
+ if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()))
+ return ConstantFP::getNullValue(Op1->getType());
+
+ // sqrt(X) * sqrt(X) --> X, if we can:
+ // 1. Remove the intermediate rounding (reassociate).
+ // 2. Ignore non-zero negative numbers because sqrt would produce NAN.
+ // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0.
+ Value *X;
+ if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) &&
+ FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros())
+ return X;
+
+ return nullptr;
+}
+
+/// Given the operands for an FMul, see if we can fold the result
+static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+ return C;
+
+ // Now apply simplifications that do not require rounding.
+ return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
+}
+
+Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+
+Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
+ return C;
+
+ if (Constant *C = simplifyFPOp({Op0, Op1}))
+ return C;
+
+ // X / 1.0 -> X
+ if (match(Op1, m_FPOne()))
+ return Op0;
+
+ // 0 / X -> 0
+ // Requires that NaNs are off (X could be zero) and signed zeroes are
+ // ignored (X could be positive or negative, so the output sign is unknown).
+ if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()))
+ return ConstantFP::getNullValue(Op0->getType());
+
+ if (FMF.noNaNs()) {
+ // X / X -> 1.0 is legal when NaNs are ignored.
+ // We can ignore infinities because INF/INF is NaN.
+ if (Op0 == Op1)
+ return ConstantFP::get(Op0->getType(), 1.0);
+
+ // (X * Y) / Y --> X if we can reassociate to the above form.
+ Value *X;
+ if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1))))
+ return X;
+
+ // -X / X -> -1.0 and
+ // X / -X -> -1.0 are legal when NaNs are ignored.
+ // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored.
+ if (match(Op0, m_FNegNSZ(m_Specific(Op1))) ||
+ match(Op1, m_FNegNSZ(m_Specific(Op0))))
+ return ConstantFP::get(Op0->getType(), -1.0);
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
+ return C;
+
+ if (Constant *C = simplifyFPOp({Op0, Op1}))
+ return C;
+
+ // Unlike fdiv, the result of frem always matches the sign of the dividend.
+ // The constant match may include undef elements in a vector, so return a full
+ // zero constant as the result.
+ if (FMF.noNaNs()) {
+ // +0 % X -> 0
+ if (match(Op0, m_PosZeroFP()))
+ return ConstantFP::getNullValue(Op0->getType());
+ // -0 % X -> -0
+ if (match(Op0, m_NegZeroFP()))
+ return ConstantFP::getNegativeZero(Op0->getType());
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
+//=== Helper functions for higher up the class hierarchy.
+
+/// Given the operand for a UnaryOperator, see if we can fold the result.
+/// If not, this returns null.
+static Value *simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ switch (Opcode) {
+ case Instruction::FNeg:
+ return simplifyFNegInst(Op, FastMathFlags(), Q, MaxRecurse);
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+}
+
+/// Given the operand for a UnaryOperator, see if we can fold the result.
+/// If not, this returns null.
+/// Try to use FastMathFlags when folding the result.
+static Value *simplifyFPUnOp(unsigned Opcode, Value *Op,
+ const FastMathFlags &FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ switch (Opcode) {
+ case Instruction::FNeg:
+ return simplifyFNegInst(Op, FMF, Q, MaxRecurse);
+ default:
+ return simplifyUnOp(Opcode, Op, Q, MaxRecurse);
+ }
+}
+
+Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q) {
+ return ::simplifyUnOp(Opcode, Op, Q, RecursionLimit);
+}
+
+Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::simplifyFPUnOp(Opcode, Op, FMF, Q, RecursionLimit);
+}
+
+/// Given operands for a BinaryOperator, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ switch (Opcode) {
+ case Instruction::Add:
+ return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse);
+ case Instruction::Sub:
+ return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse);
+ case Instruction::Mul:
+ return SimplifyMulInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::SDiv:
+ return SimplifySDivInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::UDiv:
+ return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::SRem:
+ return SimplifySRemInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::URem:
+ return SimplifyURemInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::Shl:
+ return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse);
+ case Instruction::LShr:
+ return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse);
+ case Instruction::AShr:
+ return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse);
+ case Instruction::And:
+ return SimplifyAndInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::Or:
+ return SimplifyOrInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::Xor:
+ return SimplifyXorInst(LHS, RHS, Q, MaxRecurse);
+ case Instruction::FAdd:
+ return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+ case Instruction::FSub:
+ return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+ case Instruction::FMul:
+ return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+ case Instruction::FDiv:
+ return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+ case Instruction::FRem:
+ return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+}
+
+/// Given operands for a BinaryOperator, see if we can fold the result.
+/// If not, this returns null.
+/// Try to use FastMathFlags when folding the result.
+static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
+ const FastMathFlags &FMF, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ switch (Opcode) {
+ case Instruction::FAdd:
+ return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse);
+ case Instruction::FSub:
+ return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse);
+ case Instruction::FMul:
+ return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse);
+ case Instruction::FDiv:
+ return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse);
+ default:
+ return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse);
+ }
+}
+
+Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q) {
+ return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit);
+}
+
+Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
+ FastMathFlags FMF, const SimplifyQuery &Q) {
+ return ::SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit);
+}
+
+/// Given operands for a CmpInst, see if we can fold the result.
+static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate))
+ return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse);
+ return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse);
+}
+
+Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q) {
+ return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit);
+}
+
+static bool IsIdempotent(Intrinsic::ID ID) {
+ switch (ID) {
+ default: return false;
+
+ // Unary idempotent: f(f(x)) = f(x)
+ case Intrinsic::fabs:
+ case Intrinsic::floor:
+ case Intrinsic::ceil:
+ case Intrinsic::trunc:
+ case Intrinsic::rint:
+ case Intrinsic::nearbyint:
+ case Intrinsic::round:
+ case Intrinsic::canonicalize:
+ return true;
+ }
+}
+
+static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset,
+ const DataLayout &DL) {
+ GlobalValue *PtrSym;
+ APInt PtrOffset;
+ if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL))
+ return nullptr;
+
+ Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext());
+ Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
+ Type *Int32PtrTy = Int32Ty->getPointerTo();
+ Type *Int64Ty = Type::getInt64Ty(Ptr->getContext());
+
+ auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
+ if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+ return nullptr;
+
+ uint64_t OffsetInt = OffsetConstInt->getSExtValue();
+ if (OffsetInt % 4 != 0)
+ return nullptr;
+
+ Constant *C = ConstantExpr::getGetElementPtr(
+ Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy),
+ ConstantInt::get(Int64Ty, OffsetInt / 4));
+ Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL);
+ if (!Loaded)
+ return nullptr;
+
+ auto *LoadedCE = dyn_cast<ConstantExpr>(Loaded);
+ if (!LoadedCE)
+ return nullptr;
+
+ if (LoadedCE->getOpcode() == Instruction::Trunc) {
+ LoadedCE = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0));
+ if (!LoadedCE)
+ return nullptr;
+ }
+
+ if (LoadedCE->getOpcode() != Instruction::Sub)
+ return nullptr;
+
+ auto *LoadedLHS = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0));
+ if (!LoadedLHS || LoadedLHS->getOpcode() != Instruction::PtrToInt)
+ return nullptr;
+ auto *LoadedLHSPtr = LoadedLHS->getOperand(0);
+
+ Constant *LoadedRHS = LoadedCE->getOperand(1);
+ GlobalValue *LoadedRHSSym;
+ APInt LoadedRHSOffset;
+ if (!IsConstantOffsetFromGlobal(LoadedRHS, LoadedRHSSym, LoadedRHSOffset,
+ DL) ||
+ PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset)
+ return nullptr;
+
+ return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy);
+}
+
+static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
+ const SimplifyQuery &Q) {
+ // Idempotent functions return the same result when called repeatedly.
+ Intrinsic::ID IID = F->getIntrinsicID();
+ if (IsIdempotent(IID))
+ if (auto *II = dyn_cast<IntrinsicInst>(Op0))
+ if (II->getIntrinsicID() == IID)
+ return II;
+
+ Value *X;
+ switch (IID) {
+ case Intrinsic::fabs:
+ if (SignBitMustBeZero(Op0, Q.TLI)) return Op0;
+ break;
+ case Intrinsic::bswap:
+ // bswap(bswap(x)) -> x
+ if (match(Op0, m_BSwap(m_Value(X)))) return X;
+ break;
+ case Intrinsic::bitreverse:
+ // bitreverse(bitreverse(x)) -> x
+ if (match(Op0, m_BitReverse(m_Value(X)))) return X;
+ break;
+ case Intrinsic::exp:
+ // exp(log(x)) -> x
+ if (Q.CxtI->hasAllowReassoc() &&
+ match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) return X;
+ break;
+ case Intrinsic::exp2:
+ // exp2(log2(x)) -> x
+ if (Q.CxtI->hasAllowReassoc() &&
+ match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) return X;
+ break;
+ case Intrinsic::log:
+ // log(exp(x)) -> x
+ if (Q.CxtI->hasAllowReassoc() &&
+ match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) return X;
+ break;
+ case Intrinsic::log2:
+ // log2(exp2(x)) -> x
+ if (Q.CxtI->hasAllowReassoc() &&
+ (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) ||
+ match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0),
+ m_Value(X))))) return X;
+ break;
+ case Intrinsic::log10:
+ // log10(pow(10.0, x)) -> x
+ if (Q.CxtI->hasAllowReassoc() &&
+ match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0),
+ m_Value(X)))) return X;
+ break;
+ case Intrinsic::floor:
+ case Intrinsic::trunc:
+ case Intrinsic::ceil:
+ case Intrinsic::round:
+ case Intrinsic::nearbyint:
+ case Intrinsic::rint: {
+ // floor (sitofp x) -> sitofp x
+ // floor (uitofp x) -> uitofp x
+ //
+ // Converting from int always results in a finite integral number or
+ // infinity. For either of those inputs, these rounding functions always
+ // return the same value, so the rounding can be eliminated.
+ if (match(Op0, m_SIToFP(m_Value())) || match(Op0, m_UIToFP(m_Value())))
+ return Op0;
+ break;
+ }
+ default:
+ break;
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
+ const SimplifyQuery &Q) {
+ Intrinsic::ID IID = F->getIntrinsicID();
+ Type *ReturnType = F->getReturnType();
+ switch (IID) {
+ case Intrinsic::usub_with_overflow:
+ case Intrinsic::ssub_with_overflow:
+ // X - X -> { 0, false }
+ if (Op0 == Op1)
+ return Constant::getNullValue(ReturnType);
+ LLVM_FALLTHROUGH;
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::sadd_with_overflow:
+ // X - undef -> { undef, false }
+ // undef - X -> { undef, false }
+ // X + undef -> { undef, false }
+ // undef + x -> { undef, false }
+ if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) {
+ return ConstantStruct::get(
+ cast<StructType>(ReturnType),
+ {UndefValue::get(ReturnType->getStructElementType(0)),
+ Constant::getNullValue(ReturnType->getStructElementType(1))});
+ }
+ break;
+ case Intrinsic::umul_with_overflow:
+ case Intrinsic::smul_with_overflow:
+ // 0 * X -> { 0, false }
+ // X * 0 -> { 0, false }
+ if (match(Op0, m_Zero()) || match(Op1, m_Zero()))
+ return Constant::getNullValue(ReturnType);
+ // undef * X -> { 0, false }
+ // X * undef -> { 0, false }
+ if (match(Op0, m_Undef()) || match(Op1, m_Undef()))
+ return Constant::getNullValue(ReturnType);
+ break;
+ case Intrinsic::uadd_sat:
+ // sat(MAX + X) -> MAX
+ // sat(X + MAX) -> MAX
+ if (match(Op0, m_AllOnes()) || match(Op1, m_AllOnes()))
+ return Constant::getAllOnesValue(ReturnType);
+ LLVM_FALLTHROUGH;
+ case Intrinsic::sadd_sat:
+ // sat(X + undef) -> -1
+ // sat(undef + X) -> -1
+ // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1).
+ // For signed: Assume undef is ~X, in which case X + ~X = -1.
+ if (match(Op0, m_Undef()) || match(Op1, m_Undef()))
+ return Constant::getAllOnesValue(ReturnType);
+
+ // X + 0 -> X
+ if (match(Op1, m_Zero()))
+ return Op0;
+ // 0 + X -> X
+ if (match(Op0, m_Zero()))
+ return Op1;
+ break;
+ case Intrinsic::usub_sat:
+ // sat(0 - X) -> 0, sat(X - MAX) -> 0
+ if (match(Op0, m_Zero()) || match(Op1, m_AllOnes()))
+ return Constant::getNullValue(ReturnType);
+ LLVM_FALLTHROUGH;
+ case Intrinsic::ssub_sat:
+ // X - X -> 0, X - undef -> 0, undef - X -> 0
+ if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef()))
+ return Constant::getNullValue(ReturnType);
+ // X - 0 -> X
+ if (match(Op1, m_Zero()))
+ return Op0;
+ break;
+ case Intrinsic::load_relative:
+ if (auto *C0 = dyn_cast<Constant>(Op0))
+ if (auto *C1 = dyn_cast<Constant>(Op1))
+ return SimplifyRelativeLoad(C0, C1, Q.DL);
+ break;
+ case Intrinsic::powi:
+ if (auto *Power = dyn_cast<ConstantInt>(Op1)) {
+ // powi(x, 0) -> 1.0
+ if (Power->isZero())
+ return ConstantFP::get(Op0->getType(), 1.0);
+ // powi(x, 1) -> x
+ if (Power->isOne())
+ return Op0;
+ }
+ break;
+ case Intrinsic::maxnum:
+ case Intrinsic::minnum:
+ case Intrinsic::maximum:
+ case Intrinsic::minimum: {
+ // If the arguments are the same, this is a no-op.
+ if (Op0 == Op1) return Op0;
+
+ // If one argument is undef, return the other argument.
+ if (match(Op0, m_Undef()))
+ return Op1;
+ if (match(Op1, m_Undef()))
+ return Op0;
+
+ // If one argument is NaN, return other or NaN appropriately.
+ bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum;
+ if (match(Op0, m_NaN()))
+ return PropagateNaN ? Op0 : Op1;
+ if (match(Op1, m_NaN()))
+ return PropagateNaN ? Op1 : Op0;
+
+ // Min/max of the same operation with common operand:
+ // m(m(X, Y)), X --> m(X, Y) (4 commuted variants)
+ if (auto *M0 = dyn_cast<IntrinsicInst>(Op0))
+ if (M0->getIntrinsicID() == IID &&
+ (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1))
+ return Op0;
+ if (auto *M1 = dyn_cast<IntrinsicInst>(Op1))
+ if (M1->getIntrinsicID() == IID &&
+ (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0))
+ return Op1;
+
+ // min(X, -Inf) --> -Inf (and commuted variant)
+ // max(X, +Inf) --> +Inf (and commuted variant)
+ bool UseNegInf = IID == Intrinsic::minnum || IID == Intrinsic::minimum;
+ const APFloat *C;
+ if ((match(Op0, m_APFloat(C)) && C->isInfinity() &&
+ C->isNegative() == UseNegInf) ||
+ (match(Op1, m_APFloat(C)) && C->isInfinity() &&
+ C->isNegative() == UseNegInf))
+ return ConstantFP::getInfinity(ReturnType, UseNegInf);
+
+ // TODO: minnum(nnan x, inf) -> x
+ // TODO: minnum(nnan ninf x, flt_max) -> x
+ // TODO: maxnum(nnan x, -inf) -> x
+ // TODO: maxnum(nnan ninf x, -flt_max) -> x
+ break;
+ }
+ default:
+ break;
+ }
+
+ return nullptr;
+}
+
+static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
+
+ // Intrinsics with no operands have some kind of side effect. Don't simplify.
+ unsigned NumOperands = Call->getNumArgOperands();
+ if (!NumOperands)
+ return nullptr;
+
+ Function *F = cast<Function>(Call->getCalledFunction());
+ Intrinsic::ID IID = F->getIntrinsicID();
+ if (NumOperands == 1)
+ return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q);
+
+ if (NumOperands == 2)
+ return simplifyBinaryIntrinsic(F, Call->getArgOperand(0),
+ Call->getArgOperand(1), Q);
+
+ // Handle intrinsics with 3 or more arguments.
+ switch (IID) {
+ case Intrinsic::masked_load:
+ case Intrinsic::masked_gather: {
+ Value *MaskArg = Call->getArgOperand(2);
+ Value *PassthruArg = Call->getArgOperand(3);
+ // If the mask is all zeros or undef, the "passthru" argument is the result.
+ if (maskIsAllZeroOrUndef(MaskArg))
+ return PassthruArg;
+ return nullptr;
+ }
+ case Intrinsic::fshl:
+ case Intrinsic::fshr: {
+ Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1),
+ *ShAmtArg = Call->getArgOperand(2);
+
+ // If both operands are undef, the result is undef.
+ if (match(Op0, m_Undef()) && match(Op1, m_Undef()))
+ return UndefValue::get(F->getReturnType());
+
+ // If shift amount is undef, assume it is zero.
+ if (match(ShAmtArg, m_Undef()))
+ return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
+
+ const APInt *ShAmtC;
+ if (match(ShAmtArg, m_APInt(ShAmtC))) {
+ // If there's effectively no shift, return the 1st arg or 2nd arg.
+ APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth());
+ if (ShAmtC->urem(BitWidth).isNullValue())
+ return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
+ }
+ return nullptr;
+ }
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd: {
+ Value *Op0 = Call->getArgOperand(0);
+ Value *Op1 = Call->getArgOperand(1);
+ Value *Op2 = Call->getArgOperand(2);
+ if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }))
+ return V;
+ return nullptr;
+ }
+ default:
+ return nullptr;
+ }
+}
+
+Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) {
+ Value *Callee = Call->getCalledValue();
+
+ // call undef -> undef
+ // call null -> undef
+ if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee))
+ return UndefValue::get(Call->getType());
+
+ Function *F = dyn_cast<Function>(Callee);
+ if (!F)
+ return nullptr;
+
+ if (F->isIntrinsic())
+ if (Value *Ret = simplifyIntrinsic(Call, Q))
+ return Ret;
+
+ if (!canConstantFoldCallTo(Call, F))
+ return nullptr;
+
+ SmallVector<Constant *, 4> ConstantArgs;
+ unsigned NumArgs = Call->getNumArgOperands();
+ ConstantArgs.reserve(NumArgs);
+ for (auto &Arg : Call->args()) {
+ Constant *C = dyn_cast<Constant>(&Arg);
+ if (!C)
+ return nullptr;
+ ConstantArgs.push_back(C);
+ }
+
+ return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI);
+}
+
+/// See if we can compute a simplified version of this instruction.
+/// If not, this returns null.
+
+Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
+ OptimizationRemarkEmitter *ORE) {
+ const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I);
+ Value *Result;
+
+ switch (I->getOpcode()) {
+ default:
+ Result = ConstantFoldInstruction(I, Q.DL, Q.TLI);
+ break;
+ case Instruction::FNeg:
+ Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q);
+ break;
+ case Instruction::FAdd:
+ Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1),
+ I->getFastMathFlags(), Q);
+ break;
+ case Instruction::Add:
+ Result =
+ SimplifyAddInst(I->getOperand(0), I->getOperand(1),
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ break;
+ case Instruction::FSub:
+ Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1),
+ I->getFastMathFlags(), Q);
+ break;
+ case Instruction::Sub:
+ Result =
+ SimplifySubInst(I->getOperand(0), I->getOperand(1),
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ break;
+ case Instruction::FMul:
+ Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1),
+ I->getFastMathFlags(), Q);
+ break;
+ case Instruction::Mul:
+ Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::SDiv:
+ Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::UDiv:
+ Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::FDiv:
+ Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1),
+ I->getFastMathFlags(), Q);
+ break;
+ case Instruction::SRem:
+ Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::URem:
+ Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::FRem:
+ Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1),
+ I->getFastMathFlags(), Q);
+ break;
+ case Instruction::Shl:
+ Result =
+ SimplifyShlInst(I->getOperand(0), I->getOperand(1),
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ break;
+ case Instruction::LShr:
+ Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1),
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
+ break;
+ case Instruction::AShr:
+ Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1),
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
+ break;
+ case Instruction::And:
+ Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::Or:
+ Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::Xor:
+ Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::ICmp:
+ Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(),
+ I->getOperand(0), I->getOperand(1), Q);
+ break;
+ case Instruction::FCmp:
+ Result =
+ SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0),
+ I->getOperand(1), I->getFastMathFlags(), Q);
+ break;
+ case Instruction::Select:
+ Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1),
+ I->getOperand(2), Q);
+ break;
+ case Instruction::GetElementPtr: {
+ SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end());
+ Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(),
+ Ops, Q);
+ break;
+ }
+ case Instruction::InsertValue: {
+ InsertValueInst *IV = cast<InsertValueInst>(I);
+ Result = SimplifyInsertValueInst(IV->getAggregateOperand(),
+ IV->getInsertedValueOperand(),
+ IV->getIndices(), Q);
+ break;
+ }
+ case Instruction::InsertElement: {
+ auto *IE = cast<InsertElementInst>(I);
+ Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1),
+ IE->getOperand(2), Q);
+ break;
+ }
+ case Instruction::ExtractValue: {
+ auto *EVI = cast<ExtractValueInst>(I);
+ Result = SimplifyExtractValueInst(EVI->getAggregateOperand(),
+ EVI->getIndices(), Q);
+ break;
+ }
+ case Instruction::ExtractElement: {
+ auto *EEI = cast<ExtractElementInst>(I);
+ Result = SimplifyExtractElementInst(EEI->getVectorOperand(),
+ EEI->getIndexOperand(), Q);
+ break;
+ }
+ case Instruction::ShuffleVector: {
+ auto *SVI = cast<ShuffleVectorInst>(I);
+ Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1),
+ SVI->getMask(), SVI->getType(), Q);
+ break;
+ }
+ case Instruction::PHI:
+ Result = SimplifyPHINode(cast<PHINode>(I), Q);
+ break;
+ case Instruction::Call: {
+ Result = SimplifyCall(cast<CallInst>(I), Q);
+ break;
+ }
+#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc:
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_CAST_INST
+ Result =
+ SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q);
+ break;
+ case Instruction::Alloca:
+ // No simplifications for Alloca and it can't be constant folded.
+ Result = nullptr;
+ break;
+ }
+
+ // In general, it is possible for computeKnownBits to determine all bits in a
+ // value even when the operands are not all constants.
+ if (!Result && I->getType()->isIntOrIntVectorTy()) {
+ KnownBits Known = computeKnownBits(I, Q.DL, /*Depth*/ 0, Q.AC, I, Q.DT, ORE);
+ if (Known.isConstant())
+ Result = ConstantInt::get(I->getType(), Known.getConstant());
+ }
+
+ /// If called on unreachable code, the above logic may report that the
+ /// instruction simplified to itself. Make life easier for users by
+ /// detecting that case here, returning a safe value instead.
+ return Result == I ? UndefValue::get(I->getType()) : Result;
+}
+
+/// Implementation of recursive simplification through an instruction's
+/// uses.
+///
+/// This is the common implementation of the recursive simplification routines.
+/// If we have a pre-simplified value in 'SimpleV', that is forcibly used to
+/// replace the instruction 'I'. Otherwise, we simply add 'I' to the list of
+/// instructions to process and attempt to simplify it using
+/// InstructionSimplify. Recursively visited users which could not be
+/// simplified themselves are to the optional UnsimplifiedUsers set for
+/// further processing by the caller.
+///
+/// This routine returns 'true' only when *it* simplifies something. The passed
+/// in simplified value does not count toward this.
+static bool replaceAndRecursivelySimplifyImpl(
+ Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI,
+ const DominatorTree *DT, AssumptionCache *AC,
+ SmallSetVector<Instruction *, 8> *UnsimplifiedUsers = nullptr) {
+ bool Simplified = false;
+ SmallSetVector<Instruction *, 8> Worklist;
+ const DataLayout &DL = I->getModule()->getDataLayout();
+
+ // If we have an explicit value to collapse to, do that round of the
+ // simplification loop by hand initially.
+ if (SimpleV) {
+ for (User *U : I->users())
+ if (U != I)
+ Worklist.insert(cast<Instruction>(U));
+
+ // Replace the instruction with its simplified value.
+ I->replaceAllUsesWith(SimpleV);
+
+ // Gracefully handle edge cases where the instruction is not wired into any
+ // parent block.
+ if (I->getParent() && !I->isEHPad() && !I->isTerminator() &&
+ !I->mayHaveSideEffects())
+ I->eraseFromParent();
+ } else {
+ Worklist.insert(I);
+ }
+
+ // Note that we must test the size on each iteration, the worklist can grow.
+ for (unsigned Idx = 0; Idx != Worklist.size(); ++Idx) {
+ I = Worklist[Idx];
+
+ // See if this instruction simplifies.
+ SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC});
+ if (!SimpleV) {
+ if (UnsimplifiedUsers)
+ UnsimplifiedUsers->insert(I);
+ continue;
+ }
+
+ Simplified = true;
+
+ // Stash away all the uses of the old instruction so we can check them for
+ // recursive simplifications after a RAUW. This is cheaper than checking all
+ // uses of To on the recursive step in most cases.
+ for (User *U : I->users())
+ Worklist.insert(cast<Instruction>(U));
+
+ // Replace the instruction with its simplified value.
+ I->replaceAllUsesWith(SimpleV);
+
+ // Gracefully handle edge cases where the instruction is not wired into any
+ // parent block.
+ if (I->getParent() && !I->isEHPad() && !I->isTerminator() &&
+ !I->mayHaveSideEffects())
+ I->eraseFromParent();
+ }
+ return Simplified;
+}
+
+bool llvm::recursivelySimplifyInstruction(Instruction *I,
+ const TargetLibraryInfo *TLI,
+ const DominatorTree *DT,
+ AssumptionCache *AC) {
+ return replaceAndRecursivelySimplifyImpl(I, nullptr, TLI, DT, AC, nullptr);
+}
+
+bool llvm::replaceAndRecursivelySimplify(
+ Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI,
+ const DominatorTree *DT, AssumptionCache *AC,
+ SmallSetVector<Instruction *, 8> *UnsimplifiedUsers) {
+ assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!");
+ assert(SimpleV && "Must provide a simplified value.");
+ return replaceAndRecursivelySimplifyImpl(I, SimpleV, TLI, DT, AC,
+ UnsimplifiedUsers);
+}
+
+namespace llvm {
+const SimplifyQuery getBestSimplifyQuery(Pass &P, Function &F) {
+ auto *DTWP = P.getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+ auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
+ auto *TLIWP = P.getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
+ auto *TLI = TLIWP ? &TLIWP->getTLI(F) : nullptr;
+ auto *ACWP = P.getAnalysisIfAvailable<AssumptionCacheTracker>();
+ auto *AC = ACWP ? &ACWP->getAssumptionCache(F) : nullptr;
+ return {F.getParent()->getDataLayout(), TLI, DT, AC};
+}
+
+const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &AR,
+ const DataLayout &DL) {
+ return {DL, &AR.TLI, &AR.DT, &AR.AC};
+}
+
+template <class T, class... TArgs>
+const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM,
+ Function &F) {
+ auto *DT = AM.template getCachedResult<DominatorTreeAnalysis>(F);
+ auto *TLI = AM.template getCachedResult<TargetLibraryAnalysis>(F);
+ auto *AC = AM.template getCachedResult<AssumptionAnalysis>(F);
+ return {F.getParent()->getDataLayout(), TLI, DT, AC};
+}
+template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &,
+ Function &);
+}
diff --git a/llvm/lib/Analysis/Interval.cpp b/llvm/lib/Analysis/Interval.cpp
new file mode 100644
index 000000000000..07d6e27c13be
--- /dev/null
+++ b/llvm/lib/Analysis/Interval.cpp
@@ -0,0 +1,51 @@
+//===- Interval.cpp - Interval class code ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the Interval class, which represents a
+// partition of a control flow graph of some kind.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Interval.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+//===----------------------------------------------------------------------===//
+// Interval Implementation
+//===----------------------------------------------------------------------===//
+
+// isLoop - Find out if there is a back edge in this interval...
+bool Interval::isLoop() const {
+ // There is a loop in this interval iff one of the predecessors of the header
+ // node lives in the interval.
+ for (::pred_iterator I = ::pred_begin(HeaderNode), E = ::pred_end(HeaderNode);
+ I != E; ++I)
+ if (contains(*I))
+ return true;
+ return false;
+}
+
+void Interval::print(raw_ostream &OS) const {
+ OS << "-------------------------------------------------------------\n"
+ << "Interval Contents:\n";
+
+ // Print out all of the basic blocks in the interval...
+ for (const BasicBlock *Node : Nodes)
+ OS << *Node << "\n";
+
+ OS << "Interval Predecessors:\n";
+ for (const BasicBlock *Predecessor : Predecessors)
+ OS << *Predecessor << "\n";
+
+ OS << "Interval Successors:\n";
+ for (const BasicBlock *Successor : Successors)
+ OS << *Successor << "\n";
+}
diff --git a/llvm/lib/Analysis/IntervalPartition.cpp b/llvm/lib/Analysis/IntervalPartition.cpp
new file mode 100644
index 000000000000..d12db010db6a
--- /dev/null
+++ b/llvm/lib/Analysis/IntervalPartition.cpp
@@ -0,0 +1,113 @@
+//===- IntervalPartition.cpp - Interval Partition module code -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the IntervalPartition class, which
+// calculates and represent the interval partition of a function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IntervalPartition.h"
+#include "llvm/Analysis/Interval.h"
+#include "llvm/Analysis/IntervalIterator.h"
+#include "llvm/Pass.h"
+#include <cassert>
+#include <utility>
+
+using namespace llvm;
+
+char IntervalPartition::ID = 0;
+
+INITIALIZE_PASS(IntervalPartition, "intervals",
+ "Interval Partition Construction", true, true)
+
+//===----------------------------------------------------------------------===//
+// IntervalPartition Implementation
+//===----------------------------------------------------------------------===//
+
+// releaseMemory - Reset state back to before function was analyzed
+void IntervalPartition::releaseMemory() {
+ for (unsigned i = 0, e = Intervals.size(); i != e; ++i)
+ delete Intervals[i];
+ IntervalMap.clear();
+ Intervals.clear();
+ RootInterval = nullptr;
+}
+
+void IntervalPartition::print(raw_ostream &O, const Module*) const {
+ for(unsigned i = 0, e = Intervals.size(); i != e; ++i)
+ Intervals[i]->print(O);
+}
+
+// addIntervalToPartition - Add an interval to the internal list of intervals,
+// and then add mappings from all of the basic blocks in the interval to the
+// interval itself (in the IntervalMap).
+void IntervalPartition::addIntervalToPartition(Interval *I) {
+ Intervals.push_back(I);
+
+ // Add mappings for all of the basic blocks in I to the IntervalPartition
+ for (Interval::node_iterator It = I->Nodes.begin(), End = I->Nodes.end();
+ It != End; ++It)
+ IntervalMap.insert(std::make_pair(*It, I));
+}
+
+// updatePredecessors - Interval generation only sets the successor fields of
+// the interval data structures. After interval generation is complete,
+// run through all of the intervals and propagate successor info as
+// predecessor info.
+void IntervalPartition::updatePredecessors(Interval *Int) {
+ BasicBlock *Header = Int->getHeaderNode();
+ for (BasicBlock *Successor : Int->Successors)
+ getBlockInterval(Successor)->Predecessors.push_back(Header);
+}
+
+// IntervalPartition ctor - Build the first level interval partition for the
+// specified function...
+bool IntervalPartition::runOnFunction(Function &F) {
+ // Pass false to intervals_begin because we take ownership of it's memory
+ function_interval_iterator I = intervals_begin(&F, false);
+ assert(I != intervals_end(&F) && "No intervals in function!?!?!");
+
+ addIntervalToPartition(RootInterval = *I);
+
+ ++I; // After the first one...
+
+ // Add the rest of the intervals to the partition.
+ for (function_interval_iterator E = intervals_end(&F); I != E; ++I)
+ addIntervalToPartition(*I);
+
+ // Now that we know all of the successor information, propagate this to the
+ // predecessors for each block.
+ for (unsigned i = 0, e = Intervals.size(); i != e; ++i)
+ updatePredecessors(Intervals[i]);
+ return false;
+}
+
+// IntervalPartition ctor - Build a reduced interval partition from an
+// existing interval graph. This takes an additional boolean parameter to
+// distinguish it from a copy constructor. Always pass in false for now.
+IntervalPartition::IntervalPartition(IntervalPartition &IP, bool)
+ : FunctionPass(ID) {
+ assert(IP.getRootInterval() && "Cannot operate on empty IntervalPartitions!");
+
+ // Pass false to intervals_begin because we take ownership of it's memory
+ interval_part_interval_iterator I = intervals_begin(IP, false);
+ assert(I != intervals_end(IP) && "No intervals in interval partition!?!?!");
+
+ addIntervalToPartition(RootInterval = *I);
+
+ ++I; // After the first one...
+
+ // Add the rest of the intervals to the partition.
+ for (interval_part_interval_iterator E = intervals_end(IP); I != E; ++I)
+ addIntervalToPartition(*I);
+
+ // Now that we know all of the successor information, propagate this to the
+ // predecessors for each block.
+ for (unsigned i = 0, e = Intervals.size(); i != e; ++i)
+ updatePredecessors(Intervals[i]);
+}
diff --git a/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp b/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp
new file mode 100644
index 000000000000..439758560284
--- /dev/null
+++ b/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp
@@ -0,0 +1,71 @@
+//===- LazyBlockFrequencyInfo.cpp - Lazy Block Frequency Analysis ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is an alternative analysis pass to BlockFrequencyInfoWrapperPass. The
+// difference is that with this pass the block frequencies are not computed when
+// the analysis pass is executed but rather when the BFI result is explicitly
+// requested by the analysis client.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LazyBlockFrequencyInfo.h"
+#include "llvm/Analysis/LazyBranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Dominators.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "lazy-block-freq"
+
+INITIALIZE_PASS_BEGIN(LazyBlockFrequencyInfoPass, DEBUG_TYPE,
+ "Lazy Block Frequency Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(LazyBPIPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LazyBlockFrequencyInfoPass, DEBUG_TYPE,
+ "Lazy Block Frequency Analysis", true, true)
+
+char LazyBlockFrequencyInfoPass::ID = 0;
+
+LazyBlockFrequencyInfoPass::LazyBlockFrequencyInfoPass() : FunctionPass(ID) {
+ initializeLazyBlockFrequencyInfoPassPass(*PassRegistry::getPassRegistry());
+}
+
+void LazyBlockFrequencyInfoPass::print(raw_ostream &OS, const Module *) const {
+ LBFI.getCalculated().print(OS);
+}
+
+void LazyBlockFrequencyInfoPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU);
+ // We require DT so it's available when LI is available. The LI updating code
+ // asserts that DT is also present so if we don't make sure that we have DT
+ // here, that assert will trigger.
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.setPreservesAll();
+}
+
+void LazyBlockFrequencyInfoPass::releaseMemory() { LBFI.releaseMemory(); }
+
+bool LazyBlockFrequencyInfoPass::runOnFunction(Function &F) {
+ auto &BPIPass = getAnalysis<LazyBranchProbabilityInfoPass>();
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ LBFI.setAnalysis(&F, &BPIPass, &LI);
+ return false;
+}
+
+void LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AnalysisUsage &AU) {
+ LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU);
+ AU.addRequired<LazyBlockFrequencyInfoPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+}
+
+void llvm::initializeLazyBFIPassPass(PassRegistry &Registry) {
+ initializeLazyBPIPassPass(Registry);
+ INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass);
+ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
+}
diff --git a/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp b/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp
new file mode 100644
index 000000000000..e727de468a0d
--- /dev/null
+++ b/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp
@@ -0,0 +1,74 @@
+//===- LazyBranchProbabilityInfo.cpp - Lazy Branch Probability Analysis ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is an alternative analysis pass to BranchProbabilityInfoWrapperPass.
+// The difference is that with this pass the branch probabilities are not
+// computed when the analysis pass is executed but rather when the BPI results
+// is explicitly requested by the analysis client.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LazyBranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Dominators.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "lazy-branch-prob"
+
+INITIALIZE_PASS_BEGIN(LazyBranchProbabilityInfoPass, DEBUG_TYPE,
+ "Lazy Branch Probability Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(LazyBranchProbabilityInfoPass, DEBUG_TYPE,
+ "Lazy Branch Probability Analysis", true, true)
+
+char LazyBranchProbabilityInfoPass::ID = 0;
+
+LazyBranchProbabilityInfoPass::LazyBranchProbabilityInfoPass()
+ : FunctionPass(ID) {
+ initializeLazyBranchProbabilityInfoPassPass(*PassRegistry::getPassRegistry());
+}
+
+void LazyBranchProbabilityInfoPass::print(raw_ostream &OS,
+ const Module *) const {
+ LBPI->getCalculated().print(OS);
+}
+
+void LazyBranchProbabilityInfoPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ // We require DT so it's available when LI is available. The LI updating code
+ // asserts that DT is also present so if we don't make sure that we have DT
+ // here, that assert will trigger.
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.setPreservesAll();
+}
+
+void LazyBranchProbabilityInfoPass::releaseMemory() { LBPI.reset(); }
+
+bool LazyBranchProbabilityInfoPass::runOnFunction(Function &F) {
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ TargetLibraryInfo &TLI =
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ LBPI = std::make_unique<LazyBranchProbabilityInfo>(&F, &LI, &TLI);
+ return false;
+}
+
+void LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AnalysisUsage &AU) {
+ AU.addRequired<LazyBranchProbabilityInfoPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+}
+
+void llvm::initializeLazyBPIPassPass(PassRegistry &Registry) {
+ INITIALIZE_PASS_DEPENDENCY(LazyBranchProbabilityInfoPass);
+ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
+ INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass);
+}
diff --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp
new file mode 100644
index 000000000000..ef31c1e0ba8c
--- /dev/null
+++ b/llvm/lib/Analysis/LazyCallGraph.cpp
@@ -0,0 +1,1816 @@
+//===- LazyCallGraph.cpp - Analysis of a Module's call graph --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <iterator>
+#include <string>
+#include <tuple>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "lcg"
+
+void LazyCallGraph::EdgeSequence::insertEdgeInternal(Node &TargetN,
+ Edge::Kind EK) {
+ EdgeIndexMap.insert({&TargetN, Edges.size()});
+ Edges.emplace_back(TargetN, EK);
+}
+
+void LazyCallGraph::EdgeSequence::setEdgeKind(Node &TargetN, Edge::Kind EK) {
+ Edges[EdgeIndexMap.find(&TargetN)->second].setKind(EK);
+}
+
+bool LazyCallGraph::EdgeSequence::removeEdgeInternal(Node &TargetN) {
+ auto IndexMapI = EdgeIndexMap.find(&TargetN);
+ if (IndexMapI == EdgeIndexMap.end())
+ return false;
+
+ Edges[IndexMapI->second] = Edge();
+ EdgeIndexMap.erase(IndexMapI);
+ return true;
+}
+
+static void addEdge(SmallVectorImpl<LazyCallGraph::Edge> &Edges,
+ DenseMap<LazyCallGraph::Node *, int> &EdgeIndexMap,
+ LazyCallGraph::Node &N, LazyCallGraph::Edge::Kind EK) {
+ if (!EdgeIndexMap.insert({&N, Edges.size()}).second)
+ return;
+
+ LLVM_DEBUG(dbgs() << " Added callable function: " << N.getName() << "\n");
+ Edges.emplace_back(LazyCallGraph::Edge(N, EK));
+}
+
+LazyCallGraph::EdgeSequence &LazyCallGraph::Node::populateSlow() {
+ assert(!Edges && "Must not have already populated the edges for this node!");
+
+ LLVM_DEBUG(dbgs() << " Adding functions called by '" << getName()
+ << "' to the graph.\n");
+
+ Edges = EdgeSequence();
+
+ SmallVector<Constant *, 16> Worklist;
+ SmallPtrSet<Function *, 4> Callees;
+ SmallPtrSet<Constant *, 16> Visited;
+
+ // Find all the potential call graph edges in this function. We track both
+ // actual call edges and indirect references to functions. The direct calls
+ // are trivially added, but to accumulate the latter we walk the instructions
+ // and add every operand which is a constant to the worklist to process
+ // afterward.
+ //
+ // Note that we consider *any* function with a definition to be a viable
+ // edge. Even if the function's definition is subject to replacement by
+ // some other module (say, a weak definition) there may still be
+ // optimizations which essentially speculate based on the definition and
+ // a way to check that the specific definition is in fact the one being
+ // used. For example, this could be done by moving the weak definition to
+ // a strong (internal) definition and making the weak definition be an
+ // alias. Then a test of the address of the weak function against the new
+ // strong definition's address would be an effective way to determine the
+ // safety of optimizing a direct call edge.
+ for (BasicBlock &BB : *F)
+ for (Instruction &I : BB) {
+ if (auto CS = CallSite(&I))
+ if (Function *Callee = CS.getCalledFunction())
+ if (!Callee->isDeclaration())
+ if (Callees.insert(Callee).second) {
+ Visited.insert(Callee);
+ addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*Callee),
+ LazyCallGraph::Edge::Call);
+ }
+
+ for (Value *Op : I.operand_values())
+ if (Constant *C = dyn_cast<Constant>(Op))
+ if (Visited.insert(C).second)
+ Worklist.push_back(C);
+ }
+
+ // We've collected all the constant (and thus potentially function or
+ // function containing) operands to all of the instructions in the function.
+ // Process them (recursively) collecting every function found.
+ visitReferences(Worklist, Visited, [&](Function &F) {
+ addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(F),
+ LazyCallGraph::Edge::Ref);
+ });
+
+ // Add implicit reference edges to any defined libcall functions (if we
+ // haven't found an explicit edge).
+ for (auto *F : G->LibFunctions)
+ if (!Visited.count(F))
+ addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*F),
+ LazyCallGraph::Edge::Ref);
+
+ return *Edges;
+}
+
+void LazyCallGraph::Node::replaceFunction(Function &NewF) {
+ assert(F != &NewF && "Must not replace a function with itself!");
+ F = &NewF;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LazyCallGraph::Node::dump() const {
+ dbgs() << *this << '\n';
+}
+#endif
+
+static bool isKnownLibFunction(Function &F, TargetLibraryInfo &TLI) {
+ LibFunc LF;
+
+ // Either this is a normal library function or a "vectorizable" function.
+ return TLI.getLibFunc(F, LF) || TLI.isFunctionVectorizable(F.getName());
+}
+
+LazyCallGraph::LazyCallGraph(
+ Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
+ LLVM_DEBUG(dbgs() << "Building CG for module: " << M.getModuleIdentifier()
+ << "\n");
+ for (Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+ // If this function is a known lib function to LLVM then we want to
+ // synthesize reference edges to it to model the fact that LLVM can turn
+ // arbitrary code into a library function call.
+ if (isKnownLibFunction(F, GetTLI(F)))
+ LibFunctions.insert(&F);
+
+ if (F.hasLocalLinkage())
+ continue;
+
+ // External linkage defined functions have edges to them from other
+ // modules.
+ LLVM_DEBUG(dbgs() << " Adding '" << F.getName()
+ << "' to entry set of the graph.\n");
+ addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), Edge::Ref);
+ }
+
+ // Externally visible aliases of internal functions are also viable entry
+ // edges to the module.
+ for (auto &A : M.aliases()) {
+ if (A.hasLocalLinkage())
+ continue;
+ if (Function* F = dyn_cast<Function>(A.getAliasee())) {
+ LLVM_DEBUG(dbgs() << " Adding '" << F->getName()
+ << "' with alias '" << A.getName()
+ << "' to entry set of the graph.\n");
+ addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(*F), Edge::Ref);
+ }
+ }
+
+ // Now add entry nodes for functions reachable via initializers to globals.
+ SmallVector<Constant *, 16> Worklist;
+ SmallPtrSet<Constant *, 16> Visited;
+ for (GlobalVariable &GV : M.globals())
+ if (GV.hasInitializer())
+ if (Visited.insert(GV.getInitializer()).second)
+ Worklist.push_back(GV.getInitializer());
+
+ LLVM_DEBUG(
+ dbgs() << " Adding functions referenced by global initializers to the "
+ "entry set.\n");
+ visitReferences(Worklist, Visited, [&](Function &F) {
+ addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F),
+ LazyCallGraph::Edge::Ref);
+ });
+}
+
+LazyCallGraph::LazyCallGraph(LazyCallGraph &&G)
+ : BPA(std::move(G.BPA)), NodeMap(std::move(G.NodeMap)),
+ EntryEdges(std::move(G.EntryEdges)), SCCBPA(std::move(G.SCCBPA)),
+ SCCMap(std::move(G.SCCMap)),
+ LibFunctions(std::move(G.LibFunctions)) {
+ updateGraphPtrs();
+}
+
+LazyCallGraph &LazyCallGraph::operator=(LazyCallGraph &&G) {
+ BPA = std::move(G.BPA);
+ NodeMap = std::move(G.NodeMap);
+ EntryEdges = std::move(G.EntryEdges);
+ SCCBPA = std::move(G.SCCBPA);
+ SCCMap = std::move(G.SCCMap);
+ LibFunctions = std::move(G.LibFunctions);
+ updateGraphPtrs();
+ return *this;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LazyCallGraph::SCC::dump() const {
+ dbgs() << *this << '\n';
+}
+#endif
+
+#ifndef NDEBUG
+void LazyCallGraph::SCC::verify() {
+ assert(OuterRefSCC && "Can't have a null RefSCC!");
+ assert(!Nodes.empty() && "Can't have an empty SCC!");
+
+ for (Node *N : Nodes) {
+ assert(N && "Can't have a null node!");
+ assert(OuterRefSCC->G->lookupSCC(*N) == this &&
+ "Node does not map to this SCC!");
+ assert(N->DFSNumber == -1 &&
+ "Must set DFS numbers to -1 when adding a node to an SCC!");
+ assert(N->LowLink == -1 &&
+ "Must set low link to -1 when adding a node to an SCC!");
+ for (Edge &E : **N)
+ assert(E.getNode().isPopulated() && "Can't have an unpopulated node!");
+ }
+}
+#endif
+
+bool LazyCallGraph::SCC::isParentOf(const SCC &C) const {
+ if (this == &C)
+ return false;
+
+ for (Node &N : *this)
+ for (Edge &E : N->calls())
+ if (OuterRefSCC->G->lookupSCC(E.getNode()) == &C)
+ return true;
+
+ // No edges found.
+ return false;
+}
+
+bool LazyCallGraph::SCC::isAncestorOf(const SCC &TargetC) const {
+ if (this == &TargetC)
+ return false;
+
+ LazyCallGraph &G = *OuterRefSCC->G;
+
+ // Start with this SCC.
+ SmallPtrSet<const SCC *, 16> Visited = {this};
+ SmallVector<const SCC *, 16> Worklist = {this};
+
+ // Walk down the graph until we run out of edges or find a path to TargetC.
+ do {
+ const SCC &C = *Worklist.pop_back_val();
+ for (Node &N : C)
+ for (Edge &E : N->calls()) {
+ SCC *CalleeC = G.lookupSCC(E.getNode());
+ if (!CalleeC)
+ continue;
+
+ // If the callee's SCC is the TargetC, we're done.
+ if (CalleeC == &TargetC)
+ return true;
+
+ // If this is the first time we've reached this SCC, put it on the
+ // worklist to recurse through.
+ if (Visited.insert(CalleeC).second)
+ Worklist.push_back(CalleeC);
+ }
+ } while (!Worklist.empty());
+
+ // No paths found.
+ return false;
+}
+
+LazyCallGraph::RefSCC::RefSCC(LazyCallGraph &G) : G(&G) {}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LazyCallGraph::RefSCC::dump() const {
+ dbgs() << *this << '\n';
+}
+#endif
+
+#ifndef NDEBUG
+void LazyCallGraph::RefSCC::verify() {
+ assert(G && "Can't have a null graph!");
+ assert(!SCCs.empty() && "Can't have an empty SCC!");
+
+ // Verify basic properties of the SCCs.
+ SmallPtrSet<SCC *, 4> SCCSet;
+ for (SCC *C : SCCs) {
+ assert(C && "Can't have a null SCC!");
+ C->verify();
+ assert(&C->getOuterRefSCC() == this &&
+ "SCC doesn't think it is inside this RefSCC!");
+ bool Inserted = SCCSet.insert(C).second;
+ assert(Inserted && "Found a duplicate SCC!");
+ auto IndexIt = SCCIndices.find(C);
+ assert(IndexIt != SCCIndices.end() &&
+ "Found an SCC that doesn't have an index!");
+ }
+
+ // Check that our indices map correctly.
+ for (auto &SCCIndexPair : SCCIndices) {
+ SCC *C = SCCIndexPair.first;
+ int i = SCCIndexPair.second;
+ assert(C && "Can't have a null SCC in the indices!");
+ assert(SCCSet.count(C) && "Found an index for an SCC not in the RefSCC!");
+ assert(SCCs[i] == C && "Index doesn't point to SCC!");
+ }
+
+ // Check that the SCCs are in fact in post-order.
+ for (int i = 0, Size = SCCs.size(); i < Size; ++i) {
+ SCC &SourceSCC = *SCCs[i];
+ for (Node &N : SourceSCC)
+ for (Edge &E : *N) {
+ if (!E.isCall())
+ continue;
+ SCC &TargetSCC = *G->lookupSCC(E.getNode());
+ if (&TargetSCC.getOuterRefSCC() == this) {
+ assert(SCCIndices.find(&TargetSCC)->second <= i &&
+ "Edge between SCCs violates post-order relationship.");
+ continue;
+ }
+ }
+ }
+}
+#endif
+
+bool LazyCallGraph::RefSCC::isParentOf(const RefSCC &RC) const {
+ if (&RC == this)
+ return false;
+
+ // Search all edges to see if this is a parent.
+ for (SCC &C : *this)
+ for (Node &N : C)
+ for (Edge &E : *N)
+ if (G->lookupRefSCC(E.getNode()) == &RC)
+ return true;
+
+ return false;
+}
+
+bool LazyCallGraph::RefSCC::isAncestorOf(const RefSCC &RC) const {
+ if (&RC == this)
+ return false;
+
+ // For each descendant of this RefSCC, see if one of its children is the
+ // argument. If not, add that descendant to the worklist and continue
+ // searching.
+ SmallVector<const RefSCC *, 4> Worklist;
+ SmallPtrSet<const RefSCC *, 4> Visited;
+ Worklist.push_back(this);
+ Visited.insert(this);
+ do {
+ const RefSCC &DescendantRC = *Worklist.pop_back_val();
+ for (SCC &C : DescendantRC)
+ for (Node &N : C)
+ for (Edge &E : *N) {
+ auto *ChildRC = G->lookupRefSCC(E.getNode());
+ if (ChildRC == &RC)
+ return true;
+ if (!ChildRC || !Visited.insert(ChildRC).second)
+ continue;
+ Worklist.push_back(ChildRC);
+ }
+ } while (!Worklist.empty());
+
+ return false;
+}
+
+/// Generic helper that updates a postorder sequence of SCCs for a potentially
+/// cycle-introducing edge insertion.
+///
+/// A postorder sequence of SCCs of a directed graph has one fundamental
+/// property: all deges in the DAG of SCCs point "up" the sequence. That is,
+/// all edges in the SCC DAG point to prior SCCs in the sequence.
+///
+/// This routine both updates a postorder sequence and uses that sequence to
+/// compute the set of SCCs connected into a cycle. It should only be called to
+/// insert a "downward" edge which will require changing the sequence to
+/// restore it to a postorder.
+///
+/// When inserting an edge from an earlier SCC to a later SCC in some postorder
+/// sequence, all of the SCCs which may be impacted are in the closed range of
+/// those two within the postorder sequence. The algorithm used here to restore
+/// the state is as follows:
+///
+/// 1) Starting from the source SCC, construct a set of SCCs which reach the
+/// source SCC consisting of just the source SCC. Then scan toward the
+/// target SCC in postorder and for each SCC, if it has an edge to an SCC
+/// in the set, add it to the set. Otherwise, the source SCC is not
+/// a successor, move it in the postorder sequence to immediately before
+/// the source SCC, shifting the source SCC and all SCCs in the set one
+/// position toward the target SCC. Stop scanning after processing the
+/// target SCC.
+/// 2) If the source SCC is now past the target SCC in the postorder sequence,
+/// and thus the new edge will flow toward the start, we are done.
+/// 3) Otherwise, starting from the target SCC, walk all edges which reach an
+/// SCC between the source and the target, and add them to the set of
+/// connected SCCs, then recurse through them. Once a complete set of the
+/// SCCs the target connects to is known, hoist the remaining SCCs between
+/// the source and the target to be above the target. Note that there is no
+/// need to process the source SCC, it is already known to connect.
+/// 4) At this point, all of the SCCs in the closed range between the source
+/// SCC and the target SCC in the postorder sequence are connected,
+/// including the target SCC and the source SCC. Inserting the edge from
+/// the source SCC to the target SCC will form a cycle out of precisely
+/// these SCCs. Thus we can merge all of the SCCs in this closed range into
+/// a single SCC.
+///
+/// This process has various important properties:
+/// - Only mutates the SCCs when adding the edge actually changes the SCC
+/// structure.
+/// - Never mutates SCCs which are unaffected by the change.
+/// - Updates the postorder sequence to correctly satisfy the postorder
+/// constraint after the edge is inserted.
+/// - Only reorders SCCs in the closed postorder sequence from the source to
+/// the target, so easy to bound how much has changed even in the ordering.
+/// - Big-O is the number of edges in the closed postorder range of SCCs from
+/// source to target.
+///
+/// This helper routine, in addition to updating the postorder sequence itself
+/// will also update a map from SCCs to indices within that sequence.
+///
+/// The sequence and the map must operate on pointers to the SCC type.
+///
+/// Two callbacks must be provided. The first computes the subset of SCCs in
+/// the postorder closed range from the source to the target which connect to
+/// the source SCC via some (transitive) set of edges. The second computes the
+/// subset of the same range which the target SCC connects to via some
+/// (transitive) set of edges. Both callbacks should populate the set argument
+/// provided.
+template <typename SCCT, typename PostorderSequenceT, typename SCCIndexMapT,
+ typename ComputeSourceConnectedSetCallableT,
+ typename ComputeTargetConnectedSetCallableT>
+static iterator_range<typename PostorderSequenceT::iterator>
+updatePostorderSequenceForEdgeInsertion(
+ SCCT &SourceSCC, SCCT &TargetSCC, PostorderSequenceT &SCCs,
+ SCCIndexMapT &SCCIndices,
+ ComputeSourceConnectedSetCallableT ComputeSourceConnectedSet,
+ ComputeTargetConnectedSetCallableT ComputeTargetConnectedSet) {
+ int SourceIdx = SCCIndices[&SourceSCC];
+ int TargetIdx = SCCIndices[&TargetSCC];
+ assert(SourceIdx < TargetIdx && "Cannot have equal indices here!");
+
+ SmallPtrSet<SCCT *, 4> ConnectedSet;
+
+ // Compute the SCCs which (transitively) reach the source.
+ ComputeSourceConnectedSet(ConnectedSet);
+
+ // Partition the SCCs in this part of the port-order sequence so only SCCs
+ // connecting to the source remain between it and the target. This is
+ // a benign partition as it preserves postorder.
+ auto SourceI = std::stable_partition(
+ SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx + 1,
+ [&ConnectedSet](SCCT *C) { return !ConnectedSet.count(C); });
+ for (int i = SourceIdx, e = TargetIdx + 1; i < e; ++i)
+ SCCIndices.find(SCCs[i])->second = i;
+
+ // If the target doesn't connect to the source, then we've corrected the
+ // post-order and there are no cycles formed.
+ if (!ConnectedSet.count(&TargetSCC)) {
+ assert(SourceI > (SCCs.begin() + SourceIdx) &&
+ "Must have moved the source to fix the post-order.");
+ assert(*std::prev(SourceI) == &TargetSCC &&
+ "Last SCC to move should have bene the target.");
+
+ // Return an empty range at the target SCC indicating there is nothing to
+ // merge.
+ return make_range(std::prev(SourceI), std::prev(SourceI));
+ }
+
+ assert(SCCs[TargetIdx] == &TargetSCC &&
+ "Should not have moved target if connected!");
+ SourceIdx = SourceI - SCCs.begin();
+ assert(SCCs[SourceIdx] == &SourceSCC &&
+ "Bad updated index computation for the source SCC!");
+
+
+ // See whether there are any remaining intervening SCCs between the source
+ // and target. If so we need to make sure they all are reachable form the
+ // target.
+ if (SourceIdx + 1 < TargetIdx) {
+ ConnectedSet.clear();
+ ComputeTargetConnectedSet(ConnectedSet);
+
+ // Partition SCCs so that only SCCs reached from the target remain between
+ // the source and the target. This preserves postorder.
+ auto TargetI = std::stable_partition(
+ SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1,
+ [&ConnectedSet](SCCT *C) { return ConnectedSet.count(C); });
+ for (int i = SourceIdx + 1, e = TargetIdx + 1; i < e; ++i)
+ SCCIndices.find(SCCs[i])->second = i;
+ TargetIdx = std::prev(TargetI) - SCCs.begin();
+ assert(SCCs[TargetIdx] == &TargetSCC &&
+ "Should always end with the target!");
+ }
+
+ // At this point, we know that connecting source to target forms a cycle
+ // because target connects back to source, and we know that all of the SCCs
+ // between the source and target in the postorder sequence participate in that
+ // cycle.
+ return make_range(SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx);
+}
+
+bool
+LazyCallGraph::RefSCC::switchInternalEdgeToCall(
+ Node &SourceN, Node &TargetN,
+ function_ref<void(ArrayRef<SCC *> MergeSCCs)> MergeCB) {
+ assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!");
+ SmallVector<SCC *, 1> DeletedSCCs;
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and when this
+ // routine finishes.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() { verify(); });
+#endif
+
+ SCC &SourceSCC = *G->lookupSCC(SourceN);
+ SCC &TargetSCC = *G->lookupSCC(TargetN);
+
+ // If the two nodes are already part of the same SCC, we're also done as
+ // we've just added more connectivity.
+ if (&SourceSCC == &TargetSCC) {
+ SourceN->setEdgeKind(TargetN, Edge::Call);
+ return false; // No new cycle.
+ }
+
+ // At this point we leverage the postorder list of SCCs to detect when the
+ // insertion of an edge changes the SCC structure in any way.
+ //
+ // First and foremost, we can eliminate the need for any changes when the
+ // edge is toward the beginning of the postorder sequence because all edges
+ // flow in that direction already. Thus adding a new one cannot form a cycle.
+ int SourceIdx = SCCIndices[&SourceSCC];
+ int TargetIdx = SCCIndices[&TargetSCC];
+ if (TargetIdx < SourceIdx) {
+ SourceN->setEdgeKind(TargetN, Edge::Call);
+ return false; // No new cycle.
+ }
+
+ // Compute the SCCs which (transitively) reach the source.
+ auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) {
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid before computing this as the
+ // results will be nonsensical of we've broken its invariants.
+ verify();
+#endif
+ ConnectedSet.insert(&SourceSCC);
+ auto IsConnected = [&](SCC &C) {
+ for (Node &N : C)
+ for (Edge &E : N->calls())
+ if (ConnectedSet.count(G->lookupSCC(E.getNode())))
+ return true;
+
+ return false;
+ };
+
+ for (SCC *C :
+ make_range(SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1))
+ if (IsConnected(*C))
+ ConnectedSet.insert(C);
+ };
+
+ // Use a normal worklist to find which SCCs the target connects to. We still
+ // bound the search based on the range in the postorder list we care about,
+ // but because this is forward connectivity we just "recurse" through the
+ // edges.
+ auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) {
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid before computing this as the
+ // results will be nonsensical of we've broken its invariants.
+ verify();
+#endif
+ ConnectedSet.insert(&TargetSCC);
+ SmallVector<SCC *, 4> Worklist;
+ Worklist.push_back(&TargetSCC);
+ do {
+ SCC &C = *Worklist.pop_back_val();
+ for (Node &N : C)
+ for (Edge &E : *N) {
+ if (!E.isCall())
+ continue;
+ SCC &EdgeC = *G->lookupSCC(E.getNode());
+ if (&EdgeC.getOuterRefSCC() != this)
+ // Not in this RefSCC...
+ continue;
+ if (SCCIndices.find(&EdgeC)->second <= SourceIdx)
+ // Not in the postorder sequence between source and target.
+ continue;
+
+ if (ConnectedSet.insert(&EdgeC).second)
+ Worklist.push_back(&EdgeC);
+ }
+ } while (!Worklist.empty());
+ };
+
+ // Use a generic helper to update the postorder sequence of SCCs and return
+ // a range of any SCCs connected into a cycle by inserting this edge. This
+ // routine will also take care of updating the indices into the postorder
+ // sequence.
+ auto MergeRange = updatePostorderSequenceForEdgeInsertion(
+ SourceSCC, TargetSCC, SCCs, SCCIndices, ComputeSourceConnectedSet,
+ ComputeTargetConnectedSet);
+
+ // Run the user's callback on the merged SCCs before we actually merge them.
+ if (MergeCB)
+ MergeCB(makeArrayRef(MergeRange.begin(), MergeRange.end()));
+
+ // If the merge range is empty, then adding the edge didn't actually form any
+ // new cycles. We're done.
+ if (MergeRange.empty()) {
+ // Now that the SCC structure is finalized, flip the kind to call.
+ SourceN->setEdgeKind(TargetN, Edge::Call);
+ return false; // No new cycle.
+ }
+
+#ifndef NDEBUG
+ // Before merging, check that the RefSCC remains valid after all the
+ // postorder updates.
+ verify();
+#endif
+
+ // Otherwise we need to merge all of the SCCs in the cycle into a single
+ // result SCC.
+ //
+ // NB: We merge into the target because all of these functions were already
+ // reachable from the target, meaning any SCC-wide properties deduced about it
+ // other than the set of functions within it will not have changed.
+ for (SCC *C : MergeRange) {
+ assert(C != &TargetSCC &&
+ "We merge *into* the target and shouldn't process it here!");
+ SCCIndices.erase(C);
+ TargetSCC.Nodes.append(C->Nodes.begin(), C->Nodes.end());
+ for (Node *N : C->Nodes)
+ G->SCCMap[N] = &TargetSCC;
+ C->clear();
+ DeletedSCCs.push_back(C);
+ }
+
+ // Erase the merged SCCs from the list and update the indices of the
+ // remaining SCCs.
+ int IndexOffset = MergeRange.end() - MergeRange.begin();
+ auto EraseEnd = SCCs.erase(MergeRange.begin(), MergeRange.end());
+ for (SCC *C : make_range(EraseEnd, SCCs.end()))
+ SCCIndices[C] -= IndexOffset;
+
+ // Now that the SCC structure is finalized, flip the kind to call.
+ SourceN->setEdgeKind(TargetN, Edge::Call);
+
+ // And we're done, but we did form a new cycle.
+ return true;
+}
+
+void LazyCallGraph::RefSCC::switchTrivialInternalEdgeToRef(Node &SourceN,
+ Node &TargetN) {
+ assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!");
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and when this
+ // routine finishes.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() { verify(); });
+#endif
+
+ assert(G->lookupRefSCC(SourceN) == this &&
+ "Source must be in this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) == this &&
+ "Target must be in this RefSCC.");
+ assert(G->lookupSCC(SourceN) != G->lookupSCC(TargetN) &&
+ "Source and Target must be in separate SCCs for this to be trivial!");
+
+ // Set the edge kind.
+ SourceN->setEdgeKind(TargetN, Edge::Ref);
+}
+
+iterator_range<LazyCallGraph::RefSCC::iterator>
+LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) {
+ assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!");
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and when this
+ // routine finishes.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() { verify(); });
+#endif
+
+ assert(G->lookupRefSCC(SourceN) == this &&
+ "Source must be in this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) == this &&
+ "Target must be in this RefSCC.");
+
+ SCC &TargetSCC = *G->lookupSCC(TargetN);
+ assert(G->lookupSCC(SourceN) == &TargetSCC && "Source and Target must be in "
+ "the same SCC to require the "
+ "full CG update.");
+
+ // Set the edge kind.
+ SourceN->setEdgeKind(TargetN, Edge::Ref);
+
+ // Otherwise we are removing a call edge from a single SCC. This may break
+ // the cycle. In order to compute the new set of SCCs, we need to do a small
+ // DFS over the nodes within the SCC to form any sub-cycles that remain as
+ // distinct SCCs and compute a postorder over the resulting SCCs.
+ //
+ // However, we specially handle the target node. The target node is known to
+ // reach all other nodes in the original SCC by definition. This means that
+ // we want the old SCC to be replaced with an SCC containing that node as it
+ // will be the root of whatever SCC DAG results from the DFS. Assumptions
+ // about an SCC such as the set of functions called will continue to hold,
+ // etc.
+
+ SCC &OldSCC = TargetSCC;
+ SmallVector<std::pair<Node *, EdgeSequence::call_iterator>, 16> DFSStack;
+ SmallVector<Node *, 16> PendingSCCStack;
+ SmallVector<SCC *, 4> NewSCCs;
+
+ // Prepare the nodes for a fresh DFS.
+ SmallVector<Node *, 16> Worklist;
+ Worklist.swap(OldSCC.Nodes);
+ for (Node *N : Worklist) {
+ N->DFSNumber = N->LowLink = 0;
+ G->SCCMap.erase(N);
+ }
+
+ // Force the target node to be in the old SCC. This also enables us to take
+ // a very significant short-cut in the standard Tarjan walk to re-form SCCs
+ // below: whenever we build an edge that reaches the target node, we know
+ // that the target node eventually connects back to all other nodes in our
+ // walk. As a consequence, we can detect and handle participants in that
+ // cycle without walking all the edges that form this connection, and instead
+ // by relying on the fundamental guarantee coming into this operation (all
+ // nodes are reachable from the target due to previously forming an SCC).
+ TargetN.DFSNumber = TargetN.LowLink = -1;
+ OldSCC.Nodes.push_back(&TargetN);
+ G->SCCMap[&TargetN] = &OldSCC;
+
+ // Scan down the stack and DFS across the call edges.
+ for (Node *RootN : Worklist) {
+ assert(DFSStack.empty() &&
+ "Cannot begin a new root with a non-empty DFS stack!");
+ assert(PendingSCCStack.empty() &&
+ "Cannot begin a new root with pending nodes for an SCC!");
+
+ // Skip any nodes we've already reached in the DFS.
+ if (RootN->DFSNumber != 0) {
+ assert(RootN->DFSNumber == -1 &&
+ "Shouldn't have any mid-DFS root nodes!");
+ continue;
+ }
+
+ RootN->DFSNumber = RootN->LowLink = 1;
+ int NextDFSNumber = 2;
+
+ DFSStack.push_back({RootN, (*RootN)->call_begin()});
+ do {
+ Node *N;
+ EdgeSequence::call_iterator I;
+ std::tie(N, I) = DFSStack.pop_back_val();
+ auto E = (*N)->call_end();
+ while (I != E) {
+ Node &ChildN = I->getNode();
+ if (ChildN.DFSNumber == 0) {
+ // We haven't yet visited this child, so descend, pushing the current
+ // node onto the stack.
+ DFSStack.push_back({N, I});
+
+ assert(!G->SCCMap.count(&ChildN) &&
+ "Found a node with 0 DFS number but already in an SCC!");
+ ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++;
+ N = &ChildN;
+ I = (*N)->call_begin();
+ E = (*N)->call_end();
+ continue;
+ }
+
+ // Check for the child already being part of some component.
+ if (ChildN.DFSNumber == -1) {
+ if (G->lookupSCC(ChildN) == &OldSCC) {
+ // If the child is part of the old SCC, we know that it can reach
+ // every other node, so we have formed a cycle. Pull the entire DFS
+ // and pending stacks into it. See the comment above about setting
+ // up the old SCC for why we do this.
+ int OldSize = OldSCC.size();
+ OldSCC.Nodes.push_back(N);
+ OldSCC.Nodes.append(PendingSCCStack.begin(), PendingSCCStack.end());
+ PendingSCCStack.clear();
+ while (!DFSStack.empty())
+ OldSCC.Nodes.push_back(DFSStack.pop_back_val().first);
+ for (Node &N : make_range(OldSCC.begin() + OldSize, OldSCC.end())) {
+ N.DFSNumber = N.LowLink = -1;
+ G->SCCMap[&N] = &OldSCC;
+ }
+ N = nullptr;
+ break;
+ }
+
+ // If the child has already been added to some child component, it
+ // couldn't impact the low-link of this parent because it isn't
+ // connected, and thus its low-link isn't relevant so skip it.
+ ++I;
+ continue;
+ }
+
+ // Track the lowest linked child as the lowest link for this node.
+ assert(ChildN.LowLink > 0 && "Must have a positive low-link number!");
+ if (ChildN.LowLink < N->LowLink)
+ N->LowLink = ChildN.LowLink;
+
+ // Move to the next edge.
+ ++I;
+ }
+ if (!N)
+ // Cleared the DFS early, start another round.
+ break;
+
+ // We've finished processing N and its descendants, put it on our pending
+ // SCC stack to eventually get merged into an SCC of nodes.
+ PendingSCCStack.push_back(N);
+
+ // If this node is linked to some lower entry, continue walking up the
+ // stack.
+ if (N->LowLink != N->DFSNumber)
+ continue;
+
+ // Otherwise, we've completed an SCC. Append it to our post order list of
+ // SCCs.
+ int RootDFSNumber = N->DFSNumber;
+ // Find the range of the node stack by walking down until we pass the
+ // root DFS number.
+ auto SCCNodes = make_range(
+ PendingSCCStack.rbegin(),
+ find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) {
+ return N->DFSNumber < RootDFSNumber;
+ }));
+
+ // Form a new SCC out of these nodes and then clear them off our pending
+ // stack.
+ NewSCCs.push_back(G->createSCC(*this, SCCNodes));
+ for (Node &N : *NewSCCs.back()) {
+ N.DFSNumber = N.LowLink = -1;
+ G->SCCMap[&N] = NewSCCs.back();
+ }
+ PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end());
+ } while (!DFSStack.empty());
+ }
+
+ // Insert the remaining SCCs before the old one. The old SCC can reach all
+ // other SCCs we form because it contains the target node of the removed edge
+ // of the old SCC. This means that we will have edges into all of the new
+ // SCCs, which means the old one must come last for postorder.
+ int OldIdx = SCCIndices[&OldSCC];
+ SCCs.insert(SCCs.begin() + OldIdx, NewSCCs.begin(), NewSCCs.end());
+
+ // Update the mapping from SCC* to index to use the new SCC*s, and remove the
+ // old SCC from the mapping.
+ for (int Idx = OldIdx, Size = SCCs.size(); Idx < Size; ++Idx)
+ SCCIndices[SCCs[Idx]] = Idx;
+
+ return make_range(SCCs.begin() + OldIdx,
+ SCCs.begin() + OldIdx + NewSCCs.size());
+}
+
+void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN,
+ Node &TargetN) {
+ assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!");
+
+ assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) != this &&
+ "Target must not be in this RefSCC.");
+#ifdef EXPENSIVE_CHECKS
+ assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) &&
+ "Target must be a descendant of the Source.");
+#endif
+
+ // Edges between RefSCCs are the same regardless of call or ref, so we can
+ // just flip the edge here.
+ SourceN->setEdgeKind(TargetN, Edge::Call);
+
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid.
+ verify();
+#endif
+}
+
+void LazyCallGraph::RefSCC::switchOutgoingEdgeToRef(Node &SourceN,
+ Node &TargetN) {
+ assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!");
+
+ assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) != this &&
+ "Target must not be in this RefSCC.");
+#ifdef EXPENSIVE_CHECKS
+ assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) &&
+ "Target must be a descendant of the Source.");
+#endif
+
+ // Edges between RefSCCs are the same regardless of call or ref, so we can
+ // just flip the edge here.
+ SourceN->setEdgeKind(TargetN, Edge::Ref);
+
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid.
+ verify();
+#endif
+}
+
+void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN,
+ Node &TargetN) {
+ assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC.");
+
+ SourceN->insertEdgeInternal(TargetN, Edge::Ref);
+
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid.
+ verify();
+#endif
+}
+
+void LazyCallGraph::RefSCC::insertOutgoingEdge(Node &SourceN, Node &TargetN,
+ Edge::Kind EK) {
+ // First insert it into the caller.
+ SourceN->insertEdgeInternal(TargetN, EK);
+
+ assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC.");
+
+ assert(G->lookupRefSCC(TargetN) != this &&
+ "Target must not be in this RefSCC.");
+#ifdef EXPENSIVE_CHECKS
+ assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) &&
+ "Target must be a descendant of the Source.");
+#endif
+
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid.
+ verify();
+#endif
+}
+
+SmallVector<LazyCallGraph::RefSCC *, 1>
+LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) {
+ assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC.");
+ RefSCC &SourceC = *G->lookupRefSCC(SourceN);
+ assert(&SourceC != this && "Source must not be in this RefSCC.");
+#ifdef EXPENSIVE_CHECKS
+ assert(SourceC.isDescendantOf(*this) &&
+ "Source must be a descendant of the Target.");
+#endif
+
+ SmallVector<RefSCC *, 1> DeletedRefSCCs;
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and when this
+ // routine finishes.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() { verify(); });
+#endif
+
+ int SourceIdx = G->RefSCCIndices[&SourceC];
+ int TargetIdx = G->RefSCCIndices[this];
+ assert(SourceIdx < TargetIdx &&
+ "Postorder list doesn't see edge as incoming!");
+
+ // Compute the RefSCCs which (transitively) reach the source. We do this by
+ // working backwards from the source using the parent set in each RefSCC,
+ // skipping any RefSCCs that don't fall in the postorder range. This has the
+ // advantage of walking the sparser parent edge (in high fan-out graphs) but
+ // more importantly this removes examining all forward edges in all RefSCCs
+ // within the postorder range which aren't in fact connected. Only connected
+ // RefSCCs (and their edges) are visited here.
+ auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) {
+ Set.insert(&SourceC);
+ auto IsConnected = [&](RefSCC &RC) {
+ for (SCC &C : RC)
+ for (Node &N : C)
+ for (Edge &E : *N)
+ if (Set.count(G->lookupRefSCC(E.getNode())))
+ return true;
+
+ return false;
+ };
+
+ for (RefSCC *C : make_range(G->PostOrderRefSCCs.begin() + SourceIdx + 1,
+ G->PostOrderRefSCCs.begin() + TargetIdx + 1))
+ if (IsConnected(*C))
+ Set.insert(C);
+ };
+
+ // Use a normal worklist to find which SCCs the target connects to. We still
+ // bound the search based on the range in the postorder list we care about,
+ // but because this is forward connectivity we just "recurse" through the
+ // edges.
+ auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) {
+ Set.insert(this);
+ SmallVector<RefSCC *, 4> Worklist;
+ Worklist.push_back(this);
+ do {
+ RefSCC &RC = *Worklist.pop_back_val();
+ for (SCC &C : RC)
+ for (Node &N : C)
+ for (Edge &E : *N) {
+ RefSCC &EdgeRC = *G->lookupRefSCC(E.getNode());
+ if (G->getRefSCCIndex(EdgeRC) <= SourceIdx)
+ // Not in the postorder sequence between source and target.
+ continue;
+
+ if (Set.insert(&EdgeRC).second)
+ Worklist.push_back(&EdgeRC);
+ }
+ } while (!Worklist.empty());
+ };
+
+ // Use a generic helper to update the postorder sequence of RefSCCs and return
+ // a range of any RefSCCs connected into a cycle by inserting this edge. This
+ // routine will also take care of updating the indices into the postorder
+ // sequence.
+ iterator_range<SmallVectorImpl<RefSCC *>::iterator> MergeRange =
+ updatePostorderSequenceForEdgeInsertion(
+ SourceC, *this, G->PostOrderRefSCCs, G->RefSCCIndices,
+ ComputeSourceConnectedSet, ComputeTargetConnectedSet);
+
+ // Build a set so we can do fast tests for whether a RefSCC will end up as
+ // part of the merged RefSCC.
+ SmallPtrSet<RefSCC *, 16> MergeSet(MergeRange.begin(), MergeRange.end());
+
+ // This RefSCC will always be part of that set, so just insert it here.
+ MergeSet.insert(this);
+
+ // Now that we have identified all of the SCCs which need to be merged into
+ // a connected set with the inserted edge, merge all of them into this SCC.
+ SmallVector<SCC *, 16> MergedSCCs;
+ int SCCIndex = 0;
+ for (RefSCC *RC : MergeRange) {
+ assert(RC != this && "We're merging into the target RefSCC, so it "
+ "shouldn't be in the range.");
+
+ // Walk the inner SCCs to update their up-pointer and walk all the edges to
+ // update any parent sets.
+ // FIXME: We should try to find a way to avoid this (rather expensive) edge
+ // walk by updating the parent sets in some other manner.
+ for (SCC &InnerC : *RC) {
+ InnerC.OuterRefSCC = this;
+ SCCIndices[&InnerC] = SCCIndex++;
+ for (Node &N : InnerC)
+ G->SCCMap[&N] = &InnerC;
+ }
+
+ // Now merge in the SCCs. We can actually move here so try to reuse storage
+ // the first time through.
+ if (MergedSCCs.empty())
+ MergedSCCs = std::move(RC->SCCs);
+ else
+ MergedSCCs.append(RC->SCCs.begin(), RC->SCCs.end());
+ RC->SCCs.clear();
+ DeletedRefSCCs.push_back(RC);
+ }
+
+ // Append our original SCCs to the merged list and move it into place.
+ for (SCC &InnerC : *this)
+ SCCIndices[&InnerC] = SCCIndex++;
+ MergedSCCs.append(SCCs.begin(), SCCs.end());
+ SCCs = std::move(MergedSCCs);
+
+ // Remove the merged away RefSCCs from the post order sequence.
+ for (RefSCC *RC : MergeRange)
+ G->RefSCCIndices.erase(RC);
+ int IndexOffset = MergeRange.end() - MergeRange.begin();
+ auto EraseEnd =
+ G->PostOrderRefSCCs.erase(MergeRange.begin(), MergeRange.end());
+ for (RefSCC *RC : make_range(EraseEnd, G->PostOrderRefSCCs.end()))
+ G->RefSCCIndices[RC] -= IndexOffset;
+
+ // At this point we have a merged RefSCC with a post-order SCCs list, just
+ // connect the nodes to form the new edge.
+ SourceN->insertEdgeInternal(TargetN, Edge::Ref);
+
+ // We return the list of SCCs which were merged so that callers can
+ // invalidate any data they have associated with those SCCs. Note that these
+ // SCCs are no longer in an interesting state (they are totally empty) but
+ // the pointers will remain stable for the life of the graph itself.
+ return DeletedRefSCCs;
+}
+
+void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) {
+ assert(G->lookupRefSCC(SourceN) == this &&
+ "The source must be a member of this RefSCC.");
+ assert(G->lookupRefSCC(TargetN) != this &&
+ "The target must not be a member of this RefSCC");
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and when this
+ // routine finishes.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() { verify(); });
+#endif
+
+ // First remove it from the node.
+ bool Removed = SourceN->removeEdgeInternal(TargetN);
+ (void)Removed;
+ assert(Removed && "Target not in the edge set for this caller?");
+}
+
+SmallVector<LazyCallGraph::RefSCC *, 1>
+LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN,
+ ArrayRef<Node *> TargetNs) {
+ // We return a list of the resulting *new* RefSCCs in post-order.
+ SmallVector<RefSCC *, 1> Result;
+
+#ifndef NDEBUG
+ // In a debug build, verify the RefSCC is valid to start with and that either
+ // we return an empty list of result RefSCCs and this RefSCC remains valid,
+ // or we return new RefSCCs and this RefSCC is dead.
+ verify();
+ auto VerifyOnExit = make_scope_exit([&]() {
+ // If we didn't replace our RefSCC with new ones, check that this one
+ // remains valid.
+ if (G)
+ verify();
+ });
+#endif
+
+ // First remove the actual edges.
+ for (Node *TargetN : TargetNs) {
+ assert(!(*SourceN)[*TargetN].isCall() &&
+ "Cannot remove a call edge, it must first be made a ref edge");
+
+ bool Removed = SourceN->removeEdgeInternal(*TargetN);
+ (void)Removed;
+ assert(Removed && "Target not in the edge set for this caller?");
+ }
+
+ // Direct self references don't impact the ref graph at all.
+ if (llvm::all_of(TargetNs,
+ [&](Node *TargetN) { return &SourceN == TargetN; }))
+ return Result;
+
+ // If all targets are in the same SCC as the source, because no call edges
+ // were removed there is no RefSCC structure change.
+ SCC &SourceC = *G->lookupSCC(SourceN);
+ if (llvm::all_of(TargetNs, [&](Node *TargetN) {
+ return G->lookupSCC(*TargetN) == &SourceC;
+ }))
+ return Result;
+
+ // We build somewhat synthetic new RefSCCs by providing a postorder mapping
+ // for each inner SCC. We store these inside the low-link field of the nodes
+ // rather than associated with SCCs because this saves a round-trip through
+ // the node->SCC map and in the common case, SCCs are small. We will verify
+ // that we always give the same number to every node in the SCC such that
+ // these are equivalent.
+ int PostOrderNumber = 0;
+
+ // Reset all the other nodes to prepare for a DFS over them, and add them to
+ // our worklist.
+ SmallVector<Node *, 8> Worklist;
+ for (SCC *C : SCCs) {
+ for (Node &N : *C)
+ N.DFSNumber = N.LowLink = 0;
+
+ Worklist.append(C->Nodes.begin(), C->Nodes.end());
+ }
+
+ // Track the number of nodes in this RefSCC so that we can quickly recognize
+ // an important special case of the edge removal not breaking the cycle of
+ // this RefSCC.
+ const int NumRefSCCNodes = Worklist.size();
+
+ SmallVector<std::pair<Node *, EdgeSequence::iterator>, 4> DFSStack;
+ SmallVector<Node *, 4> PendingRefSCCStack;
+ do {
+ assert(DFSStack.empty() &&
+ "Cannot begin a new root with a non-empty DFS stack!");
+ assert(PendingRefSCCStack.empty() &&
+ "Cannot begin a new root with pending nodes for an SCC!");
+
+ Node *RootN = Worklist.pop_back_val();
+ // Skip any nodes we've already reached in the DFS.
+ if (RootN->DFSNumber != 0) {
+ assert(RootN->DFSNumber == -1 &&
+ "Shouldn't have any mid-DFS root nodes!");
+ continue;
+ }
+
+ RootN->DFSNumber = RootN->LowLink = 1;
+ int NextDFSNumber = 2;
+
+ DFSStack.push_back({RootN, (*RootN)->begin()});
+ do {
+ Node *N;
+ EdgeSequence::iterator I;
+ std::tie(N, I) = DFSStack.pop_back_val();
+ auto E = (*N)->end();
+
+ assert(N->DFSNumber != 0 && "We should always assign a DFS number "
+ "before processing a node.");
+
+ while (I != E) {
+ Node &ChildN = I->getNode();
+ if (ChildN.DFSNumber == 0) {
+ // Mark that we should start at this child when next this node is the
+ // top of the stack. We don't start at the next child to ensure this
+ // child's lowlink is reflected.
+ DFSStack.push_back({N, I});
+
+ // Continue, resetting to the child node.
+ ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++;
+ N = &ChildN;
+ I = ChildN->begin();
+ E = ChildN->end();
+ continue;
+ }
+ if (ChildN.DFSNumber == -1) {
+ // If this child isn't currently in this RefSCC, no need to process
+ // it.
+ ++I;
+ continue;
+ }
+
+ // Track the lowest link of the children, if any are still in the stack.
+ // Any child not on the stack will have a LowLink of -1.
+ assert(ChildN.LowLink != 0 &&
+ "Low-link must not be zero with a non-zero DFS number.");
+ if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink)
+ N->LowLink = ChildN.LowLink;
+ ++I;
+ }
+
+ // We've finished processing N and its descendants, put it on our pending
+ // stack to eventually get merged into a RefSCC.
+ PendingRefSCCStack.push_back(N);
+
+ // If this node is linked to some lower entry, continue walking up the
+ // stack.
+ if (N->LowLink != N->DFSNumber) {
+ assert(!DFSStack.empty() &&
+ "We never found a viable root for a RefSCC to pop off!");
+ continue;
+ }
+
+ // Otherwise, form a new RefSCC from the top of the pending node stack.
+ int RefSCCNumber = PostOrderNumber++;
+ int RootDFSNumber = N->DFSNumber;
+
+ // Find the range of the node stack by walking down until we pass the
+ // root DFS number. Update the DFS numbers and low link numbers in the
+ // process to avoid re-walking this list where possible.
+ auto StackRI = find_if(reverse(PendingRefSCCStack), [&](Node *N) {
+ if (N->DFSNumber < RootDFSNumber)
+ // We've found the bottom.
+ return true;
+
+ // Update this node and keep scanning.
+ N->DFSNumber = -1;
+ // Save the post-order number in the lowlink field so that we can use
+ // it to map SCCs into new RefSCCs after we finish the DFS.
+ N->LowLink = RefSCCNumber;
+ return false;
+ });
+ auto RefSCCNodes = make_range(StackRI.base(), PendingRefSCCStack.end());
+
+ // If we find a cycle containing all nodes originally in this RefSCC then
+ // the removal hasn't changed the structure at all. This is an important
+ // special case and we can directly exit the entire routine more
+ // efficiently as soon as we discover it.
+ if (llvm::size(RefSCCNodes) == NumRefSCCNodes) {
+ // Clear out the low link field as we won't need it.
+ for (Node *N : RefSCCNodes)
+ N->LowLink = -1;
+ // Return the empty result immediately.
+ return Result;
+ }
+
+ // We've already marked the nodes internally with the RefSCC number so
+ // just clear them off the stack and continue.
+ PendingRefSCCStack.erase(RefSCCNodes.begin(), PendingRefSCCStack.end());
+ } while (!DFSStack.empty());
+
+ assert(DFSStack.empty() && "Didn't flush the entire DFS stack!");
+ assert(PendingRefSCCStack.empty() && "Didn't flush all pending nodes!");
+ } while (!Worklist.empty());
+
+ assert(PostOrderNumber > 1 &&
+ "Should never finish the DFS when the existing RefSCC remains valid!");
+
+ // Otherwise we create a collection of new RefSCC nodes and build
+ // a radix-sort style map from postorder number to these new RefSCCs. We then
+ // append SCCs to each of these RefSCCs in the order they occurred in the
+ // original SCCs container.
+ for (int i = 0; i < PostOrderNumber; ++i)
+ Result.push_back(G->createRefSCC(*G));
+
+ // Insert the resulting postorder sequence into the global graph postorder
+ // sequence before the current RefSCC in that sequence, and then remove the
+ // current one.
+ //
+ // FIXME: It'd be nice to change the APIs so that we returned an iterator
+ // range over the global postorder sequence and generally use that sequence
+ // rather than building a separate result vector here.
+ int Idx = G->getRefSCCIndex(*this);
+ G->PostOrderRefSCCs.erase(G->PostOrderRefSCCs.begin() + Idx);
+ G->PostOrderRefSCCs.insert(G->PostOrderRefSCCs.begin() + Idx, Result.begin(),
+ Result.end());
+ for (int i : seq<int>(Idx, G->PostOrderRefSCCs.size()))
+ G->RefSCCIndices[G->PostOrderRefSCCs[i]] = i;
+
+ for (SCC *C : SCCs) {
+ // We store the SCC number in the node's low-link field above.
+ int SCCNumber = C->begin()->LowLink;
+ // Clear out all of the SCC's node's low-link fields now that we're done
+ // using them as side-storage.
+ for (Node &N : *C) {
+ assert(N.LowLink == SCCNumber &&
+ "Cannot have different numbers for nodes in the same SCC!");
+ N.LowLink = -1;
+ }
+
+ RefSCC &RC = *Result[SCCNumber];
+ int SCCIndex = RC.SCCs.size();
+ RC.SCCs.push_back(C);
+ RC.SCCIndices[C] = SCCIndex;
+ C->OuterRefSCC = &RC;
+ }
+
+ // Now that we've moved things into the new RefSCCs, clear out our current
+ // one.
+ G = nullptr;
+ SCCs.clear();
+ SCCIndices.clear();
+
+#ifndef NDEBUG
+ // Verify the new RefSCCs we've built.
+ for (RefSCC *RC : Result)
+ RC->verify();
+#endif
+
+ // Return the new list of SCCs.
+ return Result;
+}
+
+void LazyCallGraph::RefSCC::handleTrivialEdgeInsertion(Node &SourceN,
+ Node &TargetN) {
+ // The only trivial case that requires any graph updates is when we add new
+ // ref edge and may connect different RefSCCs along that path. This is only
+ // because of the parents set. Every other part of the graph remains constant
+ // after this edge insertion.
+ assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC.");
+ RefSCC &TargetRC = *G->lookupRefSCC(TargetN);
+ if (&TargetRC == this)
+ return;
+
+#ifdef EXPENSIVE_CHECKS
+ assert(TargetRC.isDescendantOf(*this) &&
+ "Target must be a descendant of the Source.");
+#endif
+}
+
+void LazyCallGraph::RefSCC::insertTrivialCallEdge(Node &SourceN,
+ Node &TargetN) {
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid when we finish.
+ auto ExitVerifier = make_scope_exit([this] { verify(); });
+
+#ifdef EXPENSIVE_CHECKS
+ // Check that we aren't breaking some invariants of the SCC graph. Note that
+ // this is quadratic in the number of edges in the call graph!
+ SCC &SourceC = *G->lookupSCC(SourceN);
+ SCC &TargetC = *G->lookupSCC(TargetN);
+ if (&SourceC != &TargetC)
+ assert(SourceC.isAncestorOf(TargetC) &&
+ "Call edge is not trivial in the SCC graph!");
+#endif // EXPENSIVE_CHECKS
+#endif // NDEBUG
+
+ // First insert it into the source or find the existing edge.
+ auto InsertResult =
+ SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()});
+ if (!InsertResult.second) {
+ // Already an edge, just update it.
+ Edge &E = SourceN->Edges[InsertResult.first->second];
+ if (E.isCall())
+ return; // Nothing to do!
+ E.setKind(Edge::Call);
+ } else {
+ // Create the new edge.
+ SourceN->Edges.emplace_back(TargetN, Edge::Call);
+ }
+
+ // Now that we have the edge, handle the graph fallout.
+ handleTrivialEdgeInsertion(SourceN, TargetN);
+}
+
+void LazyCallGraph::RefSCC::insertTrivialRefEdge(Node &SourceN, Node &TargetN) {
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid when we finish.
+ auto ExitVerifier = make_scope_exit([this] { verify(); });
+
+#ifdef EXPENSIVE_CHECKS
+ // Check that we aren't breaking some invariants of the RefSCC graph.
+ RefSCC &SourceRC = *G->lookupRefSCC(SourceN);
+ RefSCC &TargetRC = *G->lookupRefSCC(TargetN);
+ if (&SourceRC != &TargetRC)
+ assert(SourceRC.isAncestorOf(TargetRC) &&
+ "Ref edge is not trivial in the RefSCC graph!");
+#endif // EXPENSIVE_CHECKS
+#endif // NDEBUG
+
+ // First insert it into the source or find the existing edge.
+ auto InsertResult =
+ SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()});
+ if (!InsertResult.second)
+ // Already an edge, we're done.
+ return;
+
+ // Create the new edge.
+ SourceN->Edges.emplace_back(TargetN, Edge::Ref);
+
+ // Now that we have the edge, handle the graph fallout.
+ handleTrivialEdgeInsertion(SourceN, TargetN);
+}
+
+void LazyCallGraph::RefSCC::replaceNodeFunction(Node &N, Function &NewF) {
+ Function &OldF = N.getFunction();
+
+#ifndef NDEBUG
+ // Check that the RefSCC is still valid when we finish.
+ auto ExitVerifier = make_scope_exit([this] { verify(); });
+
+ assert(G->lookupRefSCC(N) == this &&
+ "Cannot replace the function of a node outside this RefSCC.");
+
+ assert(G->NodeMap.find(&NewF) == G->NodeMap.end() &&
+ "Must not have already walked the new function!'");
+
+ // It is important that this replacement not introduce graph changes so we
+ // insist that the caller has already removed every use of the original
+ // function and that all uses of the new function correspond to existing
+ // edges in the graph. The common and expected way to use this is when
+ // replacing the function itself in the IR without changing the call graph
+ // shape and just updating the analysis based on that.
+ assert(&OldF != &NewF && "Cannot replace a function with itself!");
+ assert(OldF.use_empty() &&
+ "Must have moved all uses from the old function to the new!");
+#endif
+
+ N.replaceFunction(NewF);
+
+ // Update various call graph maps.
+ G->NodeMap.erase(&OldF);
+ G->NodeMap[&NewF] = &N;
+}
+
+void LazyCallGraph::insertEdge(Node &SourceN, Node &TargetN, Edge::Kind EK) {
+ assert(SCCMap.empty() &&
+ "This method cannot be called after SCCs have been formed!");
+
+ return SourceN->insertEdgeInternal(TargetN, EK);
+}
+
+void LazyCallGraph::removeEdge(Node &SourceN, Node &TargetN) {
+ assert(SCCMap.empty() &&
+ "This method cannot be called after SCCs have been formed!");
+
+ bool Removed = SourceN->removeEdgeInternal(TargetN);
+ (void)Removed;
+ assert(Removed && "Target not in the edge set for this caller?");
+}
+
+void LazyCallGraph::removeDeadFunction(Function &F) {
+ // FIXME: This is unnecessarily restrictive. We should be able to remove
+ // functions which recursively call themselves.
+ assert(F.use_empty() &&
+ "This routine should only be called on trivially dead functions!");
+
+ // We shouldn't remove library functions as they are never really dead while
+ // the call graph is in use -- every function definition refers to them.
+ assert(!isLibFunction(F) &&
+ "Must not remove lib functions from the call graph!");
+
+ auto NI = NodeMap.find(&F);
+ if (NI == NodeMap.end())
+ // Not in the graph at all!
+ return;
+
+ Node &N = *NI->second;
+ NodeMap.erase(NI);
+
+ // Remove this from the entry edges if present.
+ EntryEdges.removeEdgeInternal(N);
+
+ if (SCCMap.empty()) {
+ // No SCCs have been formed, so removing this is fine and there is nothing
+ // else necessary at this point but clearing out the node.
+ N.clear();
+ return;
+ }
+
+ // Cannot remove a function which has yet to be visited in the DFS walk, so
+ // if we have a node at all then we must have an SCC and RefSCC.
+ auto CI = SCCMap.find(&N);
+ assert(CI != SCCMap.end() &&
+ "Tried to remove a node without an SCC after DFS walk started!");
+ SCC &C = *CI->second;
+ SCCMap.erase(CI);
+ RefSCC &RC = C.getOuterRefSCC();
+
+ // This node must be the only member of its SCC as it has no callers, and
+ // that SCC must be the only member of a RefSCC as it has no references.
+ // Validate these properties first.
+ assert(C.size() == 1 && "Dead functions must be in a singular SCC");
+ assert(RC.size() == 1 && "Dead functions must be in a singular RefSCC");
+
+ auto RCIndexI = RefSCCIndices.find(&RC);
+ int RCIndex = RCIndexI->second;
+ PostOrderRefSCCs.erase(PostOrderRefSCCs.begin() + RCIndex);
+ RefSCCIndices.erase(RCIndexI);
+ for (int i = RCIndex, Size = PostOrderRefSCCs.size(); i < Size; ++i)
+ RefSCCIndices[PostOrderRefSCCs[i]] = i;
+
+ // Finally clear out all the data structures from the node down through the
+ // components.
+ N.clear();
+ N.G = nullptr;
+ N.F = nullptr;
+ C.clear();
+ RC.clear();
+ RC.G = nullptr;
+
+ // Nothing to delete as all the objects are allocated in stable bump pointer
+ // allocators.
+}
+
+LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) {
+ return *new (MappedN = BPA.Allocate()) Node(*this, F);
+}
+
+void LazyCallGraph::updateGraphPtrs() {
+ // Walk the node map to update their graph pointers. While this iterates in
+ // an unstable order, the order has no effect so it remains correct.
+ for (auto &FunctionNodePair : NodeMap)
+ FunctionNodePair.second->G = this;
+
+ for (auto *RC : PostOrderRefSCCs)
+ RC->G = this;
+}
+
+template <typename RootsT, typename GetBeginT, typename GetEndT,
+ typename GetNodeT, typename FormSCCCallbackT>
+void LazyCallGraph::buildGenericSCCs(RootsT &&Roots, GetBeginT &&GetBegin,
+ GetEndT &&GetEnd, GetNodeT &&GetNode,
+ FormSCCCallbackT &&FormSCC) {
+ using EdgeItT = decltype(GetBegin(std::declval<Node &>()));
+
+ SmallVector<std::pair<Node *, EdgeItT>, 16> DFSStack;
+ SmallVector<Node *, 16> PendingSCCStack;
+
+ // Scan down the stack and DFS across the call edges.
+ for (Node *RootN : Roots) {
+ assert(DFSStack.empty() &&
+ "Cannot begin a new root with a non-empty DFS stack!");
+ assert(PendingSCCStack.empty() &&
+ "Cannot begin a new root with pending nodes for an SCC!");
+
+ // Skip any nodes we've already reached in the DFS.
+ if (RootN->DFSNumber != 0) {
+ assert(RootN->DFSNumber == -1 &&
+ "Shouldn't have any mid-DFS root nodes!");
+ continue;
+ }
+
+ RootN->DFSNumber = RootN->LowLink = 1;
+ int NextDFSNumber = 2;
+
+ DFSStack.push_back({RootN, GetBegin(*RootN)});
+ do {
+ Node *N;
+ EdgeItT I;
+ std::tie(N, I) = DFSStack.pop_back_val();
+ auto E = GetEnd(*N);
+ while (I != E) {
+ Node &ChildN = GetNode(I);
+ if (ChildN.DFSNumber == 0) {
+ // We haven't yet visited this child, so descend, pushing the current
+ // node onto the stack.
+ DFSStack.push_back({N, I});
+
+ ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++;
+ N = &ChildN;
+ I = GetBegin(*N);
+ E = GetEnd(*N);
+ continue;
+ }
+
+ // If the child has already been added to some child component, it
+ // couldn't impact the low-link of this parent because it isn't
+ // connected, and thus its low-link isn't relevant so skip it.
+ if (ChildN.DFSNumber == -1) {
+ ++I;
+ continue;
+ }
+
+ // Track the lowest linked child as the lowest link for this node.
+ assert(ChildN.LowLink > 0 && "Must have a positive low-link number!");
+ if (ChildN.LowLink < N->LowLink)
+ N->LowLink = ChildN.LowLink;
+
+ // Move to the next edge.
+ ++I;
+ }
+
+ // We've finished processing N and its descendants, put it on our pending
+ // SCC stack to eventually get merged into an SCC of nodes.
+ PendingSCCStack.push_back(N);
+
+ // If this node is linked to some lower entry, continue walking up the
+ // stack.
+ if (N->LowLink != N->DFSNumber)
+ continue;
+
+ // Otherwise, we've completed an SCC. Append it to our post order list of
+ // SCCs.
+ int RootDFSNumber = N->DFSNumber;
+ // Find the range of the node stack by walking down until we pass the
+ // root DFS number.
+ auto SCCNodes = make_range(
+ PendingSCCStack.rbegin(),
+ find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) {
+ return N->DFSNumber < RootDFSNumber;
+ }));
+ // Form a new SCC out of these nodes and then clear them off our pending
+ // stack.
+ FormSCC(SCCNodes);
+ PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end());
+ } while (!DFSStack.empty());
+ }
+}
+
+/// Build the internal SCCs for a RefSCC from a sequence of nodes.
+///
+/// Appends the SCCs to the provided vector and updates the map with their
+/// indices. Both the vector and map must be empty when passed into this
+/// routine.
+void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) {
+ assert(RC.SCCs.empty() && "Already built SCCs!");
+ assert(RC.SCCIndices.empty() && "Already mapped SCC indices!");
+
+ for (Node *N : Nodes) {
+ assert(N->LowLink >= (*Nodes.begin())->LowLink &&
+ "We cannot have a low link in an SCC lower than its root on the "
+ "stack!");
+
+ // This node will go into the next RefSCC, clear out its DFS and low link
+ // as we scan.
+ N->DFSNumber = N->LowLink = 0;
+ }
+
+ // Each RefSCC contains a DAG of the call SCCs. To build these, we do
+ // a direct walk of the call edges using Tarjan's algorithm. We reuse the
+ // internal storage as we won't need it for the outer graph's DFS any longer.
+ buildGenericSCCs(
+ Nodes, [](Node &N) { return N->call_begin(); },
+ [](Node &N) { return N->call_end(); },
+ [](EdgeSequence::call_iterator I) -> Node & { return I->getNode(); },
+ [this, &RC](node_stack_range Nodes) {
+ RC.SCCs.push_back(createSCC(RC, Nodes));
+ for (Node &N : *RC.SCCs.back()) {
+ N.DFSNumber = N.LowLink = -1;
+ SCCMap[&N] = RC.SCCs.back();
+ }
+ });
+
+ // Wire up the SCC indices.
+ for (int i = 0, Size = RC.SCCs.size(); i < Size; ++i)
+ RC.SCCIndices[RC.SCCs[i]] = i;
+}
+
+void LazyCallGraph::buildRefSCCs() {
+ if (EntryEdges.empty() || !PostOrderRefSCCs.empty())
+ // RefSCCs are either non-existent or already built!
+ return;
+
+ assert(RefSCCIndices.empty() && "Already mapped RefSCC indices!");
+
+ SmallVector<Node *, 16> Roots;
+ for (Edge &E : *this)
+ Roots.push_back(&E.getNode());
+
+ // The roots will be popped of a stack, so use reverse to get a less
+ // surprising order. This doesn't change any of the semantics anywhere.
+ std::reverse(Roots.begin(), Roots.end());
+
+ buildGenericSCCs(
+ Roots,
+ [](Node &N) {
+ // We need to populate each node as we begin to walk its edges.
+ N.populate();
+ return N->begin();
+ },
+ [](Node &N) { return N->end(); },
+ [](EdgeSequence::iterator I) -> Node & { return I->getNode(); },
+ [this](node_stack_range Nodes) {
+ RefSCC *NewRC = createRefSCC(*this);
+ buildSCCs(*NewRC, Nodes);
+
+ // Push the new node into the postorder list and remember its position
+ // in the index map.
+ bool Inserted =
+ RefSCCIndices.insert({NewRC, PostOrderRefSCCs.size()}).second;
+ (void)Inserted;
+ assert(Inserted && "Cannot already have this RefSCC in the index map!");
+ PostOrderRefSCCs.push_back(NewRC);
+#ifndef NDEBUG
+ NewRC->verify();
+#endif
+ });
+}
+
+AnalysisKey LazyCallGraphAnalysis::Key;
+
+LazyCallGraphPrinterPass::LazyCallGraphPrinterPass(raw_ostream &OS) : OS(OS) {}
+
+static void printNode(raw_ostream &OS, LazyCallGraph::Node &N) {
+ OS << " Edges in function: " << N.getFunction().getName() << "\n";
+ for (LazyCallGraph::Edge &E : N.populate())
+ OS << " " << (E.isCall() ? "call" : "ref ") << " -> "
+ << E.getFunction().getName() << "\n";
+
+ OS << "\n";
+}
+
+static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &C) {
+ OS << " SCC with " << C.size() << " functions:\n";
+
+ for (LazyCallGraph::Node &N : C)
+ OS << " " << N.getFunction().getName() << "\n";
+}
+
+static void printRefSCC(raw_ostream &OS, LazyCallGraph::RefSCC &C) {
+ OS << " RefSCC with " << C.size() << " call SCCs:\n";
+
+ for (LazyCallGraph::SCC &InnerC : C)
+ printSCC(OS, InnerC);
+
+ OS << "\n";
+}
+
+PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M);
+
+ OS << "Printing the call graph for module: " << M.getModuleIdentifier()
+ << "\n\n";
+
+ for (Function &F : M)
+ printNode(OS, G.get(F));
+
+ G.buildRefSCCs();
+ for (LazyCallGraph::RefSCC &C : G.postorder_ref_sccs())
+ printRefSCC(OS, C);
+
+ return PreservedAnalyses::all();
+}
+
+LazyCallGraphDOTPrinterPass::LazyCallGraphDOTPrinterPass(raw_ostream &OS)
+ : OS(OS) {}
+
+static void printNodeDOT(raw_ostream &OS, LazyCallGraph::Node &N) {
+ std::string Name = "\"" + DOT::EscapeString(N.getFunction().getName()) + "\"";
+
+ for (LazyCallGraph::Edge &E : N.populate()) {
+ OS << " " << Name << " -> \""
+ << DOT::EscapeString(E.getFunction().getName()) << "\"";
+ if (!E.isCall()) // It is a ref edge.
+ OS << " [style=dashed,label=\"ref\"]";
+ OS << ";\n";
+ }
+
+ OS << "\n";
+}
+
+PreservedAnalyses LazyCallGraphDOTPrinterPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M);
+
+ OS << "digraph \"" << DOT::EscapeString(M.getModuleIdentifier()) << "\" {\n";
+
+ for (Function &F : M)
+ printNodeDOT(OS, G.get(F));
+
+ OS << "}\n";
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
new file mode 100644
index 000000000000..96722f32e355
--- /dev/null
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -0,0 +1,2042 @@
+//===- LazyValueInfo.cpp - Value constraint analysis ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the interface for lazy computation of value constraint
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LazyValueInfo.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/ValueLattice.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/raw_ostream.h"
+#include <map>
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "lazy-value-info"
+
+// This is the number of worklist items we will process to try to discover an
+// answer for a given value.
+static const unsigned MaxProcessedPerValue = 500;
+
+char LazyValueInfoWrapperPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LazyValueInfoWrapperPass, "lazy-value-info",
+ "Lazy Value Information Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(LazyValueInfoWrapperPass, "lazy-value-info",
+ "Lazy Value Information Analysis", false, true)
+
+namespace llvm {
+ FunctionPass *createLazyValueInfoPass() { return new LazyValueInfoWrapperPass(); }
+}
+
+AnalysisKey LazyValueAnalysis::Key;
+
+/// Returns true if this lattice value represents at most one possible value.
+/// This is as precise as any lattice value can get while still representing
+/// reachable code.
+static bool hasSingleValue(const ValueLatticeElement &Val) {
+ if (Val.isConstantRange() &&
+ Val.getConstantRange().isSingleElement())
+ // Integer constants are single element ranges
+ return true;
+ if (Val.isConstant())
+ // Non integer constants
+ return true;
+ return false;
+}
+
+/// Combine two sets of facts about the same value into a single set of
+/// facts. Note that this method is not suitable for merging facts along
+/// different paths in a CFG; that's what the mergeIn function is for. This
+/// is for merging facts gathered about the same value at the same location
+/// through two independent means.
+/// Notes:
+/// * This method does not promise to return the most precise possible lattice
+/// value implied by A and B. It is allowed to return any lattice element
+/// which is at least as strong as *either* A or B (unless our facts
+/// conflict, see below).
+/// * Due to unreachable code, the intersection of two lattice values could be
+/// contradictory. If this happens, we return some valid lattice value so as
+/// not confuse the rest of LVI. Ideally, we'd always return Undefined, but
+/// we do not make this guarantee. TODO: This would be a useful enhancement.
+static ValueLatticeElement intersect(const ValueLatticeElement &A,
+ const ValueLatticeElement &B) {
+ // Undefined is the strongest state. It means the value is known to be along
+ // an unreachable path.
+ if (A.isUndefined())
+ return A;
+ if (B.isUndefined())
+ return B;
+
+ // If we gave up for one, but got a useable fact from the other, use it.
+ if (A.isOverdefined())
+ return B;
+ if (B.isOverdefined())
+ return A;
+
+ // Can't get any more precise than constants.
+ if (hasSingleValue(A))
+ return A;
+ if (hasSingleValue(B))
+ return B;
+
+ // Could be either constant range or not constant here.
+ if (!A.isConstantRange() || !B.isConstantRange()) {
+ // TODO: Arbitrary choice, could be improved
+ return A;
+ }
+
+ // Intersect two constant ranges
+ ConstantRange Range =
+ A.getConstantRange().intersectWith(B.getConstantRange());
+ // Note: An empty range is implicitly converted to overdefined internally.
+ // TODO: We could instead use Undefined here since we've proven a conflict
+ // and thus know this path must be unreachable.
+ return ValueLatticeElement::getRange(std::move(Range));
+}
+
+//===----------------------------------------------------------------------===//
+// LazyValueInfoCache Decl
+//===----------------------------------------------------------------------===//
+
+namespace {
+ /// A callback value handle updates the cache when values are erased.
+ class LazyValueInfoCache;
+ struct LVIValueHandle final : public CallbackVH {
+ // Needs to access getValPtr(), which is protected.
+ friend struct DenseMapInfo<LVIValueHandle>;
+
+ LazyValueInfoCache *Parent;
+
+ LVIValueHandle(Value *V, LazyValueInfoCache *P)
+ : CallbackVH(V), Parent(P) { }
+
+ void deleted() override;
+ void allUsesReplacedWith(Value *V) override {
+ deleted();
+ }
+ };
+} // end anonymous namespace
+
+namespace {
+ /// This is the cache kept by LazyValueInfo which
+ /// maintains information about queries across the clients' queries.
+ class LazyValueInfoCache {
+ /// This is all of the cached block information for exactly one Value*.
+ /// The entries are sorted by the BasicBlock* of the
+ /// entries, allowing us to do a lookup with a binary search.
+ /// Over-defined lattice values are recorded in OverDefinedCache to reduce
+ /// memory overhead.
+ struct ValueCacheEntryTy {
+ ValueCacheEntryTy(Value *V, LazyValueInfoCache *P) : Handle(V, P) {}
+ LVIValueHandle Handle;
+ SmallDenseMap<PoisoningVH<BasicBlock>, ValueLatticeElement, 4> BlockVals;
+ };
+
+ /// This tracks, on a per-block basis, the set of values that are
+ /// over-defined at the end of that block.
+ typedef DenseMap<PoisoningVH<BasicBlock>, SmallPtrSet<Value *, 4>>
+ OverDefinedCacheTy;
+ /// Keep track of all blocks that we have ever seen, so we
+ /// don't spend time removing unused blocks from our caches.
+ DenseSet<PoisoningVH<BasicBlock> > SeenBlocks;
+
+ /// This is all of the cached information for all values,
+ /// mapped from Value* to key information.
+ DenseMap<Value *, std::unique_ptr<ValueCacheEntryTy>> ValueCache;
+ OverDefinedCacheTy OverDefinedCache;
+
+
+ public:
+ void insertResult(Value *Val, BasicBlock *BB,
+ const ValueLatticeElement &Result) {
+ SeenBlocks.insert(BB);
+
+ // Insert over-defined values into their own cache to reduce memory
+ // overhead.
+ if (Result.isOverdefined())
+ OverDefinedCache[BB].insert(Val);
+ else {
+ auto It = ValueCache.find_as(Val);
+ if (It == ValueCache.end()) {
+ ValueCache[Val] = std::make_unique<ValueCacheEntryTy>(Val, this);
+ It = ValueCache.find_as(Val);
+ assert(It != ValueCache.end() && "Val was just added to the map!");
+ }
+ It->second->BlockVals[BB] = Result;
+ }
+ }
+
+ bool isOverdefined(Value *V, BasicBlock *BB) const {
+ auto ODI = OverDefinedCache.find(BB);
+
+ if (ODI == OverDefinedCache.end())
+ return false;
+
+ return ODI->second.count(V);
+ }
+
+ bool hasCachedValueInfo(Value *V, BasicBlock *BB) const {
+ if (isOverdefined(V, BB))
+ return true;
+
+ auto I = ValueCache.find_as(V);
+ if (I == ValueCache.end())
+ return false;
+
+ return I->second->BlockVals.count(BB);
+ }
+
+ ValueLatticeElement getCachedValueInfo(Value *V, BasicBlock *BB) const {
+ if (isOverdefined(V, BB))
+ return ValueLatticeElement::getOverdefined();
+
+ auto I = ValueCache.find_as(V);
+ if (I == ValueCache.end())
+ return ValueLatticeElement();
+ auto BBI = I->second->BlockVals.find(BB);
+ if (BBI == I->second->BlockVals.end())
+ return ValueLatticeElement();
+ return BBI->second;
+ }
+
+ /// clear - Empty the cache.
+ void clear() {
+ SeenBlocks.clear();
+ ValueCache.clear();
+ OverDefinedCache.clear();
+ }
+
+ /// Inform the cache that a given value has been deleted.
+ void eraseValue(Value *V);
+
+ /// This is part of the update interface to inform the cache
+ /// that a block has been deleted.
+ void eraseBlock(BasicBlock *BB);
+
+ /// Updates the cache to remove any influence an overdefined value in
+ /// OldSucc might have (unless also overdefined in NewSucc). This just
+ /// flushes elements from the cache and does not add any.
+ void threadEdgeImpl(BasicBlock *OldSucc,BasicBlock *NewSucc);
+
+ friend struct LVIValueHandle;
+ };
+}
+
+void LazyValueInfoCache::eraseValue(Value *V) {
+ for (auto I = OverDefinedCache.begin(), E = OverDefinedCache.end(); I != E;) {
+ // Copy and increment the iterator immediately so we can erase behind
+ // ourselves.
+ auto Iter = I++;
+ SmallPtrSetImpl<Value *> &ValueSet = Iter->second;
+ ValueSet.erase(V);
+ if (ValueSet.empty())
+ OverDefinedCache.erase(Iter);
+ }
+
+ ValueCache.erase(V);
+}
+
+void LVIValueHandle::deleted() {
+ // This erasure deallocates *this, so it MUST happen after we're done
+ // using any and all members of *this.
+ Parent->eraseValue(*this);
+}
+
+void LazyValueInfoCache::eraseBlock(BasicBlock *BB) {
+ // Shortcut if we have never seen this block.
+ DenseSet<PoisoningVH<BasicBlock> >::iterator I = SeenBlocks.find(BB);
+ if (I == SeenBlocks.end())
+ return;
+ SeenBlocks.erase(I);
+
+ auto ODI = OverDefinedCache.find(BB);
+ if (ODI != OverDefinedCache.end())
+ OverDefinedCache.erase(ODI);
+
+ for (auto &I : ValueCache)
+ I.second->BlockVals.erase(BB);
+}
+
+void LazyValueInfoCache::threadEdgeImpl(BasicBlock *OldSucc,
+ BasicBlock *NewSucc) {
+ // When an edge in the graph has been threaded, values that we could not
+ // determine a value for before (i.e. were marked overdefined) may be
+ // possible to solve now. We do NOT try to proactively update these values.
+ // Instead, we clear their entries from the cache, and allow lazy updating to
+ // recompute them when needed.
+
+ // The updating process is fairly simple: we need to drop cached info
+ // for all values that were marked overdefined in OldSucc, and for those same
+ // values in any successor of OldSucc (except NewSucc) in which they were
+ // also marked overdefined.
+ std::vector<BasicBlock*> worklist;
+ worklist.push_back(OldSucc);
+
+ auto I = OverDefinedCache.find(OldSucc);
+ if (I == OverDefinedCache.end())
+ return; // Nothing to process here.
+ SmallVector<Value *, 4> ValsToClear(I->second.begin(), I->second.end());
+
+ // Use a worklist to perform a depth-first search of OldSucc's successors.
+ // NOTE: We do not need a visited list since any blocks we have already
+ // visited will have had their overdefined markers cleared already, and we
+ // thus won't loop to their successors.
+ while (!worklist.empty()) {
+ BasicBlock *ToUpdate = worklist.back();
+ worklist.pop_back();
+
+ // Skip blocks only accessible through NewSucc.
+ if (ToUpdate == NewSucc) continue;
+
+ // If a value was marked overdefined in OldSucc, and is here too...
+ auto OI = OverDefinedCache.find(ToUpdate);
+ if (OI == OverDefinedCache.end())
+ continue;
+ SmallPtrSetImpl<Value *> &ValueSet = OI->second;
+
+ bool changed = false;
+ for (Value *V : ValsToClear) {
+ if (!ValueSet.erase(V))
+ continue;
+
+ // If we removed anything, then we potentially need to update
+ // blocks successors too.
+ changed = true;
+
+ if (ValueSet.empty()) {
+ OverDefinedCache.erase(OI);
+ break;
+ }
+ }
+
+ if (!changed) continue;
+
+ worklist.insert(worklist.end(), succ_begin(ToUpdate), succ_end(ToUpdate));
+ }
+}
+
+
+namespace {
+/// An assembly annotator class to print LazyValueCache information in
+/// comments.
+class LazyValueInfoImpl;
+class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter {
+ LazyValueInfoImpl *LVIImpl;
+ // While analyzing which blocks we can solve values for, we need the dominator
+ // information. Since this is an optional parameter in LVI, we require this
+ // DomTreeAnalysis pass in the printer pass, and pass the dominator
+ // tree to the LazyValueInfoAnnotatedWriter.
+ DominatorTree &DT;
+
+public:
+ LazyValueInfoAnnotatedWriter(LazyValueInfoImpl *L, DominatorTree &DTree)
+ : LVIImpl(L), DT(DTree) {}
+
+ virtual void emitBasicBlockStartAnnot(const BasicBlock *BB,
+ formatted_raw_ostream &OS);
+
+ virtual void emitInstructionAnnot(const Instruction *I,
+ formatted_raw_ostream &OS);
+};
+}
+namespace {
+ // The actual implementation of the lazy analysis and update. Note that the
+ // inheritance from LazyValueInfoCache is intended to be temporary while
+ // splitting the code and then transitioning to a has-a relationship.
+ class LazyValueInfoImpl {
+
+ /// Cached results from previous queries
+ LazyValueInfoCache TheCache;
+
+ /// This stack holds the state of the value solver during a query.
+ /// It basically emulates the callstack of the naive
+ /// recursive value lookup process.
+ SmallVector<std::pair<BasicBlock*, Value*>, 8> BlockValueStack;
+
+ /// Keeps track of which block-value pairs are in BlockValueStack.
+ DenseSet<std::pair<BasicBlock*, Value*> > BlockValueSet;
+
+ /// Push BV onto BlockValueStack unless it's already in there.
+ /// Returns true on success.
+ bool pushBlockValue(const std::pair<BasicBlock *, Value *> &BV) {
+ if (!BlockValueSet.insert(BV).second)
+ return false; // It's already in the stack.
+
+ LLVM_DEBUG(dbgs() << "PUSH: " << *BV.second << " in "
+ << BV.first->getName() << "\n");
+ BlockValueStack.push_back(BV);
+ return true;
+ }
+
+ AssumptionCache *AC; ///< A pointer to the cache of @llvm.assume calls.
+ const DataLayout &DL; ///< A mandatory DataLayout
+ DominatorTree *DT; ///< An optional DT pointer.
+ DominatorTree *DisabledDT; ///< Stores DT if it's disabled.
+
+ ValueLatticeElement getBlockValue(Value *Val, BasicBlock *BB);
+ bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T,
+ ValueLatticeElement &Result, Instruction *CxtI = nullptr);
+ bool hasBlockValue(Value *Val, BasicBlock *BB);
+
+ // These methods process one work item and may add more. A false value
+ // returned means that the work item was not completely processed and must
+ // be revisited after going through the new items.
+ bool solveBlockValue(Value *Val, BasicBlock *BB);
+ bool solveBlockValueImpl(ValueLatticeElement &Res, Value *Val,
+ BasicBlock *BB);
+ bool solveBlockValueNonLocal(ValueLatticeElement &BBLV, Value *Val,
+ BasicBlock *BB);
+ bool solveBlockValuePHINode(ValueLatticeElement &BBLV, PHINode *PN,
+ BasicBlock *BB);
+ bool solveBlockValueSelect(ValueLatticeElement &BBLV, SelectInst *S,
+ BasicBlock *BB);
+ Optional<ConstantRange> getRangeForOperand(unsigned Op, Instruction *I,
+ BasicBlock *BB);
+ bool solveBlockValueBinaryOpImpl(
+ ValueLatticeElement &BBLV, Instruction *I, BasicBlock *BB,
+ std::function<ConstantRange(const ConstantRange &,
+ const ConstantRange &)> OpFn);
+ bool solveBlockValueBinaryOp(ValueLatticeElement &BBLV, BinaryOperator *BBI,
+ BasicBlock *BB);
+ bool solveBlockValueCast(ValueLatticeElement &BBLV, CastInst *CI,
+ BasicBlock *BB);
+ bool solveBlockValueOverflowIntrinsic(
+ ValueLatticeElement &BBLV, WithOverflowInst *WO, BasicBlock *BB);
+ bool solveBlockValueIntrinsic(ValueLatticeElement &BBLV, IntrinsicInst *II,
+ BasicBlock *BB);
+ bool solveBlockValueExtractValue(ValueLatticeElement &BBLV,
+ ExtractValueInst *EVI, BasicBlock *BB);
+ void intersectAssumeOrGuardBlockValueConstantRange(Value *Val,
+ ValueLatticeElement &BBLV,
+ Instruction *BBI);
+
+ void solve();
+
+ public:
+ /// This is the query interface to determine the lattice
+ /// value for the specified Value* at the end of the specified block.
+ ValueLatticeElement getValueInBlock(Value *V, BasicBlock *BB,
+ Instruction *CxtI = nullptr);
+
+ /// This is the query interface to determine the lattice
+ /// value for the specified Value* at the specified instruction (generally
+ /// from an assume intrinsic).
+ ValueLatticeElement getValueAt(Value *V, Instruction *CxtI);
+
+ /// This is the query interface to determine the lattice
+ /// value for the specified Value* that is true on the specified edge.
+ ValueLatticeElement getValueOnEdge(Value *V, BasicBlock *FromBB,
+ BasicBlock *ToBB,
+ Instruction *CxtI = nullptr);
+
+ /// Complete flush all previously computed values
+ void clear() {
+ TheCache.clear();
+ }
+
+ /// Printing the LazyValueInfo Analysis.
+ void printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) {
+ LazyValueInfoAnnotatedWriter Writer(this, DTree);
+ F.print(OS, &Writer);
+ }
+
+ /// This is part of the update interface to inform the cache
+ /// that a block has been deleted.
+ void eraseBlock(BasicBlock *BB) {
+ TheCache.eraseBlock(BB);
+ }
+
+ /// Disables use of the DominatorTree within LVI.
+ void disableDT() {
+ if (DT) {
+ assert(!DisabledDT && "Both DT and DisabledDT are not nullptr!");
+ std::swap(DT, DisabledDT);
+ }
+ }
+
+ /// Enables use of the DominatorTree within LVI. Does nothing if the class
+ /// instance was initialized without a DT pointer.
+ void enableDT() {
+ if (DisabledDT) {
+ assert(!DT && "Both DT and DisabledDT are not nullptr!");
+ std::swap(DT, DisabledDT);
+ }
+ }
+
+ /// This is the update interface to inform the cache that an edge from
+ /// PredBB to OldSucc has been threaded to be from PredBB to NewSucc.
+ void threadEdge(BasicBlock *PredBB,BasicBlock *OldSucc,BasicBlock *NewSucc);
+
+ LazyValueInfoImpl(AssumptionCache *AC, const DataLayout &DL,
+ DominatorTree *DT = nullptr)
+ : AC(AC), DL(DL), DT(DT), DisabledDT(nullptr) {}
+ };
+} // end anonymous namespace
+
+
+void LazyValueInfoImpl::solve() {
+ SmallVector<std::pair<BasicBlock *, Value *>, 8> StartingStack(
+ BlockValueStack.begin(), BlockValueStack.end());
+
+ unsigned processedCount = 0;
+ while (!BlockValueStack.empty()) {
+ processedCount++;
+ // Abort if we have to process too many values to get a result for this one.
+ // Because of the design of the overdefined cache currently being per-block
+ // to avoid naming-related issues (IE it wants to try to give different
+ // results for the same name in different blocks), overdefined results don't
+ // get cached globally, which in turn means we will often try to rediscover
+ // the same overdefined result again and again. Once something like
+ // PredicateInfo is used in LVI or CVP, we should be able to make the
+ // overdefined cache global, and remove this throttle.
+ if (processedCount > MaxProcessedPerValue) {
+ LLVM_DEBUG(
+ dbgs() << "Giving up on stack because we are getting too deep\n");
+ // Fill in the original values
+ while (!StartingStack.empty()) {
+ std::pair<BasicBlock *, Value *> &e = StartingStack.back();
+ TheCache.insertResult(e.second, e.first,
+ ValueLatticeElement::getOverdefined());
+ StartingStack.pop_back();
+ }
+ BlockValueSet.clear();
+ BlockValueStack.clear();
+ return;
+ }
+ std::pair<BasicBlock *, Value *> e = BlockValueStack.back();
+ assert(BlockValueSet.count(e) && "Stack value should be in BlockValueSet!");
+
+ if (solveBlockValue(e.second, e.first)) {
+ // The work item was completely processed.
+ assert(BlockValueStack.back() == e && "Nothing should have been pushed!");
+ assert(TheCache.hasCachedValueInfo(e.second, e.first) &&
+ "Result should be in cache!");
+
+ LLVM_DEBUG(
+ dbgs() << "POP " << *e.second << " in " << e.first->getName() << " = "
+ << TheCache.getCachedValueInfo(e.second, e.first) << "\n");
+
+ BlockValueStack.pop_back();
+ BlockValueSet.erase(e);
+ } else {
+ // More work needs to be done before revisiting.
+ assert(BlockValueStack.back() != e && "Stack should have been pushed!");
+ }
+ }
+}
+
+bool LazyValueInfoImpl::hasBlockValue(Value *Val, BasicBlock *BB) {
+ // If already a constant, there is nothing to compute.
+ if (isa<Constant>(Val))
+ return true;
+
+ return TheCache.hasCachedValueInfo(Val, BB);
+}
+
+ValueLatticeElement LazyValueInfoImpl::getBlockValue(Value *Val,
+ BasicBlock *BB) {
+ // If already a constant, there is nothing to compute.
+ if (Constant *VC = dyn_cast<Constant>(Val))
+ return ValueLatticeElement::get(VC);
+
+ return TheCache.getCachedValueInfo(Val, BB);
+}
+
+static ValueLatticeElement getFromRangeMetadata(Instruction *BBI) {
+ switch (BBI->getOpcode()) {
+ default: break;
+ case Instruction::Load:
+ case Instruction::Call:
+ case Instruction::Invoke:
+ if (MDNode *Ranges = BBI->getMetadata(LLVMContext::MD_range))
+ if (isa<IntegerType>(BBI->getType())) {
+ return ValueLatticeElement::getRange(
+ getConstantRangeFromMetadata(*Ranges));
+ }
+ break;
+ };
+ // Nothing known - will be intersected with other facts
+ return ValueLatticeElement::getOverdefined();
+}
+
+bool LazyValueInfoImpl::solveBlockValue(Value *Val, BasicBlock *BB) {
+ if (isa<Constant>(Val))
+ return true;
+
+ if (TheCache.hasCachedValueInfo(Val, BB)) {
+ // If we have a cached value, use that.
+ LLVM_DEBUG(dbgs() << " reuse BB '" << BB->getName() << "' val="
+ << TheCache.getCachedValueInfo(Val, BB) << '\n');
+
+ // Since we're reusing a cached value, we don't need to update the
+ // OverDefinedCache. The cache will have been properly updated whenever the
+ // cached value was inserted.
+ return true;
+ }
+
+ // Hold off inserting this value into the Cache in case we have to return
+ // false and come back later.
+ ValueLatticeElement Res;
+ if (!solveBlockValueImpl(Res, Val, BB))
+ // Work pushed, will revisit
+ return false;
+
+ TheCache.insertResult(Val, BB, Res);
+ return true;
+}
+
+bool LazyValueInfoImpl::solveBlockValueImpl(ValueLatticeElement &Res,
+ Value *Val, BasicBlock *BB) {
+
+ Instruction *BBI = dyn_cast<Instruction>(Val);
+ if (!BBI || BBI->getParent() != BB)
+ return solveBlockValueNonLocal(Res, Val, BB);
+
+ if (PHINode *PN = dyn_cast<PHINode>(BBI))
+ return solveBlockValuePHINode(Res, PN, BB);
+
+ if (auto *SI = dyn_cast<SelectInst>(BBI))
+ return solveBlockValueSelect(Res, SI, BB);
+
+ // If this value is a nonnull pointer, record it's range and bailout. Note
+ // that for all other pointer typed values, we terminate the search at the
+ // definition. We could easily extend this to look through geps, bitcasts,
+ // and the like to prove non-nullness, but it's not clear that's worth it
+ // compile time wise. The context-insensitive value walk done inside
+ // isKnownNonZero gets most of the profitable cases at much less expense.
+ // This does mean that we have a sensitivity to where the defining
+ // instruction is placed, even if it could legally be hoisted much higher.
+ // That is unfortunate.
+ PointerType *PT = dyn_cast<PointerType>(BBI->getType());
+ if (PT && isKnownNonZero(BBI, DL)) {
+ Res = ValueLatticeElement::getNot(ConstantPointerNull::get(PT));
+ return true;
+ }
+ if (BBI->getType()->isIntegerTy()) {
+ if (auto *CI = dyn_cast<CastInst>(BBI))
+ return solveBlockValueCast(Res, CI, BB);
+
+ if (BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI))
+ return solveBlockValueBinaryOp(Res, BO, BB);
+
+ if (auto *EVI = dyn_cast<ExtractValueInst>(BBI))
+ return solveBlockValueExtractValue(Res, EVI, BB);
+
+ if (auto *II = dyn_cast<IntrinsicInst>(BBI))
+ return solveBlockValueIntrinsic(Res, II, BB);
+ }
+
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - unknown inst def found.\n");
+ Res = getFromRangeMetadata(BBI);
+ return true;
+}
+
+static bool InstructionDereferencesPointer(Instruction *I, Value *Ptr) {
+ if (LoadInst *L = dyn_cast<LoadInst>(I)) {
+ return L->getPointerAddressSpace() == 0 &&
+ GetUnderlyingObject(L->getPointerOperand(),
+ L->getModule()->getDataLayout()) == Ptr;
+ }
+ if (StoreInst *S = dyn_cast<StoreInst>(I)) {
+ return S->getPointerAddressSpace() == 0 &&
+ GetUnderlyingObject(S->getPointerOperand(),
+ S->getModule()->getDataLayout()) == Ptr;
+ }
+ if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
+ if (MI->isVolatile()) return false;
+
+ // FIXME: check whether it has a valuerange that excludes zero?
+ ConstantInt *Len = dyn_cast<ConstantInt>(MI->getLength());
+ if (!Len || Len->isZero()) return false;
+
+ if (MI->getDestAddressSpace() == 0)
+ if (GetUnderlyingObject(MI->getRawDest(),
+ MI->getModule()->getDataLayout()) == Ptr)
+ return true;
+ if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI))
+ if (MTI->getSourceAddressSpace() == 0)
+ if (GetUnderlyingObject(MTI->getRawSource(),
+ MTI->getModule()->getDataLayout()) == Ptr)
+ return true;
+ }
+ return false;
+}
+
+/// Return true if the allocation associated with Val is ever dereferenced
+/// within the given basic block. This establishes the fact Val is not null,
+/// but does not imply that the memory at Val is dereferenceable. (Val may
+/// point off the end of the dereferenceable part of the object.)
+static bool isObjectDereferencedInBlock(Value *Val, BasicBlock *BB) {
+ assert(Val->getType()->isPointerTy());
+
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ Value *UnderlyingVal = GetUnderlyingObject(Val, DL);
+ // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge
+ // inside InstructionDereferencesPointer either.
+ if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, DL, 1))
+ for (Instruction &I : *BB)
+ if (InstructionDereferencesPointer(&I, UnderlyingVal))
+ return true;
+ return false;
+}
+
+bool LazyValueInfoImpl::solveBlockValueNonLocal(ValueLatticeElement &BBLV,
+ Value *Val, BasicBlock *BB) {
+ ValueLatticeElement Result; // Start Undefined.
+
+ // If this is the entry block, we must be asking about an argument. The
+ // value is overdefined.
+ if (BB == &BB->getParent()->getEntryBlock()) {
+ assert(isa<Argument>(Val) && "Unknown live-in to the entry block");
+ // Before giving up, see if we can prove the pointer non-null local to
+ // this particular block.
+ PointerType *PTy = dyn_cast<PointerType>(Val->getType());
+ if (PTy &&
+ (isKnownNonZero(Val, DL) ||
+ (isObjectDereferencedInBlock(Val, BB) &&
+ !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())))) {
+ Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy));
+ } else {
+ Result = ValueLatticeElement::getOverdefined();
+ }
+ BBLV = Result;
+ return true;
+ }
+
+ // Loop over all of our predecessors, merging what we know from them into
+ // result. If we encounter an unexplored predecessor, we eagerly explore it
+ // in a depth first manner. In practice, this has the effect of discovering
+ // paths we can't analyze eagerly without spending compile times analyzing
+ // other paths. This heuristic benefits from the fact that predecessors are
+ // frequently arranged such that dominating ones come first and we quickly
+ // find a path to function entry. TODO: We should consider explicitly
+ // canonicalizing to make this true rather than relying on this happy
+ // accident.
+ for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {
+ ValueLatticeElement EdgeResult;
+ if (!getEdgeValue(Val, *PI, BB, EdgeResult))
+ // Explore that input, then return here
+ return false;
+
+ Result.mergeIn(EdgeResult, DL);
+
+ // If we hit overdefined, exit early. The BlockVals entry is already set
+ // to overdefined.
+ if (Result.isOverdefined()) {
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined because of pred (non local).\n");
+ // Before giving up, see if we can prove the pointer non-null local to
+ // this particular block.
+ PointerType *PTy = dyn_cast<PointerType>(Val->getType());
+ if (PTy && isObjectDereferencedInBlock(Val, BB) &&
+ !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())) {
+ Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy));
+ }
+
+ BBLV = Result;
+ return true;
+ }
+ }
+
+ // Return the merged value, which is more precise than 'overdefined'.
+ assert(!Result.isOverdefined());
+ BBLV = Result;
+ return true;
+}
+
+bool LazyValueInfoImpl::solveBlockValuePHINode(ValueLatticeElement &BBLV,
+ PHINode *PN, BasicBlock *BB) {
+ ValueLatticeElement Result; // Start Undefined.
+
+ // Loop over all of our predecessors, merging what we know from them into
+ // result. See the comment about the chosen traversal order in
+ // solveBlockValueNonLocal; the same reasoning applies here.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ BasicBlock *PhiBB = PN->getIncomingBlock(i);
+ Value *PhiVal = PN->getIncomingValue(i);
+ ValueLatticeElement EdgeResult;
+ // Note that we can provide PN as the context value to getEdgeValue, even
+ // though the results will be cached, because PN is the value being used as
+ // the cache key in the caller.
+ if (!getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN))
+ // Explore that input, then return here
+ return false;
+
+ Result.mergeIn(EdgeResult, DL);
+
+ // If we hit overdefined, exit early. The BlockVals entry is already set
+ // to overdefined.
+ if (Result.isOverdefined()) {
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined because of pred (local).\n");
+
+ BBLV = Result;
+ return true;
+ }
+ }
+
+ // Return the merged value, which is more precise than 'overdefined'.
+ assert(!Result.isOverdefined() && "Possible PHI in entry block?");
+ BBLV = Result;
+ return true;
+}
+
+static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
+ bool isTrueDest = true);
+
+// If we can determine a constraint on the value given conditions assumed by
+// the program, intersect those constraints with BBLV
+void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
+ Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
+ BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
+ if (!BBI)
+ return;
+
+ for (auto &AssumeVH : AC->assumptionsFor(Val)) {
+ if (!AssumeVH)
+ continue;
+ auto *I = cast<CallInst>(AssumeVH);
+ if (!isValidAssumeForContext(I, BBI, DT))
+ continue;
+
+ BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
+ }
+
+ // If guards are not used in the module, don't spend time looking for them
+ auto *GuardDecl = BBI->getModule()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_guard));
+ if (!GuardDecl || GuardDecl->use_empty())
+ return;
+
+ if (BBI->getIterator() == BBI->getParent()->begin())
+ return;
+ for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
+ BBI->getParent()->rend())) {
+ Value *Cond = nullptr;
+ if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
+ BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
+ }
+}
+
+bool LazyValueInfoImpl::solveBlockValueSelect(ValueLatticeElement &BBLV,
+ SelectInst *SI, BasicBlock *BB) {
+
+ // Recurse on our inputs if needed
+ if (!hasBlockValue(SI->getTrueValue(), BB)) {
+ if (pushBlockValue(std::make_pair(BB, SI->getTrueValue())))
+ return false;
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+ ValueLatticeElement TrueVal = getBlockValue(SI->getTrueValue(), BB);
+ // If we hit overdefined, don't ask more queries. We want to avoid poisoning
+ // extra slots in the table if we can.
+ if (TrueVal.isOverdefined()) {
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+
+ if (!hasBlockValue(SI->getFalseValue(), BB)) {
+ if (pushBlockValue(std::make_pair(BB, SI->getFalseValue())))
+ return false;
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+ ValueLatticeElement FalseVal = getBlockValue(SI->getFalseValue(), BB);
+ // If we hit overdefined, don't ask more queries. We want to avoid poisoning
+ // extra slots in the table if we can.
+ if (FalseVal.isOverdefined()) {
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+
+ if (TrueVal.isConstantRange() && FalseVal.isConstantRange()) {
+ const ConstantRange &TrueCR = TrueVal.getConstantRange();
+ const ConstantRange &FalseCR = FalseVal.getConstantRange();
+ Value *LHS = nullptr;
+ Value *RHS = nullptr;
+ SelectPatternResult SPR = matchSelectPattern(SI, LHS, RHS);
+ // Is this a min specifically of our two inputs? (Avoid the risk of
+ // ValueTracking getting smarter looking back past our immediate inputs.)
+ if (SelectPatternResult::isMinOrMax(SPR.Flavor) &&
+ LHS == SI->getTrueValue() && RHS == SI->getFalseValue()) {
+ ConstantRange ResultCR = [&]() {
+ switch (SPR.Flavor) {
+ default:
+ llvm_unreachable("unexpected minmax type!");
+ case SPF_SMIN: /// Signed minimum
+ return TrueCR.smin(FalseCR);
+ case SPF_UMIN: /// Unsigned minimum
+ return TrueCR.umin(FalseCR);
+ case SPF_SMAX: /// Signed maximum
+ return TrueCR.smax(FalseCR);
+ case SPF_UMAX: /// Unsigned maximum
+ return TrueCR.umax(FalseCR);
+ };
+ }();
+ BBLV = ValueLatticeElement::getRange(ResultCR);
+ return true;
+ }
+
+ if (SPR.Flavor == SPF_ABS) {
+ if (LHS == SI->getTrueValue()) {
+ BBLV = ValueLatticeElement::getRange(TrueCR.abs());
+ return true;
+ }
+ if (LHS == SI->getFalseValue()) {
+ BBLV = ValueLatticeElement::getRange(FalseCR.abs());
+ return true;
+ }
+ }
+
+ if (SPR.Flavor == SPF_NABS) {
+ ConstantRange Zero(APInt::getNullValue(TrueCR.getBitWidth()));
+ if (LHS == SI->getTrueValue()) {
+ BBLV = ValueLatticeElement::getRange(Zero.sub(TrueCR.abs()));
+ return true;
+ }
+ if (LHS == SI->getFalseValue()) {
+ BBLV = ValueLatticeElement::getRange(Zero.sub(FalseCR.abs()));
+ return true;
+ }
+ }
+ }
+
+ // Can we constrain the facts about the true and false values by using the
+ // condition itself? This shows up with idioms like e.g. select(a > 5, a, 5).
+ // TODO: We could potentially refine an overdefined true value above.
+ Value *Cond = SI->getCondition();
+ TrueVal = intersect(TrueVal,
+ getValueFromCondition(SI->getTrueValue(), Cond, true));
+ FalseVal = intersect(FalseVal,
+ getValueFromCondition(SI->getFalseValue(), Cond, false));
+
+ // Handle clamp idioms such as:
+ // %24 = constantrange<0, 17>
+ // %39 = icmp eq i32 %24, 0
+ // %40 = add i32 %24, -1
+ // %siv.next = select i1 %39, i32 16, i32 %40
+ // %siv.next = constantrange<0, 17> not <-1, 17>
+ // In general, this can handle any clamp idiom which tests the edge
+ // condition via an equality or inequality.
+ if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *A = ICI->getOperand(0);
+ if (ConstantInt *CIBase = dyn_cast<ConstantInt>(ICI->getOperand(1))) {
+ auto addConstants = [](ConstantInt *A, ConstantInt *B) {
+ assert(A->getType() == B->getType());
+ return ConstantInt::get(A->getType(), A->getValue() + B->getValue());
+ };
+ // See if either input is A + C2, subject to the constraint from the
+ // condition that A != C when that input is used. We can assume that
+ // that input doesn't include C + C2.
+ ConstantInt *CIAdded;
+ switch (Pred) {
+ default: break;
+ case ICmpInst::ICMP_EQ:
+ if (match(SI->getFalseValue(), m_Add(m_Specific(A),
+ m_ConstantInt(CIAdded)))) {
+ auto ResNot = addConstants(CIBase, CIAdded);
+ FalseVal = intersect(FalseVal,
+ ValueLatticeElement::getNot(ResNot));
+ }
+ break;
+ case ICmpInst::ICMP_NE:
+ if (match(SI->getTrueValue(), m_Add(m_Specific(A),
+ m_ConstantInt(CIAdded)))) {
+ auto ResNot = addConstants(CIBase, CIAdded);
+ TrueVal = intersect(TrueVal,
+ ValueLatticeElement::getNot(ResNot));
+ }
+ break;
+ };
+ }
+ }
+
+ ValueLatticeElement Result; // Start Undefined.
+ Result.mergeIn(TrueVal, DL);
+ Result.mergeIn(FalseVal, DL);
+ BBLV = Result;
+ return true;
+}
+
+Optional<ConstantRange> LazyValueInfoImpl::getRangeForOperand(unsigned Op,
+ Instruction *I,
+ BasicBlock *BB) {
+ if (!hasBlockValue(I->getOperand(Op), BB))
+ if (pushBlockValue(std::make_pair(BB, I->getOperand(Op))))
+ return None;
+
+ const unsigned OperandBitWidth =
+ DL.getTypeSizeInBits(I->getOperand(Op)->getType());
+ ConstantRange Range = ConstantRange::getFull(OperandBitWidth);
+ if (hasBlockValue(I->getOperand(Op), BB)) {
+ ValueLatticeElement Val = getBlockValue(I->getOperand(Op), BB);
+ intersectAssumeOrGuardBlockValueConstantRange(I->getOperand(Op), Val, I);
+ if (Val.isConstantRange())
+ Range = Val.getConstantRange();
+ }
+ return Range;
+}
+
+bool LazyValueInfoImpl::solveBlockValueCast(ValueLatticeElement &BBLV,
+ CastInst *CI,
+ BasicBlock *BB) {
+ if (!CI->getOperand(0)->getType()->isSized()) {
+ // Without knowing how wide the input is, we can't analyze it in any useful
+ // way.
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+
+ // Filter out casts we don't know how to reason about before attempting to
+ // recurse on our operand. This can cut a long search short if we know we're
+ // not going to be able to get any useful information anways.
+ switch (CI->getOpcode()) {
+ case Instruction::Trunc:
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ case Instruction::BitCast:
+ break;
+ default:
+ // Unhandled instructions are overdefined.
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined (unknown cast).\n");
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+
+ // Figure out the range of the LHS. If that fails, we still apply the
+ // transfer rule on the full set since we may be able to locally infer
+ // interesting facts.
+ Optional<ConstantRange> LHSRes = getRangeForOperand(0, CI, BB);
+ if (!LHSRes.hasValue())
+ // More work to do before applying this transfer rule.
+ return false;
+ ConstantRange LHSRange = LHSRes.getValue();
+
+ const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth();
+
+ // NOTE: We're currently limited by the set of operations that ConstantRange
+ // can evaluate symbolically. Enhancing that set will allows us to analyze
+ // more definitions.
+ BBLV = ValueLatticeElement::getRange(LHSRange.castOp(CI->getOpcode(),
+ ResultBitWidth));
+ return true;
+}
+
+bool LazyValueInfoImpl::solveBlockValueBinaryOpImpl(
+ ValueLatticeElement &BBLV, Instruction *I, BasicBlock *BB,
+ std::function<ConstantRange(const ConstantRange &,
+ const ConstantRange &)> OpFn) {
+ // Figure out the ranges of the operands. If that fails, use a
+ // conservative range, but apply the transfer rule anyways. This
+ // lets us pick up facts from expressions like "and i32 (call i32
+ // @foo()), 32"
+ Optional<ConstantRange> LHSRes = getRangeForOperand(0, I, BB);
+ Optional<ConstantRange> RHSRes = getRangeForOperand(1, I, BB);
+ if (!LHSRes.hasValue() || !RHSRes.hasValue())
+ // More work to do before applying this transfer rule.
+ return false;
+
+ ConstantRange LHSRange = LHSRes.getValue();
+ ConstantRange RHSRange = RHSRes.getValue();
+ BBLV = ValueLatticeElement::getRange(OpFn(LHSRange, RHSRange));
+ return true;
+}
+
+bool LazyValueInfoImpl::solveBlockValueBinaryOp(ValueLatticeElement &BBLV,
+ BinaryOperator *BO,
+ BasicBlock *BB) {
+
+ assert(BO->getOperand(0)->getType()->isSized() &&
+ "all operands to binary operators are sized");
+ if (BO->getOpcode() == Instruction::Xor) {
+ // Xor is the only operation not supported by ConstantRange::binaryOp().
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined (unknown binary operator).\n");
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+
+ return solveBlockValueBinaryOpImpl(BBLV, BO, BB,
+ [BO](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.binaryOp(BO->getOpcode(), CR2);
+ });
+}
+
+bool LazyValueInfoImpl::solveBlockValueOverflowIntrinsic(
+ ValueLatticeElement &BBLV, WithOverflowInst *WO, BasicBlock *BB) {
+ return solveBlockValueBinaryOpImpl(BBLV, WO, BB,
+ [WO](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.binaryOp(WO->getBinaryOp(), CR2);
+ });
+}
+
+bool LazyValueInfoImpl::solveBlockValueIntrinsic(
+ ValueLatticeElement &BBLV, IntrinsicInst *II, BasicBlock *BB) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::uadd_sat:
+ return solveBlockValueBinaryOpImpl(BBLV, II, BB,
+ [](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.uadd_sat(CR2);
+ });
+ case Intrinsic::usub_sat:
+ return solveBlockValueBinaryOpImpl(BBLV, II, BB,
+ [](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.usub_sat(CR2);
+ });
+ case Intrinsic::sadd_sat:
+ return solveBlockValueBinaryOpImpl(BBLV, II, BB,
+ [](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.sadd_sat(CR2);
+ });
+ case Intrinsic::ssub_sat:
+ return solveBlockValueBinaryOpImpl(BBLV, II, BB,
+ [](const ConstantRange &CR1, const ConstantRange &CR2) {
+ return CR1.ssub_sat(CR2);
+ });
+ default:
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined (unknown intrinsic).\n");
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+}
+
+bool LazyValueInfoImpl::solveBlockValueExtractValue(
+ ValueLatticeElement &BBLV, ExtractValueInst *EVI, BasicBlock *BB) {
+ if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
+ if (EVI->getNumIndices() == 1 && *EVI->idx_begin() == 0)
+ return solveBlockValueOverflowIntrinsic(BBLV, WO, BB);
+
+ // Handle extractvalue of insertvalue to allow further simplification
+ // based on replaced with.overflow intrinsics.
+ if (Value *V = SimplifyExtractValueInst(
+ EVI->getAggregateOperand(), EVI->getIndices(),
+ EVI->getModule()->getDataLayout())) {
+ if (!hasBlockValue(V, BB)) {
+ if (pushBlockValue({ BB, V }))
+ return false;
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+ }
+ BBLV = getBlockValue(V, BB);
+ return true;
+ }
+
+ LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName()
+ << "' - overdefined (unknown extractvalue).\n");
+ BBLV = ValueLatticeElement::getOverdefined();
+ return true;
+}
+
+static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
+ bool isTrueDest) {
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+ CmpInst::Predicate Predicate = ICI->getPredicate();
+
+ if (isa<Constant>(RHS)) {
+ if (ICI->isEquality() && LHS == Val) {
+ // We know that V has the RHS constant if this is a true SETEQ or
+ // false SETNE.
+ if (isTrueDest == (Predicate == ICmpInst::ICMP_EQ))
+ return ValueLatticeElement::get(cast<Constant>(RHS));
+ else
+ return ValueLatticeElement::getNot(cast<Constant>(RHS));
+ }
+ }
+
+ if (!Val->getType()->isIntegerTy())
+ return ValueLatticeElement::getOverdefined();
+
+ // Use ConstantRange::makeAllowedICmpRegion in order to determine the possible
+ // range of Val guaranteed by the condition. Recognize comparisons in the from
+ // of:
+ // icmp <pred> Val, ...
+ // icmp <pred> (add Val, Offset), ...
+ // The latter is the range checking idiom that InstCombine produces. Subtract
+ // the offset from the allowed range for RHS in this case.
+
+ // Val or (add Val, Offset) can be on either hand of the comparison
+ if (LHS != Val && !match(LHS, m_Add(m_Specific(Val), m_ConstantInt()))) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+
+ ConstantInt *Offset = nullptr;
+ if (LHS != Val)
+ match(LHS, m_Add(m_Specific(Val), m_ConstantInt(Offset)));
+
+ if (LHS == Val || Offset) {
+ // Calculate the range of values that are allowed by the comparison
+ ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
+ /*isFullSet=*/true);
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
+ RHSRange = ConstantRange(CI->getValue());
+ else if (Instruction *I = dyn_cast<Instruction>(RHS))
+ if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
+ RHSRange = getConstantRangeFromMetadata(*Ranges);
+
+ // If we're interested in the false dest, invert the condition
+ CmpInst::Predicate Pred =
+ isTrueDest ? Predicate : CmpInst::getInversePredicate(Predicate);
+ ConstantRange TrueValues =
+ ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
+
+ if (Offset) // Apply the offset from above.
+ TrueValues = TrueValues.subtract(Offset->getValue());
+
+ return ValueLatticeElement::getRange(std::move(TrueValues));
+ }
+
+ return ValueLatticeElement::getOverdefined();
+}
+
+// Handle conditions of the form
+// extractvalue(op.with.overflow(%x, C), 1).
+static ValueLatticeElement getValueFromOverflowCondition(
+ Value *Val, WithOverflowInst *WO, bool IsTrueDest) {
+ // TODO: This only works with a constant RHS for now. We could also compute
+ // the range of the RHS, but this doesn't fit into the current structure of
+ // the edge value calculation.
+ const APInt *C;
+ if (WO->getLHS() != Val || !match(WO->getRHS(), m_APInt(C)))
+ return ValueLatticeElement::getOverdefined();
+
+ // Calculate the possible values of %x for which no overflow occurs.
+ ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
+ WO->getBinaryOp(), *C, WO->getNoWrapKind());
+
+ // If overflow is false, %x is constrained to NWR. If overflow is true, %x is
+ // constrained to it's inverse (all values that might cause overflow).
+ if (IsTrueDest)
+ NWR = NWR.inverse();
+ return ValueLatticeElement::getRange(NWR);
+}
+
+static ValueLatticeElement
+getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest,
+ DenseMap<Value*, ValueLatticeElement> &Visited);
+
+static ValueLatticeElement
+getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
+ DenseMap<Value*, ValueLatticeElement> &Visited) {
+ if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
+ return getValueFromICmpCondition(Val, ICI, isTrueDest);
+
+ if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
+ if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
+ if (EVI->getNumIndices() == 1 && *EVI->idx_begin() == 1)
+ return getValueFromOverflowCondition(Val, WO, isTrueDest);
+
+ // Handle conditions in the form of (cond1 && cond2), we know that on the
+ // true dest path both of the conditions hold. Similarly for conditions of
+ // the form (cond1 || cond2), we know that on the false dest path neither
+ // condition holds.
+ BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond);
+ if (!BO || (isTrueDest && BO->getOpcode() != BinaryOperator::And) ||
+ (!isTrueDest && BO->getOpcode() != BinaryOperator::Or))
+ return ValueLatticeElement::getOverdefined();
+
+ // Prevent infinite recursion if Cond references itself as in this example:
+ // Cond: "%tmp4 = and i1 %tmp4, undef"
+ // BL: "%tmp4 = and i1 %tmp4, undef"
+ // BR: "i1 undef"
+ Value *BL = BO->getOperand(0);
+ Value *BR = BO->getOperand(1);
+ if (BL == Cond || BR == Cond)
+ return ValueLatticeElement::getOverdefined();
+
+ return intersect(getValueFromCondition(Val, BL, isTrueDest, Visited),
+ getValueFromCondition(Val, BR, isTrueDest, Visited));
+}
+
+static ValueLatticeElement
+getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest,
+ DenseMap<Value*, ValueLatticeElement> &Visited) {
+ auto I = Visited.find(Cond);
+ if (I != Visited.end())
+ return I->second;
+
+ auto Result = getValueFromConditionImpl(Val, Cond, isTrueDest, Visited);
+ Visited[Cond] = Result;
+ return Result;
+}
+
+ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
+ bool isTrueDest) {
+ assert(Cond && "precondition");
+ DenseMap<Value*, ValueLatticeElement> Visited;
+ return getValueFromCondition(Val, Cond, isTrueDest, Visited);
+}
+
+// Return true if Usr has Op as an operand, otherwise false.
+static bool usesOperand(User *Usr, Value *Op) {
+ return find(Usr->operands(), Op) != Usr->op_end();
+}
+
+// Return true if the instruction type of Val is supported by
+// constantFoldUser(). Currently CastInst and BinaryOperator only. Call this
+// before calling constantFoldUser() to find out if it's even worth attempting
+// to call it.
+static bool isOperationFoldable(User *Usr) {
+ return isa<CastInst>(Usr) || isa<BinaryOperator>(Usr);
+}
+
+// Check if Usr can be simplified to an integer constant when the value of one
+// of its operands Op is an integer constant OpConstVal. If so, return it as an
+// lattice value range with a single element or otherwise return an overdefined
+// lattice value.
+static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
+ const APInt &OpConstVal,
+ const DataLayout &DL) {
+ assert(isOperationFoldable(Usr) && "Precondition");
+ Constant* OpConst = Constant::getIntegerValue(Op->getType(), OpConstVal);
+ // Check if Usr can be simplified to a constant.
+ if (auto *CI = dyn_cast<CastInst>(Usr)) {
+ assert(CI->getOperand(0) == Op && "Operand 0 isn't Op");
+ if (auto *C = dyn_cast_or_null<ConstantInt>(
+ SimplifyCastInst(CI->getOpcode(), OpConst,
+ CI->getDestTy(), DL))) {
+ return ValueLatticeElement::getRange(ConstantRange(C->getValue()));
+ }
+ } else if (auto *BO = dyn_cast<BinaryOperator>(Usr)) {
+ bool Op0Match = BO->getOperand(0) == Op;
+ bool Op1Match = BO->getOperand(1) == Op;
+ assert((Op0Match || Op1Match) &&
+ "Operand 0 nor Operand 1 isn't a match");
+ Value *LHS = Op0Match ? OpConst : BO->getOperand(0);
+ Value *RHS = Op1Match ? OpConst : BO->getOperand(1);
+ if (auto *C = dyn_cast_or_null<ConstantInt>(
+ SimplifyBinOp(BO->getOpcode(), LHS, RHS, DL))) {
+ return ValueLatticeElement::getRange(ConstantRange(C->getValue()));
+ }
+ }
+ return ValueLatticeElement::getOverdefined();
+}
+
+/// Compute the value of Val on the edge BBFrom -> BBTo. Returns false if
+/// Val is not constrained on the edge. Result is unspecified if return value
+/// is false.
+static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
+ BasicBlock *BBTo, ValueLatticeElement &Result) {
+ // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
+ // know that v != 0.
+ if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
+ // If this is a conditional branch and only one successor goes to BBTo, then
+ // we may be able to infer something from the condition.
+ if (BI->isConditional() &&
+ BI->getSuccessor(0) != BI->getSuccessor(1)) {
+ bool isTrueDest = BI->getSuccessor(0) == BBTo;
+ assert(BI->getSuccessor(!isTrueDest) == BBTo &&
+ "BBTo isn't a successor of BBFrom");
+ Value *Condition = BI->getCondition();
+
+ // If V is the condition of the branch itself, then we know exactly what
+ // it is.
+ if (Condition == Val) {
+ Result = ValueLatticeElement::get(ConstantInt::get(
+ Type::getInt1Ty(Val->getContext()), isTrueDest));
+ return true;
+ }
+
+ // If the condition of the branch is an equality comparison, we may be
+ // able to infer the value.
+ Result = getValueFromCondition(Val, Condition, isTrueDest);
+ if (!Result.isOverdefined())
+ return true;
+
+ if (User *Usr = dyn_cast<User>(Val)) {
+ assert(Result.isOverdefined() && "Result isn't overdefined");
+ // Check with isOperationFoldable() first to avoid linearly iterating
+ // over the operands unnecessarily which can be expensive for
+ // instructions with many operands.
+ if (isa<IntegerType>(Usr->getType()) && isOperationFoldable(Usr)) {
+ const DataLayout &DL = BBTo->getModule()->getDataLayout();
+ if (usesOperand(Usr, Condition)) {
+ // If Val has Condition as an operand and Val can be folded into a
+ // constant with either Condition == true or Condition == false,
+ // propagate the constant.
+ // eg.
+ // ; %Val is true on the edge to %then.
+ // %Val = and i1 %Condition, true.
+ // br %Condition, label %then, label %else
+ APInt ConditionVal(1, isTrueDest ? 1 : 0);
+ Result = constantFoldUser(Usr, Condition, ConditionVal, DL);
+ } else {
+ // If one of Val's operand has an inferred value, we may be able to
+ // infer the value of Val.
+ // eg.
+ // ; %Val is 94 on the edge to %then.
+ // %Val = add i8 %Op, 1
+ // %Condition = icmp eq i8 %Op, 93
+ // br i1 %Condition, label %then, label %else
+ for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
+ Value *Op = Usr->getOperand(i);
+ ValueLatticeElement OpLatticeVal =
+ getValueFromCondition(Op, Condition, isTrueDest);
+ if (Optional<APInt> OpConst = OpLatticeVal.asConstantInteger()) {
+ Result = constantFoldUser(Usr, Op, OpConst.getValue(), DL);
+ break;
+ }
+ }
+ }
+ }
+ }
+ if (!Result.isOverdefined())
+ return true;
+ }
+ }
+
+ // If the edge was formed by a switch on the value, then we may know exactly
+ // what it is.
+ if (SwitchInst *SI = dyn_cast<SwitchInst>(BBFrom->getTerminator())) {
+ Value *Condition = SI->getCondition();
+ if (!isa<IntegerType>(Val->getType()))
+ return false;
+ bool ValUsesConditionAndMayBeFoldable = false;
+ if (Condition != Val) {
+ // Check if Val has Condition as an operand.
+ if (User *Usr = dyn_cast<User>(Val))
+ ValUsesConditionAndMayBeFoldable = isOperationFoldable(Usr) &&
+ usesOperand(Usr, Condition);
+ if (!ValUsesConditionAndMayBeFoldable)
+ return false;
+ }
+ assert((Condition == Val || ValUsesConditionAndMayBeFoldable) &&
+ "Condition != Val nor Val doesn't use Condition");
+
+ bool DefaultCase = SI->getDefaultDest() == BBTo;
+ unsigned BitWidth = Val->getType()->getIntegerBitWidth();
+ ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/);
+
+ for (auto Case : SI->cases()) {
+ APInt CaseValue = Case.getCaseValue()->getValue();
+ ConstantRange EdgeVal(CaseValue);
+ if (ValUsesConditionAndMayBeFoldable) {
+ User *Usr = cast<User>(Val);
+ const DataLayout &DL = BBTo->getModule()->getDataLayout();
+ ValueLatticeElement EdgeLatticeVal =
+ constantFoldUser(Usr, Condition, CaseValue, DL);
+ if (EdgeLatticeVal.isOverdefined())
+ return false;
+ EdgeVal = EdgeLatticeVal.getConstantRange();
+ }
+ if (DefaultCase) {
+ // It is possible that the default destination is the destination of
+ // some cases. We cannot perform difference for those cases.
+ // We know Condition != CaseValue in BBTo. In some cases we can use
+ // this to infer Val == f(Condition) is != f(CaseValue). For now, we
+ // only do this when f is identity (i.e. Val == Condition), but we
+ // should be able to do this for any injective f.
+ if (Case.getCaseSuccessor() != BBTo && Condition == Val)
+ EdgesVals = EdgesVals.difference(EdgeVal);
+ } else if (Case.getCaseSuccessor() == BBTo)
+ EdgesVals = EdgesVals.unionWith(EdgeVal);
+ }
+ Result = ValueLatticeElement::getRange(std::move(EdgesVals));
+ return true;
+ }
+ return false;
+}
+
+/// Compute the value of Val on the edge BBFrom -> BBTo or the value at
+/// the basic block if the edge does not constrain Val.
+bool LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
+ BasicBlock *BBTo,
+ ValueLatticeElement &Result,
+ Instruction *CxtI) {
+ // If already a constant, there is nothing to compute.
+ if (Constant *VC = dyn_cast<Constant>(Val)) {
+ Result = ValueLatticeElement::get(VC);
+ return true;
+ }
+
+ ValueLatticeElement LocalResult;
+ if (!getEdgeValueLocal(Val, BBFrom, BBTo, LocalResult))
+ // If we couldn't constrain the value on the edge, LocalResult doesn't
+ // provide any information.
+ LocalResult = ValueLatticeElement::getOverdefined();
+
+ if (hasSingleValue(LocalResult)) {
+ // Can't get any more precise here
+ Result = LocalResult;
+ return true;
+ }
+
+ if (!hasBlockValue(Val, BBFrom)) {
+ if (pushBlockValue(std::make_pair(BBFrom, Val)))
+ return false;
+ // No new information.
+ Result = LocalResult;
+ return true;
+ }
+
+ // Try to intersect ranges of the BB and the constraint on the edge.
+ ValueLatticeElement InBlock = getBlockValue(Val, BBFrom);
+ intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock,
+ BBFrom->getTerminator());
+ // We can use the context instruction (generically the ultimate instruction
+ // the calling pass is trying to simplify) here, even though the result of
+ // this function is generally cached when called from the solve* functions
+ // (and that cached result might be used with queries using a different
+ // context instruction), because when this function is called from the solve*
+ // functions, the context instruction is not provided. When called from
+ // LazyValueInfoImpl::getValueOnEdge, the context instruction is provided,
+ // but then the result is not cached.
+ intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
+
+ Result = intersect(LocalResult, InBlock);
+ return true;
+}
+
+ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
+ Instruction *CxtI) {
+ LLVM_DEBUG(dbgs() << "LVI Getting block end value " << *V << " at '"
+ << BB->getName() << "'\n");
+
+ assert(BlockValueStack.empty() && BlockValueSet.empty());
+ if (!hasBlockValue(V, BB)) {
+ pushBlockValue(std::make_pair(BB, V));
+ solve();
+ }
+ ValueLatticeElement Result = getBlockValue(V, BB);
+ intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI);
+
+ LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");
+ return Result;
+}
+
+ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) {
+ LLVM_DEBUG(dbgs() << "LVI Getting value " << *V << " at '" << CxtI->getName()
+ << "'\n");
+
+ if (auto *C = dyn_cast<Constant>(V))
+ return ValueLatticeElement::get(C);
+
+ ValueLatticeElement Result = ValueLatticeElement::getOverdefined();
+ if (auto *I = dyn_cast<Instruction>(V))
+ Result = getFromRangeMetadata(I);
+ intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI);
+
+ LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");
+ return Result;
+}
+
+ValueLatticeElement LazyValueInfoImpl::
+getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
+ Instruction *CxtI) {
+ LLVM_DEBUG(dbgs() << "LVI Getting edge value " << *V << " from '"
+ << FromBB->getName() << "' to '" << ToBB->getName()
+ << "'\n");
+
+ ValueLatticeElement Result;
+ if (!getEdgeValue(V, FromBB, ToBB, Result, CxtI)) {
+ solve();
+ bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result, CxtI);
+ (void)WasFastQuery;
+ assert(WasFastQuery && "More work to do after problem solved?");
+ }
+
+ LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");
+ return Result;
+}
+
+void LazyValueInfoImpl::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc,
+ BasicBlock *NewSucc) {
+ TheCache.threadEdgeImpl(OldSucc, NewSucc);
+}
+
+//===----------------------------------------------------------------------===//
+// LazyValueInfo Impl
+//===----------------------------------------------------------------------===//
+
+/// This lazily constructs the LazyValueInfoImpl.
+static LazyValueInfoImpl &getImpl(void *&PImpl, AssumptionCache *AC,
+ const DataLayout *DL,
+ DominatorTree *DT = nullptr) {
+ if (!PImpl) {
+ assert(DL && "getCache() called with a null DataLayout");
+ PImpl = new LazyValueInfoImpl(AC, *DL, DT);
+ }
+ return *static_cast<LazyValueInfoImpl*>(PImpl);
+}
+
+bool LazyValueInfoWrapperPass::runOnFunction(Function &F) {
+ Info.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ const DataLayout &DL = F.getParent()->getDataLayout();
+
+ DominatorTreeWrapperPass *DTWP =
+ getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+ Info.DT = DTWP ? &DTWP->getDomTree() : nullptr;
+ Info.TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+
+ if (Info.PImpl)
+ getImpl(Info.PImpl, Info.AC, &DL, Info.DT).clear();
+
+ // Fully lazy.
+ return false;
+}
+
+void LazyValueInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+}
+
+LazyValueInfo &LazyValueInfoWrapperPass::getLVI() { return Info; }
+
+LazyValueInfo::~LazyValueInfo() { releaseMemory(); }
+
+void LazyValueInfo::releaseMemory() {
+ // If the cache was allocated, free it.
+ if (PImpl) {
+ delete &getImpl(PImpl, AC, nullptr);
+ PImpl = nullptr;
+ }
+}
+
+bool LazyValueInfo::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // We need to invalidate if we have either failed to preserve this analyses
+ // result directly or if any of its dependencies have been invalidated.
+ auto PAC = PA.getChecker<LazyValueAnalysis>();
+ if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
+ (DT && Inv.invalidate<DominatorTreeAnalysis>(F, PA)))
+ return true;
+
+ return false;
+}
+
+void LazyValueInfoWrapperPass::releaseMemory() { Info.releaseMemory(); }
+
+LazyValueInfo LazyValueAnalysis::run(Function &F,
+ FunctionAnalysisManager &FAM) {
+ auto &AC = FAM.getResult<AssumptionAnalysis>(F);
+ auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
+ auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F);
+
+ return LazyValueInfo(&AC, &F.getParent()->getDataLayout(), &TLI, DT);
+}
+
+/// Returns true if we can statically tell that this value will never be a
+/// "useful" constant. In practice, this means we've got something like an
+/// alloca or a malloc call for which a comparison against a constant can
+/// only be guarding dead code. Note that we are potentially giving up some
+/// precision in dead code (a constant result) in favour of avoiding a
+/// expensive search for a easily answered common query.
+static bool isKnownNonConstant(Value *V) {
+ V = V->stripPointerCasts();
+ // The return val of alloc cannot be a Constant.
+ if (isa<AllocaInst>(V))
+ return true;
+ return false;
+}
+
+Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB,
+ Instruction *CxtI) {
+ // Bail out early if V is known not to be a Constant.
+ if (isKnownNonConstant(V))
+ return nullptr;
+
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ ValueLatticeElement Result =
+ getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI);
+
+ if (Result.isConstant())
+ return Result.getConstant();
+ if (Result.isConstantRange()) {
+ const ConstantRange &CR = Result.getConstantRange();
+ if (const APInt *SingleVal = CR.getSingleElement())
+ return ConstantInt::get(V->getContext(), *SingleVal);
+ }
+ return nullptr;
+}
+
+ConstantRange LazyValueInfo::getConstantRange(Value *V, BasicBlock *BB,
+ Instruction *CxtI) {
+ assert(V->getType()->isIntegerTy());
+ unsigned Width = V->getType()->getIntegerBitWidth();
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ ValueLatticeElement Result =
+ getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI);
+ if (Result.isUndefined())
+ return ConstantRange::getEmpty(Width);
+ if (Result.isConstantRange())
+ return Result.getConstantRange();
+ // We represent ConstantInt constants as constant ranges but other kinds
+ // of integer constants, i.e. ConstantExpr will be tagged as constants
+ assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) &&
+ "ConstantInt value must be represented as constantrange");
+ return ConstantRange::getFull(Width);
+}
+
+/// Determine whether the specified value is known to be a
+/// constant on the specified edge. Return null if not.
+Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB,
+ BasicBlock *ToBB,
+ Instruction *CxtI) {
+ const DataLayout &DL = FromBB->getModule()->getDataLayout();
+ ValueLatticeElement Result =
+ getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI);
+
+ if (Result.isConstant())
+ return Result.getConstant();
+ if (Result.isConstantRange()) {
+ const ConstantRange &CR = Result.getConstantRange();
+ if (const APInt *SingleVal = CR.getSingleElement())
+ return ConstantInt::get(V->getContext(), *SingleVal);
+ }
+ return nullptr;
+}
+
+ConstantRange LazyValueInfo::getConstantRangeOnEdge(Value *V,
+ BasicBlock *FromBB,
+ BasicBlock *ToBB,
+ Instruction *CxtI) {
+ unsigned Width = V->getType()->getIntegerBitWidth();
+ const DataLayout &DL = FromBB->getModule()->getDataLayout();
+ ValueLatticeElement Result =
+ getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI);
+
+ if (Result.isUndefined())
+ return ConstantRange::getEmpty(Width);
+ if (Result.isConstantRange())
+ return Result.getConstantRange();
+ // We represent ConstantInt constants as constant ranges but other kinds
+ // of integer constants, i.e. ConstantExpr will be tagged as constants
+ assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) &&
+ "ConstantInt value must be represented as constantrange");
+ return ConstantRange::getFull(Width);
+}
+
+static LazyValueInfo::Tristate
+getPredicateResult(unsigned Pred, Constant *C, const ValueLatticeElement &Val,
+ const DataLayout &DL, TargetLibraryInfo *TLI) {
+ // If we know the value is a constant, evaluate the conditional.
+ Constant *Res = nullptr;
+ if (Val.isConstant()) {
+ Res = ConstantFoldCompareInstOperands(Pred, Val.getConstant(), C, DL, TLI);
+ if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res))
+ return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True;
+ return LazyValueInfo::Unknown;
+ }
+
+ if (Val.isConstantRange()) {
+ ConstantInt *CI = dyn_cast<ConstantInt>(C);
+ if (!CI) return LazyValueInfo::Unknown;
+
+ const ConstantRange &CR = Val.getConstantRange();
+ if (Pred == ICmpInst::ICMP_EQ) {
+ if (!CR.contains(CI->getValue()))
+ return LazyValueInfo::False;
+
+ if (CR.isSingleElement())
+ return LazyValueInfo::True;
+ } else if (Pred == ICmpInst::ICMP_NE) {
+ if (!CR.contains(CI->getValue()))
+ return LazyValueInfo::True;
+
+ if (CR.isSingleElement())
+ return LazyValueInfo::False;
+ } else {
+ // Handle more complex predicates.
+ ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
+ (ICmpInst::Predicate)Pred, CI->getValue());
+ if (TrueValues.contains(CR))
+ return LazyValueInfo::True;
+ if (TrueValues.inverse().contains(CR))
+ return LazyValueInfo::False;
+ }
+ return LazyValueInfo::Unknown;
+ }
+
+ if (Val.isNotConstant()) {
+ // If this is an equality comparison, we can try to fold it knowing that
+ // "V != C1".
+ if (Pred == ICmpInst::ICMP_EQ) {
+ // !C1 == C -> false iff C1 == C.
+ Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE,
+ Val.getNotConstant(), C, DL,
+ TLI);
+ if (Res->isNullValue())
+ return LazyValueInfo::False;
+ } else if (Pred == ICmpInst::ICMP_NE) {
+ // !C1 != C -> true iff C1 == C.
+ Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE,
+ Val.getNotConstant(), C, DL,
+ TLI);
+ if (Res->isNullValue())
+ return LazyValueInfo::True;
+ }
+ return LazyValueInfo::Unknown;
+ }
+
+ return LazyValueInfo::Unknown;
+}
+
+/// Determine whether the specified value comparison with a constant is known to
+/// be true or false on the specified CFG edge. Pred is a CmpInst predicate.
+LazyValueInfo::Tristate
+LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C,
+ BasicBlock *FromBB, BasicBlock *ToBB,
+ Instruction *CxtI) {
+ const DataLayout &DL = FromBB->getModule()->getDataLayout();
+ ValueLatticeElement Result =
+ getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI);
+
+ return getPredicateResult(Pred, C, Result, DL, TLI);
+}
+
+LazyValueInfo::Tristate
+LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C,
+ Instruction *CxtI) {
+ // Is or is not NonNull are common predicates being queried. If
+ // isKnownNonZero can tell us the result of the predicate, we can
+ // return it quickly. But this is only a fastpath, and falling
+ // through would still be correct.
+ const DataLayout &DL = CxtI->getModule()->getDataLayout();
+ if (V->getType()->isPointerTy() && C->isNullValue() &&
+ isKnownNonZero(V->stripPointerCastsSameRepresentation(), DL)) {
+ if (Pred == ICmpInst::ICMP_EQ)
+ return LazyValueInfo::False;
+ else if (Pred == ICmpInst::ICMP_NE)
+ return LazyValueInfo::True;
+ }
+ ValueLatticeElement Result = getImpl(PImpl, AC, &DL, DT).getValueAt(V, CxtI);
+ Tristate Ret = getPredicateResult(Pred, C, Result, DL, TLI);
+ if (Ret != Unknown)
+ return Ret;
+
+ // Note: The following bit of code is somewhat distinct from the rest of LVI;
+ // LVI as a whole tries to compute a lattice value which is conservatively
+ // correct at a given location. In this case, we have a predicate which we
+ // weren't able to prove about the merged result, and we're pushing that
+ // predicate back along each incoming edge to see if we can prove it
+ // separately for each input. As a motivating example, consider:
+ // bb1:
+ // %v1 = ... ; constantrange<1, 5>
+ // br label %merge
+ // bb2:
+ // %v2 = ... ; constantrange<10, 20>
+ // br label %merge
+ // merge:
+ // %phi = phi [%v1, %v2] ; constantrange<1,20>
+ // %pred = icmp eq i32 %phi, 8
+ // We can't tell from the lattice value for '%phi' that '%pred' is false
+ // along each path, but by checking the predicate over each input separately,
+ // we can.
+ // We limit the search to one step backwards from the current BB and value.
+ // We could consider extending this to search further backwards through the
+ // CFG and/or value graph, but there are non-obvious compile time vs quality
+ // tradeoffs.
+ if (CxtI) {
+ BasicBlock *BB = CxtI->getParent();
+
+ // Function entry or an unreachable block. Bail to avoid confusing
+ // analysis below.
+ pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
+ if (PI == PE)
+ return Unknown;
+
+ // If V is a PHI node in the same block as the context, we need to ask
+ // questions about the predicate as applied to the incoming value along
+ // each edge. This is useful for eliminating cases where the predicate is
+ // known along all incoming edges.
+ if (auto *PHI = dyn_cast<PHINode>(V))
+ if (PHI->getParent() == BB) {
+ Tristate Baseline = Unknown;
+ for (unsigned i = 0, e = PHI->getNumIncomingValues(); i < e; i++) {
+ Value *Incoming = PHI->getIncomingValue(i);
+ BasicBlock *PredBB = PHI->getIncomingBlock(i);
+ // Note that PredBB may be BB itself.
+ Tristate Result = getPredicateOnEdge(Pred, Incoming, C, PredBB, BB,
+ CxtI);
+
+ // Keep going as long as we've seen a consistent known result for
+ // all inputs.
+ Baseline = (i == 0) ? Result /* First iteration */
+ : (Baseline == Result ? Baseline : Unknown); /* All others */
+ if (Baseline == Unknown)
+ break;
+ }
+ if (Baseline != Unknown)
+ return Baseline;
+ }
+
+ // For a comparison where the V is outside this block, it's possible
+ // that we've branched on it before. Look to see if the value is known
+ // on all incoming edges.
+ if (!isa<Instruction>(V) ||
+ cast<Instruction>(V)->getParent() != BB) {
+ // For predecessor edge, determine if the comparison is true or false
+ // on that edge. If they're all true or all false, we can conclude
+ // the value of the comparison in this block.
+ Tristate Baseline = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI);
+ if (Baseline != Unknown) {
+ // Check that all remaining incoming values match the first one.
+ while (++PI != PE) {
+ Tristate Ret = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI);
+ if (Ret != Baseline) break;
+ }
+ // If we terminated early, then one of the values didn't match.
+ if (PI == PE) {
+ return Baseline;
+ }
+ }
+ }
+ }
+ return Unknown;
+}
+
+void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc,
+ BasicBlock *NewSucc) {
+ if (PImpl) {
+ const DataLayout &DL = PredBB->getModule()->getDataLayout();
+ getImpl(PImpl, AC, &DL, DT).threadEdge(PredBB, OldSucc, NewSucc);
+ }
+}
+
+void LazyValueInfo::eraseBlock(BasicBlock *BB) {
+ if (PImpl) {
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ getImpl(PImpl, AC, &DL, DT).eraseBlock(BB);
+ }
+}
+
+
+void LazyValueInfo::printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) {
+ if (PImpl) {
+ getImpl(PImpl, AC, DL, DT).printLVI(F, DTree, OS);
+ }
+}
+
+void LazyValueInfo::disableDT() {
+ if (PImpl)
+ getImpl(PImpl, AC, DL, DT).disableDT();
+}
+
+void LazyValueInfo::enableDT() {
+ if (PImpl)
+ getImpl(PImpl, AC, DL, DT).enableDT();
+}
+
+// Print the LVI for the function arguments at the start of each basic block.
+void LazyValueInfoAnnotatedWriter::emitBasicBlockStartAnnot(
+ const BasicBlock *BB, formatted_raw_ostream &OS) {
+ // Find if there are latticevalues defined for arguments of the function.
+ auto *F = BB->getParent();
+ for (auto &Arg : F->args()) {
+ ValueLatticeElement Result = LVIImpl->getValueInBlock(
+ const_cast<Argument *>(&Arg), const_cast<BasicBlock *>(BB));
+ if (Result.isUndefined())
+ continue;
+ OS << "; LatticeVal for: '" << Arg << "' is: " << Result << "\n";
+ }
+}
+
+// This function prints the LVI analysis for the instruction I at the beginning
+// of various basic blocks. It relies on calculated values that are stored in
+// the LazyValueInfoCache, and in the absence of cached values, recalculate the
+// LazyValueInfo for `I`, and print that info.
+void LazyValueInfoAnnotatedWriter::emitInstructionAnnot(
+ const Instruction *I, formatted_raw_ostream &OS) {
+
+ auto *ParentBB = I->getParent();
+ SmallPtrSet<const BasicBlock*, 16> BlocksContainingLVI;
+ // We can generate (solve) LVI values only for blocks that are dominated by
+ // the I's parent. However, to avoid generating LVI for all dominating blocks,
+ // that contain redundant/uninteresting information, we print LVI for
+ // blocks that may use this LVI information (such as immediate successor
+ // blocks, and blocks that contain uses of `I`).
+ auto printResult = [&](const BasicBlock *BB) {
+ if (!BlocksContainingLVI.insert(BB).second)
+ return;
+ ValueLatticeElement Result = LVIImpl->getValueInBlock(
+ const_cast<Instruction *>(I), const_cast<BasicBlock *>(BB));
+ OS << "; LatticeVal for: '" << *I << "' in BB: '";
+ BB->printAsOperand(OS, false);
+ OS << "' is: " << Result << "\n";
+ };
+
+ printResult(ParentBB);
+ // Print the LVI analysis results for the immediate successor blocks, that
+ // are dominated by `ParentBB`.
+ for (auto *BBSucc : successors(ParentBB))
+ if (DT.dominates(ParentBB, BBSucc))
+ printResult(BBSucc);
+
+ // Print LVI in blocks where `I` is used.
+ for (auto *U : I->users())
+ if (auto *UseI = dyn_cast<Instruction>(U))
+ if (!isa<PHINode>(UseI) || DT.dominates(ParentBB, UseI->getParent()))
+ printResult(UseI->getParent());
+
+}
+
+namespace {
+// Printer class for LazyValueInfo results.
+class LazyValueInfoPrinter : public FunctionPass {
+public:
+ static char ID; // Pass identification, replacement for typeid
+ LazyValueInfoPrinter() : FunctionPass(ID) {
+ initializeLazyValueInfoPrinterPass(*PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ AU.addRequired<LazyValueInfoWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ }
+
+ // Get the mandatory dominator tree analysis and pass this in to the
+ // LVIPrinter. We cannot rely on the LVI's DT, since it's optional.
+ bool runOnFunction(Function &F) override {
+ dbgs() << "LVI for function '" << F.getName() << "':\n";
+ auto &LVI = getAnalysis<LazyValueInfoWrapperPass>().getLVI();
+ auto &DTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ LVI.printLVI(F, DTree, dbgs());
+ return false;
+ }
+};
+}
+
+char LazyValueInfoPrinter::ID = 0;
+INITIALIZE_PASS_BEGIN(LazyValueInfoPrinter, "print-lazy-value-info",
+ "Lazy Value Info Printer Pass", false, false)
+INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
+INITIALIZE_PASS_END(LazyValueInfoPrinter, "print-lazy-value-info",
+ "Lazy Value Info Printer Pass", false, false)
diff --git a/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp
new file mode 100644
index 000000000000..7de9d2cbfddb
--- /dev/null
+++ b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp
@@ -0,0 +1,404 @@
+//===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis
+//Implementation -==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements divergence analysis which determines whether a branch
+// in a GPU program is divergent.It can help branch optimizations such as jump
+// threading and loop unswitching to make better decisions.
+//
+// GPU programs typically use the SIMD execution model, where multiple threads
+// in the same execution group have to execute in lock-step. Therefore, if the
+// code contains divergent branches (i.e., threads in a group do not agree on
+// which path of the branch to take), the group of threads has to execute all
+// the paths from that branch with different subsets of threads enabled until
+// they converge at the immediately post-dominating BB of the paths.
+//
+// Due to this execution model, some optimizations such as jump
+// threading and loop unswitching can be unfortunately harmful when performed on
+// divergent branches. Therefore, an analysis that computes which branches in a
+// GPU program are divergent can help the compiler to selectively run these
+// optimizations.
+//
+// This file defines divergence analysis which computes a conservative but
+// non-trivial approximation of all divergent branches in a GPU program. It
+// partially implements the approach described in
+//
+// Divergence Analysis
+// Sampaio, Souza, Collange, Pereira
+// TOPLAS '13
+//
+// The divergence analysis identifies the sources of divergence (e.g., special
+// variables that hold the thread ID), and recursively marks variables that are
+// data or sync dependent on a source of divergence as divergent.
+//
+// While data dependency is a well-known concept, the notion of sync dependency
+// is worth more explanation. Sync dependence characterizes the control flow
+// aspect of the propagation of branch divergence. For example,
+//
+// %cond = icmp slt i32 %tid, 10
+// br i1 %cond, label %then, label %else
+// then:
+// br label %merge
+// else:
+// br label %merge
+// merge:
+// %a = phi i32 [ 0, %then ], [ 1, %else ]
+//
+// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
+// because %tid is not on its use-def chains, %a is sync dependent on %tid
+// because the branch "br i1 %cond" depends on %tid and affects which value %a
+// is assigned to.
+//
+// The current implementation has the following limitations:
+// 1. intra-procedural. It conservatively considers the arguments of a
+// non-kernel-entry function and the return value of a function call as
+// divergent.
+// 2. memory as black box. It conservatively considers values loaded from
+// generic or local address as divergent. This can be improved by leveraging
+// pointer analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
+#include "llvm/Analysis/LegacyDivergenceAnalysis.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <vector>
+using namespace llvm;
+
+#define DEBUG_TYPE "divergence"
+
+// transparently use the GPUDivergenceAnalysis
+static cl::opt<bool> UseGPUDA("use-gpu-divergence-analysis", cl::init(false),
+ cl::Hidden,
+ cl::desc("turn the LegacyDivergenceAnalysis into "
+ "a wrapper for GPUDivergenceAnalysis"));
+
+namespace {
+
+class DivergencePropagator {
+public:
+ DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
+ PostDominatorTree &PDT, DenseSet<const Value *> &DV,
+ DenseSet<const Use *> &DU)
+ : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV), DU(DU) {}
+ void populateWithSourcesOfDivergence();
+ void propagate();
+
+private:
+ // A helper function that explores data dependents of V.
+ void exploreDataDependency(Value *V);
+ // A helper function that explores sync dependents of TI.
+ void exploreSyncDependency(Instruction *TI);
+ // Computes the influence region from Start to End. This region includes all
+ // basic blocks on any simple path from Start to End.
+ void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
+ DenseSet<BasicBlock *> &InfluenceRegion);
+ // Finds all users of I that are outside the influence region, and add these
+ // users to Worklist.
+ void findUsersOutsideInfluenceRegion(
+ Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion);
+
+ Function &F;
+ TargetTransformInfo &TTI;
+ DominatorTree &DT;
+ PostDominatorTree &PDT;
+ std::vector<Value *> Worklist; // Stack for DFS.
+ DenseSet<const Value *> &DV; // Stores all divergent values.
+ DenseSet<const Use *> &DU; // Stores divergent uses of possibly uniform
+ // values.
+};
+
+void DivergencePropagator::populateWithSourcesOfDivergence() {
+ Worklist.clear();
+ DV.clear();
+ DU.clear();
+ for (auto &I : instructions(F)) {
+ if (TTI.isSourceOfDivergence(&I)) {
+ Worklist.push_back(&I);
+ DV.insert(&I);
+ }
+ }
+ for (auto &Arg : F.args()) {
+ if (TTI.isSourceOfDivergence(&Arg)) {
+ Worklist.push_back(&Arg);
+ DV.insert(&Arg);
+ }
+ }
+}
+
+void DivergencePropagator::exploreSyncDependency(Instruction *TI) {
+ // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's
+ // immediate post dominator are divergent. This rule handles if-then-else
+ // patterns. For example,
+ //
+ // if (tid < 5)
+ // a1 = 1;
+ // else
+ // a2 = 2;
+ // a = phi(a1, a2); // sync dependent on (tid < 5)
+ BasicBlock *ThisBB = TI->getParent();
+
+ // Unreachable blocks may not be in the dominator tree.
+ if (!DT.isReachableFromEntry(ThisBB))
+ return;
+
+ // If the function has no exit blocks or doesn't reach any exit blocks, the
+ // post dominator may be null.
+ DomTreeNode *ThisNode = PDT.getNode(ThisBB);
+ if (!ThisNode)
+ return;
+
+ BasicBlock *IPostDom = ThisNode->getIDom()->getBlock();
+ if (IPostDom == nullptr)
+ return;
+
+ for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
+ // A PHINode is uniform if it returns the same value no matter which path is
+ // taken.
+ if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second)
+ Worklist.push_back(&*I);
+ }
+
+ // Propagation rule 2: if a value defined in a loop is used outside, the user
+ // is sync dependent on the condition of the loop exits that dominate the
+ // user. For example,
+ //
+ // int i = 0;
+ // do {
+ // i++;
+ // if (foo(i)) ... // uniform
+ // } while (i < tid);
+ // if (bar(i)) ... // divergent
+ //
+ // A program may contain unstructured loops. Therefore, we cannot leverage
+ // LoopInfo, which only recognizes natural loops.
+ //
+ // The algorithm used here handles both natural and unstructured loops. Given
+ // a branch TI, we first compute its influence region, the union of all simple
+ // paths from TI to its immediate post dominator (IPostDom). Then, we search
+ // for all the values defined in the influence region but used outside. All
+ // these users are sync dependent on TI.
+ DenseSet<BasicBlock *> InfluenceRegion;
+ computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion);
+ // An insight that can speed up the search process is that all the in-region
+ // values that are used outside must dominate TI. Therefore, instead of
+ // searching every basic blocks in the influence region, we search all the
+ // dominators of TI until it is outside the influence region.
+ BasicBlock *InfluencedBB = ThisBB;
+ while (InfluenceRegion.count(InfluencedBB)) {
+ for (auto &I : *InfluencedBB) {
+ if (!DV.count(&I))
+ findUsersOutsideInfluenceRegion(I, InfluenceRegion);
+ }
+ DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom();
+ if (IDomNode == nullptr)
+ break;
+ InfluencedBB = IDomNode->getBlock();
+ }
+}
+
+void DivergencePropagator::findUsersOutsideInfluenceRegion(
+ Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) {
+ for (Use &Use : I.uses()) {
+ Instruction *UserInst = cast<Instruction>(Use.getUser());
+ if (!InfluenceRegion.count(UserInst->getParent())) {
+ DU.insert(&Use);
+ if (DV.insert(UserInst).second)
+ Worklist.push_back(UserInst);
+ }
+ }
+}
+
+// A helper function for computeInfluenceRegion that adds successors of "ThisBB"
+// to the influence region.
+static void
+addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
+ DenseSet<BasicBlock *> &InfluenceRegion,
+ std::vector<BasicBlock *> &InfluenceStack) {
+ for (BasicBlock *Succ : successors(ThisBB)) {
+ if (Succ != End && InfluenceRegion.insert(Succ).second)
+ InfluenceStack.push_back(Succ);
+ }
+}
+
+void DivergencePropagator::computeInfluenceRegion(
+ BasicBlock *Start, BasicBlock *End,
+ DenseSet<BasicBlock *> &InfluenceRegion) {
+ assert(PDT.properlyDominates(End, Start) &&
+ "End does not properly dominate Start");
+
+ // The influence region starts from the end of "Start" to the beginning of
+ // "End". Therefore, "Start" should not be in the region unless "Start" is in
+ // a loop that doesn't contain "End".
+ std::vector<BasicBlock *> InfluenceStack;
+ addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
+ while (!InfluenceStack.empty()) {
+ BasicBlock *BB = InfluenceStack.back();
+ InfluenceStack.pop_back();
+ addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
+ }
+}
+
+void DivergencePropagator::exploreDataDependency(Value *V) {
+ // Follow def-use chains of V.
+ for (User *U : V->users()) {
+ if (!TTI.isAlwaysUniform(U) && DV.insert(U).second)
+ Worklist.push_back(U);
+ }
+}
+
+void DivergencePropagator::propagate() {
+ // Traverse the dependency graph using DFS.
+ while (!Worklist.empty()) {
+ Value *V = Worklist.back();
+ Worklist.pop_back();
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ // Terminators with less than two successors won't introduce sync
+ // dependency. Ignore them.
+ if (I->isTerminator() && I->getNumSuccessors() > 1)
+ exploreSyncDependency(I);
+ }
+ exploreDataDependency(V);
+ }
+}
+
+} // namespace
+
+// Register this pass.
+char LegacyDivergenceAnalysis::ID = 0;
+INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence",
+ "Legacy Divergence Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence",
+ "Legacy Divergence Analysis", false, true)
+
+FunctionPass *llvm::createLegacyDivergenceAnalysisPass() {
+ return new LegacyDivergenceAnalysis();
+}
+
+void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<PostDominatorTreeWrapperPass>();
+ if (UseGPUDA)
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.setPreservesAll();
+}
+
+bool LegacyDivergenceAnalysis::shouldUseGPUDivergenceAnalysis(
+ const Function &F) const {
+ if (!UseGPUDA)
+ return false;
+
+ // GPUDivergenceAnalysis requires a reducible CFG.
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ using RPOTraversal = ReversePostOrderTraversal<const Function *>;
+ RPOTraversal FuncRPOT(&F);
+ return !containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
+ const LoopInfo>(FuncRPOT, LI);
+}
+
+bool LegacyDivergenceAnalysis::runOnFunction(Function &F) {
+ auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
+ if (TTIWP == nullptr)
+ return false;
+
+ TargetTransformInfo &TTI = TTIWP->getTTI(F);
+ // Fast path: if the target does not have branch divergence, we do not mark
+ // any branch as divergent.
+ if (!TTI.hasBranchDivergence())
+ return false;
+
+ DivergentValues.clear();
+ DivergentUses.clear();
+ gpuDA = nullptr;
+
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
+
+ if (shouldUseGPUDivergenceAnalysis(F)) {
+ // run the new GPU divergence analysis
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ gpuDA = std::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI);
+
+ } else {
+ // run LLVM's existing DivergenceAnalysis
+ DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues, DivergentUses);
+ DP.populateWithSourcesOfDivergence();
+ DP.propagate();
+ }
+
+ LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName()
+ << ":\n";
+ print(dbgs(), F.getParent()));
+
+ return false;
+}
+
+bool LegacyDivergenceAnalysis::isDivergent(const Value *V) const {
+ if (gpuDA) {
+ return gpuDA->isDivergent(*V);
+ }
+ return DivergentValues.count(V);
+}
+
+bool LegacyDivergenceAnalysis::isDivergentUse(const Use *U) const {
+ if (gpuDA) {
+ return gpuDA->isDivergentUse(*U);
+ }
+ return DivergentValues.count(U->get()) || DivergentUses.count(U);
+}
+
+void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
+ if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty())
+ return;
+
+ const Function *F = nullptr;
+ if (!DivergentValues.empty()) {
+ const Value *FirstDivergentValue = *DivergentValues.begin();
+ if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) {
+ F = Arg->getParent();
+ } else if (const Instruction *I =
+ dyn_cast<Instruction>(FirstDivergentValue)) {
+ F = I->getParent()->getParent();
+ } else {
+ llvm_unreachable("Only arguments and instructions can be divergent");
+ }
+ } else if (gpuDA) {
+ F = &gpuDA->getFunction();
+ }
+ if (!F)
+ return;
+
+ // Dumps all divergent values in F, arguments and then instructions.
+ for (auto &Arg : F->args()) {
+ OS << (isDivergent(&Arg) ? "DIVERGENT: " : " ");
+ OS << Arg << "\n";
+ }
+ // Iterate instructions using instructions() to ensure a deterministic order.
+ for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) {
+ auto &BB = *BI;
+ OS << "\n " << BB.getName() << ":\n";
+ for (auto &I : BB.instructionsWithoutDebug()) {
+ OS << (isDivergent(&I) ? "DIVERGENT: " : " ");
+ OS << I << "\n";
+ }
+ }
+ OS << "\n";
+}
diff --git a/llvm/lib/Analysis/Lint.cpp b/llvm/lib/Analysis/Lint.cpp
new file mode 100644
index 000000000000..db18716c64cf
--- /dev/null
+++ b/llvm/lib/Analysis/Lint.cpp
@@ -0,0 +1,756 @@
+//===-- Lint.cpp - Check for common errors in LLVM IR ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass statically checks for common and easily-identified constructs
+// which produce undefined or likely unintended behavior in LLVM IR.
+//
+// It is not a guarantee of correctness, in two ways. First, it isn't
+// comprehensive. There are checks which could be done statically which are
+// not yet implemented. Some of these are indicated by TODO comments, but
+// those aren't comprehensive either. Second, many conditions cannot be
+// checked statically. This pass does no dynamic instrumentation, so it
+// can't check for all possible problems.
+//
+// Another limitation is that it assumes all code will be executed. A store
+// through a null pointer in a basic block which is never reached is harmless,
+// but this pass will warn about it anyway. This is the main reason why most
+// of these checks live here instead of in the Verifier pass.
+//
+// Optimization passes may make conditions that this pass checks for more or
+// less obvious. If an optimization pass appears to be introducing a warning,
+// it may be that the optimization pass is merely exposing an existing
+// condition in the code.
+//
+// This code may be run before instcombine. In many cases, instcombine checks
+// for the same kinds of things and turns instructions with undefined behavior
+// into unreachable (or equivalent). Because of this, this pass makes some
+// effort to look through bitcasts and so on.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Lint.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <string>
+
+using namespace llvm;
+
+namespace {
+ namespace MemRef {
+ static const unsigned Read = 1;
+ static const unsigned Write = 2;
+ static const unsigned Callee = 4;
+ static const unsigned Branchee = 8;
+ } // end namespace MemRef
+
+ class Lint : public FunctionPass, public InstVisitor<Lint> {
+ friend class InstVisitor<Lint>;
+
+ void visitFunction(Function &F);
+
+ void visitCallSite(CallSite CS);
+ void visitMemoryReference(Instruction &I, Value *Ptr,
+ uint64_t Size, unsigned Align,
+ Type *Ty, unsigned Flags);
+ void visitEHBeginCatch(IntrinsicInst *II);
+ void visitEHEndCatch(IntrinsicInst *II);
+
+ void visitCallInst(CallInst &I);
+ void visitInvokeInst(InvokeInst &I);
+ void visitReturnInst(ReturnInst &I);
+ void visitLoadInst(LoadInst &I);
+ void visitStoreInst(StoreInst &I);
+ void visitXor(BinaryOperator &I);
+ void visitSub(BinaryOperator &I);
+ void visitLShr(BinaryOperator &I);
+ void visitAShr(BinaryOperator &I);
+ void visitShl(BinaryOperator &I);
+ void visitSDiv(BinaryOperator &I);
+ void visitUDiv(BinaryOperator &I);
+ void visitSRem(BinaryOperator &I);
+ void visitURem(BinaryOperator &I);
+ void visitAllocaInst(AllocaInst &I);
+ void visitVAArgInst(VAArgInst &I);
+ void visitIndirectBrInst(IndirectBrInst &I);
+ void visitExtractElementInst(ExtractElementInst &I);
+ void visitInsertElementInst(InsertElementInst &I);
+ void visitUnreachableInst(UnreachableInst &I);
+
+ Value *findValue(Value *V, bool OffsetOk) const;
+ Value *findValueImpl(Value *V, bool OffsetOk,
+ SmallPtrSetImpl<Value *> &Visited) const;
+
+ public:
+ Module *Mod;
+ const DataLayout *DL;
+ AliasAnalysis *AA;
+ AssumptionCache *AC;
+ DominatorTree *DT;
+ TargetLibraryInfo *TLI;
+
+ std::string Messages;
+ raw_string_ostream MessagesStr;
+
+ static char ID; // Pass identification, replacement for typeid
+ Lint() : FunctionPass(ID), MessagesStr(Messages) {
+ initializeLintPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ }
+ void print(raw_ostream &O, const Module *M) const override {}
+
+ void WriteValues(ArrayRef<const Value *> Vs) {
+ for (const Value *V : Vs) {
+ if (!V)
+ continue;
+ if (isa<Instruction>(V)) {
+ MessagesStr << *V << '\n';
+ } else {
+ V->printAsOperand(MessagesStr, true, Mod);
+ MessagesStr << '\n';
+ }
+ }
+ }
+
+ /// A check failed, so printout out the condition and the message.
+ ///
+ /// This provides a nice place to put a breakpoint if you want to see why
+ /// something is not correct.
+ void CheckFailed(const Twine &Message) { MessagesStr << Message << '\n'; }
+
+ /// A check failed (with values to print).
+ ///
+ /// This calls the Message-only version so that the above is easier to set
+ /// a breakpoint on.
+ template <typename T1, typename... Ts>
+ void CheckFailed(const Twine &Message, const T1 &V1, const Ts &...Vs) {
+ CheckFailed(Message);
+ WriteValues({V1, Vs...});
+ }
+ };
+} // end anonymous namespace
+
+char Lint::ID = 0;
+INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR",
+ false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR",
+ false, true)
+
+// Assert - We know that cond should be true, if not print an error message.
+#define Assert(C, ...) \
+ do { if (!(C)) { CheckFailed(__VA_ARGS__); return; } } while (false)
+
+// Lint::run - This is the main Analysis entry point for a
+// function.
+//
+bool Lint::runOnFunction(Function &F) {
+ Mod = F.getParent();
+ DL = &F.getParent()->getDataLayout();
+ AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
+ AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ visit(F);
+ dbgs() << MessagesStr.str();
+ Messages.clear();
+ return false;
+}
+
+void Lint::visitFunction(Function &F) {
+ // This isn't undefined behavior, it's just a little unusual, and it's a
+ // fairly common mistake to neglect to name a function.
+ Assert(F.hasName() || F.hasLocalLinkage(),
+ "Unusual: Unnamed function with non-local linkage", &F);
+
+ // TODO: Check for irreducible control flow.
+}
+
+void Lint::visitCallSite(CallSite CS) {
+ Instruction &I = *CS.getInstruction();
+ Value *Callee = CS.getCalledValue();
+
+ visitMemoryReference(I, Callee, MemoryLocation::UnknownSize, 0, nullptr,
+ MemRef::Callee);
+
+ if (Function *F = dyn_cast<Function>(findValue(Callee,
+ /*OffsetOk=*/false))) {
+ Assert(CS.getCallingConv() == F->getCallingConv(),
+ "Undefined behavior: Caller and callee calling convention differ",
+ &I);
+
+ FunctionType *FT = F->getFunctionType();
+ unsigned NumActualArgs = CS.arg_size();
+
+ Assert(FT->isVarArg() ? FT->getNumParams() <= NumActualArgs
+ : FT->getNumParams() == NumActualArgs,
+ "Undefined behavior: Call argument count mismatches callee "
+ "argument count",
+ &I);
+
+ Assert(FT->getReturnType() == I.getType(),
+ "Undefined behavior: Call return type mismatches "
+ "callee return type",
+ &I);
+
+ // Check argument types (in case the callee was casted) and attributes.
+ // TODO: Verify that caller and callee attributes are compatible.
+ Function::arg_iterator PI = F->arg_begin(), PE = F->arg_end();
+ CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end();
+ for (; AI != AE; ++AI) {
+ Value *Actual = *AI;
+ if (PI != PE) {
+ Argument *Formal = &*PI++;
+ Assert(Formal->getType() == Actual->getType(),
+ "Undefined behavior: Call argument type mismatches "
+ "callee parameter type",
+ &I);
+
+ // Check that noalias arguments don't alias other arguments. This is
+ // not fully precise because we don't know the sizes of the dereferenced
+ // memory regions.
+ if (Formal->hasNoAliasAttr() && Actual->getType()->isPointerTy()) {
+ AttributeList PAL = CS.getAttributes();
+ unsigned ArgNo = 0;
+ for (CallSite::arg_iterator BI = CS.arg_begin(); BI != AE;
+ ++BI, ++ArgNo) {
+ // Skip ByVal arguments since they will be memcpy'd to the callee's
+ // stack so we're not really passing the pointer anyway.
+ if (PAL.hasParamAttribute(ArgNo, Attribute::ByVal))
+ continue;
+ // If both arguments are readonly, they have no dependence.
+ if (Formal->onlyReadsMemory() && CS.onlyReadsMemory(ArgNo))
+ continue;
+ if (AI != BI && (*BI)->getType()->isPointerTy()) {
+ AliasResult Result = AA->alias(*AI, *BI);
+ Assert(Result != MustAlias && Result != PartialAlias,
+ "Unusual: noalias argument aliases another argument", &I);
+ }
+ }
+ }
+
+ // Check that an sret argument points to valid memory.
+ if (Formal->hasStructRetAttr() && Actual->getType()->isPointerTy()) {
+ Type *Ty =
+ cast<PointerType>(Formal->getType())->getElementType();
+ visitMemoryReference(I, Actual, DL->getTypeStoreSize(Ty),
+ DL->getABITypeAlignment(Ty), Ty,
+ MemRef::Read | MemRef::Write);
+ }
+ }
+ }
+ }
+
+ if (CS.isCall()) {
+ const CallInst *CI = cast<CallInst>(CS.getInstruction());
+ if (CI->isTailCall()) {
+ const AttributeList &PAL = CI->getAttributes();
+ unsigned ArgNo = 0;
+ for (Value *Arg : CS.args()) {
+ // Skip ByVal arguments since they will be memcpy'd to the callee's
+ // stack anyway.
+ if (PAL.hasParamAttribute(ArgNo++, Attribute::ByVal))
+ continue;
+ Value *Obj = findValue(Arg, /*OffsetOk=*/true);
+ Assert(!isa<AllocaInst>(Obj),
+ "Undefined behavior: Call with \"tail\" keyword references "
+ "alloca",
+ &I);
+ }
+ }
+ }
+
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I))
+ switch (II->getIntrinsicID()) {
+ default: break;
+
+ // TODO: Check more intrinsics
+
+ case Intrinsic::memcpy: {
+ MemCpyInst *MCI = cast<MemCpyInst>(&I);
+ // TODO: If the size is known, use it.
+ visitMemoryReference(I, MCI->getDest(), MemoryLocation::UnknownSize,
+ MCI->getDestAlignment(), nullptr, MemRef::Write);
+ visitMemoryReference(I, MCI->getSource(), MemoryLocation::UnknownSize,
+ MCI->getSourceAlignment(), nullptr, MemRef::Read);
+
+ // Check that the memcpy arguments don't overlap. The AliasAnalysis API
+ // isn't expressive enough for what we really want to do. Known partial
+ // overlap is not distinguished from the case where nothing is known.
+ auto Size = LocationSize::unknown();
+ if (const ConstantInt *Len =
+ dyn_cast<ConstantInt>(findValue(MCI->getLength(),
+ /*OffsetOk=*/false)))
+ if (Len->getValue().isIntN(32))
+ Size = LocationSize::precise(Len->getValue().getZExtValue());
+ Assert(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) !=
+ MustAlias,
+ "Undefined behavior: memcpy source and destination overlap", &I);
+ break;
+ }
+ case Intrinsic::memmove: {
+ MemMoveInst *MMI = cast<MemMoveInst>(&I);
+ // TODO: If the size is known, use it.
+ visitMemoryReference(I, MMI->getDest(), MemoryLocation::UnknownSize,
+ MMI->getDestAlignment(), nullptr, MemRef::Write);
+ visitMemoryReference(I, MMI->getSource(), MemoryLocation::UnknownSize,
+ MMI->getSourceAlignment(), nullptr, MemRef::Read);
+ break;
+ }
+ case Intrinsic::memset: {
+ MemSetInst *MSI = cast<MemSetInst>(&I);
+ // TODO: If the size is known, use it.
+ visitMemoryReference(I, MSI->getDest(), MemoryLocation::UnknownSize,
+ MSI->getDestAlignment(), nullptr, MemRef::Write);
+ break;
+ }
+
+ case Intrinsic::vastart:
+ Assert(I.getParent()->getParent()->isVarArg(),
+ "Undefined behavior: va_start called in a non-varargs function",
+ &I);
+
+ visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Read | MemRef::Write);
+ break;
+ case Intrinsic::vacopy:
+ visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Write);
+ visitMemoryReference(I, CS.getArgument(1), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Read);
+ break;
+ case Intrinsic::vaend:
+ visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Read | MemRef::Write);
+ break;
+
+ case Intrinsic::stackrestore:
+ // Stackrestore doesn't read or write memory, but it sets the
+ // stack pointer, which the compiler may read from or write to
+ // at any time, so check it for both readability and writeability.
+ visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Read | MemRef::Write);
+ break;
+ }
+}
+
+void Lint::visitCallInst(CallInst &I) {
+ return visitCallSite(&I);
+}
+
+void Lint::visitInvokeInst(InvokeInst &I) {
+ return visitCallSite(&I);
+}
+
+void Lint::visitReturnInst(ReturnInst &I) {
+ Function *F = I.getParent()->getParent();
+ Assert(!F->doesNotReturn(),
+ "Unusual: Return statement in function with noreturn attribute", &I);
+
+ if (Value *V = I.getReturnValue()) {
+ Value *Obj = findValue(V, /*OffsetOk=*/true);
+ Assert(!isa<AllocaInst>(Obj), "Unusual: Returning alloca value", &I);
+ }
+}
+
+// TODO: Check that the reference is in bounds.
+// TODO: Check readnone/readonly function attributes.
+void Lint::visitMemoryReference(Instruction &I,
+ Value *Ptr, uint64_t Size, unsigned Align,
+ Type *Ty, unsigned Flags) {
+ // If no memory is being referenced, it doesn't matter if the pointer
+ // is valid.
+ if (Size == 0)
+ return;
+
+ Value *UnderlyingObject = findValue(Ptr, /*OffsetOk=*/true);
+ Assert(!isa<ConstantPointerNull>(UnderlyingObject),
+ "Undefined behavior: Null pointer dereference", &I);
+ Assert(!isa<UndefValue>(UnderlyingObject),
+ "Undefined behavior: Undef pointer dereference", &I);
+ Assert(!isa<ConstantInt>(UnderlyingObject) ||
+ !cast<ConstantInt>(UnderlyingObject)->isMinusOne(),
+ "Unusual: All-ones pointer dereference", &I);
+ Assert(!isa<ConstantInt>(UnderlyingObject) ||
+ !cast<ConstantInt>(UnderlyingObject)->isOne(),
+ "Unusual: Address one pointer dereference", &I);
+
+ if (Flags & MemRef::Write) {
+ if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(UnderlyingObject))
+ Assert(!GV->isConstant(), "Undefined behavior: Write to read-only memory",
+ &I);
+ Assert(!isa<Function>(UnderlyingObject) &&
+ !isa<BlockAddress>(UnderlyingObject),
+ "Undefined behavior: Write to text section", &I);
+ }
+ if (Flags & MemRef::Read) {
+ Assert(!isa<Function>(UnderlyingObject), "Unusual: Load from function body",
+ &I);
+ Assert(!isa<BlockAddress>(UnderlyingObject),
+ "Undefined behavior: Load from block address", &I);
+ }
+ if (Flags & MemRef::Callee) {
+ Assert(!isa<BlockAddress>(UnderlyingObject),
+ "Undefined behavior: Call to block address", &I);
+ }
+ if (Flags & MemRef::Branchee) {
+ Assert(!isa<Constant>(UnderlyingObject) ||
+ isa<BlockAddress>(UnderlyingObject),
+ "Undefined behavior: Branch to non-blockaddress", &I);
+ }
+
+ // Check for buffer overflows and misalignment.
+ // Only handles memory references that read/write something simple like an
+ // alloca instruction or a global variable.
+ int64_t Offset = 0;
+ if (Value *Base = GetPointerBaseWithConstantOffset(Ptr, Offset, *DL)) {
+ // OK, so the access is to a constant offset from Ptr. Check that Ptr is
+ // something we can handle and if so extract the size of this base object
+ // along with its alignment.
+ uint64_t BaseSize = MemoryLocation::UnknownSize;
+ unsigned BaseAlign = 0;
+
+ if (AllocaInst *AI = dyn_cast<AllocaInst>(Base)) {
+ Type *ATy = AI->getAllocatedType();
+ if (!AI->isArrayAllocation() && ATy->isSized())
+ BaseSize = DL->getTypeAllocSize(ATy);
+ BaseAlign = AI->getAlignment();
+ if (BaseAlign == 0 && ATy->isSized())
+ BaseAlign = DL->getABITypeAlignment(ATy);
+ } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) {
+ // If the global may be defined differently in another compilation unit
+ // then don't warn about funky memory accesses.
+ if (GV->hasDefinitiveInitializer()) {
+ Type *GTy = GV->getValueType();
+ if (GTy->isSized())
+ BaseSize = DL->getTypeAllocSize(GTy);
+ BaseAlign = GV->getAlignment();
+ if (BaseAlign == 0 && GTy->isSized())
+ BaseAlign = DL->getABITypeAlignment(GTy);
+ }
+ }
+
+ // Accesses from before the start or after the end of the object are not
+ // defined.
+ Assert(Size == MemoryLocation::UnknownSize ||
+ BaseSize == MemoryLocation::UnknownSize ||
+ (Offset >= 0 && Offset + Size <= BaseSize),
+ "Undefined behavior: Buffer overflow", &I);
+
+ // Accesses that say that the memory is more aligned than it is are not
+ // defined.
+ if (Align == 0 && Ty && Ty->isSized())
+ Align = DL->getABITypeAlignment(Ty);
+ Assert(!BaseAlign || Align <= MinAlign(BaseAlign, Offset),
+ "Undefined behavior: Memory reference address is misaligned", &I);
+ }
+}
+
+void Lint::visitLoadInst(LoadInst &I) {
+ visitMemoryReference(I, I.getPointerOperand(),
+ DL->getTypeStoreSize(I.getType()), I.getAlignment(),
+ I.getType(), MemRef::Read);
+}
+
+void Lint::visitStoreInst(StoreInst &I) {
+ visitMemoryReference(I, I.getPointerOperand(),
+ DL->getTypeStoreSize(I.getOperand(0)->getType()),
+ I.getAlignment(),
+ I.getOperand(0)->getType(), MemRef::Write);
+}
+
+void Lint::visitXor(BinaryOperator &I) {
+ Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)),
+ "Undefined result: xor(undef, undef)", &I);
+}
+
+void Lint::visitSub(BinaryOperator &I) {
+ Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)),
+ "Undefined result: sub(undef, undef)", &I);
+}
+
+void Lint::visitLShr(BinaryOperator &I) {
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1),
+ /*OffsetOk=*/false)))
+ Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()),
+ "Undefined result: Shift count out of range", &I);
+}
+
+void Lint::visitAShr(BinaryOperator &I) {
+ if (ConstantInt *CI =
+ dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false)))
+ Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()),
+ "Undefined result: Shift count out of range", &I);
+}
+
+void Lint::visitShl(BinaryOperator &I) {
+ if (ConstantInt *CI =
+ dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false)))
+ Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()),
+ "Undefined result: Shift count out of range", &I);
+}
+
+static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT,
+ AssumptionCache *AC) {
+ // Assume undef could be zero.
+ if (isa<UndefValue>(V))
+ return true;
+
+ VectorType *VecTy = dyn_cast<VectorType>(V->getType());
+ if (!VecTy) {
+ KnownBits Known = computeKnownBits(V, DL, 0, AC, dyn_cast<Instruction>(V), DT);
+ return Known.isZero();
+ }
+
+ // Per-component check doesn't work with zeroinitializer
+ Constant *C = dyn_cast<Constant>(V);
+ if (!C)
+ return false;
+
+ if (C->isZeroValue())
+ return true;
+
+ // For a vector, KnownZero will only be true if all values are zero, so check
+ // this per component
+ for (unsigned I = 0, N = VecTy->getNumElements(); I != N; ++I) {
+ Constant *Elem = C->getAggregateElement(I);
+ if (isa<UndefValue>(Elem))
+ return true;
+
+ KnownBits Known = computeKnownBits(Elem, DL);
+ if (Known.isZero())
+ return true;
+ }
+
+ return false;
+}
+
+void Lint::visitSDiv(BinaryOperator &I) {
+ Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC),
+ "Undefined behavior: Division by zero", &I);
+}
+
+void Lint::visitUDiv(BinaryOperator &I) {
+ Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC),
+ "Undefined behavior: Division by zero", &I);
+}
+
+void Lint::visitSRem(BinaryOperator &I) {
+ Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC),
+ "Undefined behavior: Division by zero", &I);
+}
+
+void Lint::visitURem(BinaryOperator &I) {
+ Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC),
+ "Undefined behavior: Division by zero", &I);
+}
+
+void Lint::visitAllocaInst(AllocaInst &I) {
+ if (isa<ConstantInt>(I.getArraySize()))
+ // This isn't undefined behavior, it's just an obvious pessimization.
+ Assert(&I.getParent()->getParent()->getEntryBlock() == I.getParent(),
+ "Pessimization: Static alloca outside of entry block", &I);
+
+ // TODO: Check for an unusual size (MSB set?)
+}
+
+void Lint::visitVAArgInst(VAArgInst &I) {
+ visitMemoryReference(I, I.getOperand(0), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Read | MemRef::Write);
+}
+
+void Lint::visitIndirectBrInst(IndirectBrInst &I) {
+ visitMemoryReference(I, I.getAddress(), MemoryLocation::UnknownSize, 0,
+ nullptr, MemRef::Branchee);
+
+ Assert(I.getNumDestinations() != 0,
+ "Undefined behavior: indirectbr with no destinations", &I);
+}
+
+void Lint::visitExtractElementInst(ExtractElementInst &I) {
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getIndexOperand(),
+ /*OffsetOk=*/false)))
+ Assert(CI->getValue().ult(I.getVectorOperandType()->getNumElements()),
+ "Undefined result: extractelement index out of range", &I);
+}
+
+void Lint::visitInsertElementInst(InsertElementInst &I) {
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(2),
+ /*OffsetOk=*/false)))
+ Assert(CI->getValue().ult(I.getType()->getNumElements()),
+ "Undefined result: insertelement index out of range", &I);
+}
+
+void Lint::visitUnreachableInst(UnreachableInst &I) {
+ // This isn't undefined behavior, it's merely suspicious.
+ Assert(&I == &I.getParent()->front() ||
+ std::prev(I.getIterator())->mayHaveSideEffects(),
+ "Unusual: unreachable immediately preceded by instruction without "
+ "side effects",
+ &I);
+}
+
+/// findValue - Look through bitcasts and simple memory reference patterns
+/// to identify an equivalent, but more informative, value. If OffsetOk
+/// is true, look through getelementptrs with non-zero offsets too.
+///
+/// Most analysis passes don't require this logic, because instcombine
+/// will simplify most of these kinds of things away. But it's a goal of
+/// this Lint pass to be useful even on non-optimized IR.
+Value *Lint::findValue(Value *V, bool OffsetOk) const {
+ SmallPtrSet<Value *, 4> Visited;
+ return findValueImpl(V, OffsetOk, Visited);
+}
+
+/// findValueImpl - Implementation helper for findValue.
+Value *Lint::findValueImpl(Value *V, bool OffsetOk,
+ SmallPtrSetImpl<Value *> &Visited) const {
+ // Detect self-referential values.
+ if (!Visited.insert(V).second)
+ return UndefValue::get(V->getType());
+
+ // TODO: Look through sext or zext cast, when the result is known to
+ // be interpreted as signed or unsigned, respectively.
+ // TODO: Look through eliminable cast pairs.
+ // TODO: Look through calls with unique return values.
+ // TODO: Look through vector insert/extract/shuffle.
+ V = OffsetOk ? GetUnderlyingObject(V, *DL) : V->stripPointerCasts();
+ if (LoadInst *L = dyn_cast<LoadInst>(V)) {
+ BasicBlock::iterator BBI = L->getIterator();
+ BasicBlock *BB = L->getParent();
+ SmallPtrSet<BasicBlock *, 4> VisitedBlocks;
+ for (;;) {
+ if (!VisitedBlocks.insert(BB).second)
+ break;
+ if (Value *U =
+ FindAvailableLoadedValue(L, BB, BBI, DefMaxInstsToScan, AA))
+ return findValueImpl(U, OffsetOk, Visited);
+ if (BBI != BB->begin()) break;
+ BB = BB->getUniquePredecessor();
+ if (!BB) break;
+ BBI = BB->end();
+ }
+ } else if (PHINode *PN = dyn_cast<PHINode>(V)) {
+ if (Value *W = PN->hasConstantValue())
+ if (W != V)
+ return findValueImpl(W, OffsetOk, Visited);
+ } else if (CastInst *CI = dyn_cast<CastInst>(V)) {
+ if (CI->isNoopCast(*DL))
+ return findValueImpl(CI->getOperand(0), OffsetOk, Visited);
+ } else if (ExtractValueInst *Ex = dyn_cast<ExtractValueInst>(V)) {
+ if (Value *W = FindInsertedValue(Ex->getAggregateOperand(),
+ Ex->getIndices()))
+ if (W != V)
+ return findValueImpl(W, OffsetOk, Visited);
+ } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
+ // Same as above, but for ConstantExpr instead of Instruction.
+ if (Instruction::isCast(CE->getOpcode())) {
+ if (CastInst::isNoopCast(Instruction::CastOps(CE->getOpcode()),
+ CE->getOperand(0)->getType(), CE->getType(),
+ *DL))
+ return findValueImpl(CE->getOperand(0), OffsetOk, Visited);
+ } else if (CE->getOpcode() == Instruction::ExtractValue) {
+ ArrayRef<unsigned> Indices = CE->getIndices();
+ if (Value *W = FindInsertedValue(CE->getOperand(0), Indices))
+ if (W != V)
+ return findValueImpl(W, OffsetOk, Visited);
+ }
+ }
+
+ // As a last resort, try SimplifyInstruction or constant folding.
+ if (Instruction *Inst = dyn_cast<Instruction>(V)) {
+ if (Value *W = SimplifyInstruction(Inst, {*DL, TLI, DT, AC}))
+ return findValueImpl(W, OffsetOk, Visited);
+ } else if (auto *C = dyn_cast<Constant>(V)) {
+ if (Value *W = ConstantFoldConstant(C, *DL, TLI))
+ if (W && W != V)
+ return findValueImpl(W, OffsetOk, Visited);
+ }
+
+ return V;
+}
+
+//===----------------------------------------------------------------------===//
+// Implement the public interfaces to this file...
+//===----------------------------------------------------------------------===//
+
+FunctionPass *llvm::createLintPass() {
+ return new Lint();
+}
+
+/// lintFunction - Check a function for errors, printing messages on stderr.
+///
+void llvm::lintFunction(const Function &f) {
+ Function &F = const_cast<Function&>(f);
+ assert(!F.isDeclaration() && "Cannot lint external functions");
+
+ legacy::FunctionPassManager FPM(F.getParent());
+ Lint *V = new Lint();
+ FPM.add(V);
+ FPM.run(F);
+}
+
+/// lintModule - Check a module for errors, printing messages on stderr.
+///
+void llvm::lintModule(const Module &M) {
+ legacy::PassManager PM;
+ Lint *V = new Lint();
+ PM.add(V);
+ PM.run(const_cast<Module&>(M));
+}
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
new file mode 100644
index 000000000000..641e92eac781
--- /dev/null
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -0,0 +1,481 @@
+//===- Loads.cpp - Local load analysis ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines simple local analyses for load instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Statepoint.h"
+
+using namespace llvm;
+
+static MaybeAlign getBaseAlign(const Value *Base, const DataLayout &DL) {
+ if (const MaybeAlign PA = Base->getPointerAlignment(DL))
+ return *PA;
+ Type *const Ty = Base->getType()->getPointerElementType();
+ if (!Ty->isSized())
+ return None;
+ return Align(DL.getABITypeAlignment(Ty));
+}
+
+static bool isAligned(const Value *Base, const APInt &Offset, Align Alignment,
+ const DataLayout &DL) {
+ if (MaybeAlign BA = getBaseAlign(Base, DL)) {
+ const APInt APBaseAlign(Offset.getBitWidth(), BA->value());
+ const APInt APAlign(Offset.getBitWidth(), Alignment.value());
+ assert(APAlign.isPowerOf2() && "must be a power of 2!");
+ return APBaseAlign.uge(APAlign) && !(Offset & (APAlign - 1));
+ }
+ return false;
+}
+
+/// Test if V is always a pointer to allocated and suitably aligned memory for
+/// a simple load or store.
+static bool isDereferenceableAndAlignedPointer(
+ const Value *V, Align Alignment, const APInt &Size, const DataLayout &DL,
+ const Instruction *CtxI, const DominatorTree *DT,
+ SmallPtrSetImpl<const Value *> &Visited) {
+ // Already visited? Bail out, we've likely hit unreachable code.
+ if (!Visited.insert(V).second)
+ return false;
+
+ // Note that it is not safe to speculate into a malloc'd region because
+ // malloc may return null.
+
+ // bitcast instructions are no-ops as far as dereferenceability is concerned.
+ if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(V))
+ return isDereferenceableAndAlignedPointer(BC->getOperand(0), Alignment,
+ Size, DL, CtxI, DT, Visited);
+
+ bool CheckForNonNull = false;
+ APInt KnownDerefBytes(Size.getBitWidth(),
+ V->getPointerDereferenceableBytes(DL, CheckForNonNull));
+ if (KnownDerefBytes.getBoolValue() && KnownDerefBytes.uge(Size))
+ if (!CheckForNonNull || isKnownNonZero(V, DL, 0, nullptr, CtxI, DT)) {
+ // As we recursed through GEPs to get here, we've incrementally checked
+ // that each step advanced by a multiple of the alignment. If our base is
+ // properly aligned, then the original offset accessed must also be.
+ Type *Ty = V->getType();
+ assert(Ty->isSized() && "must be sized");
+ APInt Offset(DL.getTypeStoreSizeInBits(Ty), 0);
+ return isAligned(V, Offset, Alignment, DL);
+ }
+
+ // For GEPs, determine if the indexing lands within the allocated object.
+ if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+ const Value *Base = GEP->getPointerOperand();
+
+ APInt Offset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+ if (!GEP->accumulateConstantOffset(DL, Offset) || Offset.isNegative() ||
+ !Offset.urem(APInt(Offset.getBitWidth(), Alignment.value()))
+ .isMinValue())
+ return false;
+
+ // If the base pointer is dereferenceable for Offset+Size bytes, then the
+ // GEP (== Base + Offset) is dereferenceable for Size bytes. If the base
+ // pointer is aligned to Align bytes, and the Offset is divisible by Align
+ // then the GEP (== Base + Offset == k_0 * Align + k_1 * Align) is also
+ // aligned to Align bytes.
+
+ // Offset and Size may have different bit widths if we have visited an
+ // addrspacecast, so we can't do arithmetic directly on the APInt values.
+ return isDereferenceableAndAlignedPointer(
+ Base, Alignment, Offset + Size.sextOrTrunc(Offset.getBitWidth()), DL,
+ CtxI, DT, Visited);
+ }
+
+ // For gc.relocate, look through relocations
+ if (const GCRelocateInst *RelocateInst = dyn_cast<GCRelocateInst>(V))
+ return isDereferenceableAndAlignedPointer(
+ RelocateInst->getDerivedPtr(), Alignment, Size, DL, CtxI, DT, Visited);
+
+ if (const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(V))
+ return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Alignment,
+ Size, DL, CtxI, DT, Visited);
+
+ if (const auto *Call = dyn_cast<CallBase>(V))
+ if (auto *RP = getArgumentAliasingToReturnedPointer(Call, true))
+ return isDereferenceableAndAlignedPointer(RP, Alignment, Size, DL, CtxI,
+ DT, Visited);
+
+ // If we don't know, assume the worst.
+ return false;
+}
+
+bool llvm::isDereferenceableAndAlignedPointer(const Value *V, Align Alignment,
+ const APInt &Size,
+ const DataLayout &DL,
+ const Instruction *CtxI,
+ const DominatorTree *DT) {
+ // Note: At the moment, Size can be zero. This ends up being interpreted as
+ // a query of whether [Base, V] is dereferenceable and V is aligned (since
+ // that's what the implementation happened to do). It's unclear if this is
+ // the desired semantic, but at least SelectionDAG does exercise this case.
+
+ SmallPtrSet<const Value *, 32> Visited;
+ return ::isDereferenceableAndAlignedPointer(V, Alignment, Size, DL, CtxI, DT,
+ Visited);
+}
+
+bool llvm::isDereferenceableAndAlignedPointer(const Value *V, Type *Ty,
+ MaybeAlign MA,
+ const DataLayout &DL,
+ const Instruction *CtxI,
+ const DominatorTree *DT) {
+ if (!Ty->isSized())
+ return false;
+
+ // When dereferenceability information is provided by a dereferenceable
+ // attribute, we know exactly how many bytes are dereferenceable. If we can
+ // determine the exact offset to the attributed variable, we can use that
+ // information here.
+
+ // Require ABI alignment for loads without alignment specification
+ const Align Alignment = DL.getValueOrABITypeAlignment(MA, Ty);
+ APInt AccessSize(DL.getIndexTypeSizeInBits(V->getType()),
+ DL.getTypeStoreSize(Ty));
+ return isDereferenceableAndAlignedPointer(V, Alignment, AccessSize, DL, CtxI,
+ DT);
+}
+
+bool llvm::isDereferenceablePointer(const Value *V, Type *Ty,
+ const DataLayout &DL,
+ const Instruction *CtxI,
+ const DominatorTree *DT) {
+ return isDereferenceableAndAlignedPointer(V, Ty, Align::None(), DL, CtxI, DT);
+}
+
+/// Test if A and B will obviously have the same value.
+///
+/// This includes recognizing that %t0 and %t1 will have the same
+/// value in code like this:
+/// \code
+/// %t0 = getelementptr \@a, 0, 3
+/// store i32 0, i32* %t0
+/// %t1 = getelementptr \@a, 0, 3
+/// %t2 = load i32* %t1
+/// \endcode
+///
+static bool AreEquivalentAddressValues(const Value *A, const Value *B) {
+ // Test if the values are trivially equivalent.
+ if (A == B)
+ return true;
+
+ // Test if the values come from identical arithmetic instructions.
+ // Use isIdenticalToWhenDefined instead of isIdenticalTo because
+ // this function is only used when one address use dominates the
+ // other, which means that they'll always either have the same
+ // value or one of them will have an undefined value.
+ if (isa<BinaryOperator>(A) || isa<CastInst>(A) || isa<PHINode>(A) ||
+ isa<GetElementPtrInst>(A))
+ if (const Instruction *BI = dyn_cast<Instruction>(B))
+ if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI))
+ return true;
+
+ // Otherwise they may not be equivalent.
+ return false;
+}
+
+bool llvm::isDereferenceableAndAlignedInLoop(LoadInst *LI, Loop *L,
+ ScalarEvolution &SE,
+ DominatorTree &DT) {
+ auto &DL = LI->getModule()->getDataLayout();
+ Value *Ptr = LI->getPointerOperand();
+
+ APInt EltSize(DL.getIndexTypeSizeInBits(Ptr->getType()),
+ DL.getTypeStoreSize(LI->getType()));
+ const Align Alignment = DL.getValueOrABITypeAlignment(
+ MaybeAlign(LI->getAlignment()), LI->getType());
+
+ Instruction *HeaderFirstNonPHI = L->getHeader()->getFirstNonPHI();
+
+ // If given a uniform (i.e. non-varying) address, see if we can prove the
+ // access is safe within the loop w/o needing predication.
+ if (L->isLoopInvariant(Ptr))
+ return isDereferenceableAndAlignedPointer(Ptr, Alignment, EltSize, DL,
+ HeaderFirstNonPHI, &DT);
+
+ // Otherwise, check to see if we have a repeating access pattern where we can
+ // prove that all accesses are well aligned and dereferenceable.
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Ptr));
+ if (!AddRec || AddRec->getLoop() != L || !AddRec->isAffine())
+ return false;
+ auto* Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE));
+ if (!Step)
+ return false;
+ // TODO: generalize to access patterns which have gaps
+ if (Step->getAPInt() != EltSize)
+ return false;
+
+ // TODO: If the symbolic trip count has a small bound (max count), we might
+ // be able to prove safety.
+ auto TC = SE.getSmallConstantTripCount(L);
+ if (!TC)
+ return false;
+
+ const APInt AccessSize = TC * EltSize;
+
+ auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart());
+ if (!StartS)
+ return false;
+ assert(SE.isLoopInvariant(StartS, L) && "implied by addrec definition");
+ Value *Base = StartS->getValue();
+
+ // For the moment, restrict ourselves to the case where the access size is a
+ // multiple of the requested alignment and the base is aligned.
+ // TODO: generalize if a case found which warrants
+ if (EltSize.urem(Alignment.value()) != 0)
+ return false;
+ return isDereferenceableAndAlignedPointer(Base, Alignment, AccessSize, DL,
+ HeaderFirstNonPHI, &DT);
+}
+
+/// Check if executing a load of this pointer value cannot trap.
+///
+/// If DT and ScanFrom are specified this method performs context-sensitive
+/// analysis and returns true if it is safe to load immediately before ScanFrom.
+///
+/// If it is not obviously safe to load from the specified pointer, we do
+/// a quick local scan of the basic block containing \c ScanFrom, to determine
+/// if the address is already accessed.
+///
+/// This uses the pointee type to determine how many bytes need to be safe to
+/// load from the pointer.
+bool llvm::isSafeToLoadUnconditionally(Value *V, MaybeAlign MA, APInt &Size,
+ const DataLayout &DL,
+ Instruction *ScanFrom,
+ const DominatorTree *DT) {
+ // Zero alignment means that the load has the ABI alignment for the target
+ const Align Alignment =
+ DL.getValueOrABITypeAlignment(MA, V->getType()->getPointerElementType());
+
+ // If DT is not specified we can't make context-sensitive query
+ const Instruction* CtxI = DT ? ScanFrom : nullptr;
+ if (isDereferenceableAndAlignedPointer(V, Alignment, Size, DL, CtxI, DT))
+ return true;
+
+ if (!ScanFrom)
+ return false;
+
+ if (Size.getBitWidth() > 64)
+ return false;
+ const uint64_t LoadSize = Size.getZExtValue();
+
+ // Otherwise, be a little bit aggressive by scanning the local block where we
+ // want to check to see if the pointer is already being loaded or stored
+ // from/to. If so, the previous load or store would have already trapped,
+ // so there is no harm doing an extra load (also, CSE will later eliminate
+ // the load entirely).
+ BasicBlock::iterator BBI = ScanFrom->getIterator(),
+ E = ScanFrom->getParent()->begin();
+
+ // We can at least always strip pointer casts even though we can't use the
+ // base here.
+ V = V->stripPointerCasts();
+
+ while (BBI != E) {
+ --BBI;
+
+ // If we see a free or a call which may write to memory (i.e. which might do
+ // a free) the pointer could be marked invalid.
+ if (isa<CallInst>(BBI) && BBI->mayWriteToMemory() &&
+ !isa<DbgInfoIntrinsic>(BBI))
+ return false;
+
+ Value *AccessedPtr;
+ MaybeAlign MaybeAccessedAlign;
+ if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) {
+ // Ignore volatile loads. The execution of a volatile load cannot
+ // be used to prove an address is backed by regular memory; it can,
+ // for example, point to an MMIO register.
+ if (LI->isVolatile())
+ continue;
+ AccessedPtr = LI->getPointerOperand();
+ MaybeAccessedAlign = MaybeAlign(LI->getAlignment());
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) {
+ // Ignore volatile stores (see comment for loads).
+ if (SI->isVolatile())
+ continue;
+ AccessedPtr = SI->getPointerOperand();
+ MaybeAccessedAlign = MaybeAlign(SI->getAlignment());
+ } else
+ continue;
+
+ Type *AccessedTy = AccessedPtr->getType()->getPointerElementType();
+
+ const Align AccessedAlign =
+ DL.getValueOrABITypeAlignment(MaybeAccessedAlign, AccessedTy);
+ if (AccessedAlign < Alignment)
+ continue;
+
+ // Handle trivial cases.
+ if (AccessedPtr == V &&
+ LoadSize <= DL.getTypeStoreSize(AccessedTy))
+ return true;
+
+ if (AreEquivalentAddressValues(AccessedPtr->stripPointerCasts(), V) &&
+ LoadSize <= DL.getTypeStoreSize(AccessedTy))
+ return true;
+ }
+ return false;
+}
+
+bool llvm::isSafeToLoadUnconditionally(Value *V, Type *Ty, MaybeAlign Alignment,
+ const DataLayout &DL,
+ Instruction *ScanFrom,
+ const DominatorTree *DT) {
+ APInt Size(DL.getIndexTypeSizeInBits(V->getType()), DL.getTypeStoreSize(Ty));
+ return isSafeToLoadUnconditionally(V, Alignment, Size, DL, ScanFrom, DT);
+}
+
+ /// DefMaxInstsToScan - the default number of maximum instructions
+/// to scan in the block, used by FindAvailableLoadedValue().
+/// FindAvailableLoadedValue() was introduced in r60148, to improve jump
+/// threading in part by eliminating partially redundant loads.
+/// At that point, the value of MaxInstsToScan was already set to '6'
+/// without documented explanation.
+cl::opt<unsigned>
+llvm::DefMaxInstsToScan("available-load-scan-limit", cl::init(6), cl::Hidden,
+ cl::desc("Use this to specify the default maximum number of instructions "
+ "to scan backward from a given instruction, when searching for "
+ "available loaded value"));
+
+Value *llvm::FindAvailableLoadedValue(LoadInst *Load,
+ BasicBlock *ScanBB,
+ BasicBlock::iterator &ScanFrom,
+ unsigned MaxInstsToScan,
+ AliasAnalysis *AA, bool *IsLoad,
+ unsigned *NumScanedInst) {
+ // Don't CSE load that is volatile or anything stronger than unordered.
+ if (!Load->isUnordered())
+ return nullptr;
+
+ return FindAvailablePtrLoadStore(
+ Load->getPointerOperand(), Load->getType(), Load->isAtomic(), ScanBB,
+ ScanFrom, MaxInstsToScan, AA, IsLoad, NumScanedInst);
+}
+
+Value *llvm::FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy,
+ bool AtLeastAtomic, BasicBlock *ScanBB,
+ BasicBlock::iterator &ScanFrom,
+ unsigned MaxInstsToScan,
+ AliasAnalysis *AA, bool *IsLoadCSE,
+ unsigned *NumScanedInst) {
+ if (MaxInstsToScan == 0)
+ MaxInstsToScan = ~0U;
+
+ const DataLayout &DL = ScanBB->getModule()->getDataLayout();
+
+ // Try to get the store size for the type.
+ auto AccessSize = LocationSize::precise(DL.getTypeStoreSize(AccessTy));
+
+ Value *StrippedPtr = Ptr->stripPointerCasts();
+
+ while (ScanFrom != ScanBB->begin()) {
+ // We must ignore debug info directives when counting (otherwise they
+ // would affect codegen).
+ Instruction *Inst = &*--ScanFrom;
+ if (isa<DbgInfoIntrinsic>(Inst))
+ continue;
+
+ // Restore ScanFrom to expected value in case next test succeeds
+ ScanFrom++;
+
+ if (NumScanedInst)
+ ++(*NumScanedInst);
+
+ // Don't scan huge blocks.
+ if (MaxInstsToScan-- == 0)
+ return nullptr;
+
+ --ScanFrom;
+ // If this is a load of Ptr, the loaded value is available.
+ // (This is true even if the load is volatile or atomic, although
+ // those cases are unlikely.)
+ if (LoadInst *LI = dyn_cast<LoadInst>(Inst))
+ if (AreEquivalentAddressValues(
+ LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) &&
+ CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) {
+
+ // We can value forward from an atomic to a non-atomic, but not the
+ // other way around.
+ if (LI->isAtomic() < AtLeastAtomic)
+ return nullptr;
+
+ if (IsLoadCSE)
+ *IsLoadCSE = true;
+ return LI;
+ }
+
+ if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+ Value *StorePtr = SI->getPointerOperand()->stripPointerCasts();
+ // If this is a store through Ptr, the value is available!
+ // (This is true even if the store is volatile or atomic, although
+ // those cases are unlikely.)
+ if (AreEquivalentAddressValues(StorePtr, StrippedPtr) &&
+ CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(),
+ AccessTy, DL)) {
+
+ // We can value forward from an atomic to a non-atomic, but not the
+ // other way around.
+ if (SI->isAtomic() < AtLeastAtomic)
+ return nullptr;
+
+ if (IsLoadCSE)
+ *IsLoadCSE = false;
+ return SI->getOperand(0);
+ }
+
+ // If both StrippedPtr and StorePtr reach all the way to an alloca or
+ // global and they are different, ignore the store. This is a trivial form
+ // of alias analysis that is important for reg2mem'd code.
+ if ((isa<AllocaInst>(StrippedPtr) || isa<GlobalVariable>(StrippedPtr)) &&
+ (isa<AllocaInst>(StorePtr) || isa<GlobalVariable>(StorePtr)) &&
+ StrippedPtr != StorePtr)
+ continue;
+
+ // If we have alias analysis and it says the store won't modify the loaded
+ // value, ignore the store.
+ if (AA && !isModSet(AA->getModRefInfo(SI, StrippedPtr, AccessSize)))
+ continue;
+
+ // Otherwise the store that may or may not alias the pointer, bail out.
+ ++ScanFrom;
+ return nullptr;
+ }
+
+ // If this is some other instruction that may clobber Ptr, bail out.
+ if (Inst->mayWriteToMemory()) {
+ // If alias analysis claims that it really won't modify the load,
+ // ignore it.
+ if (AA && !isModSet(AA->getModRefInfo(Inst, StrippedPtr, AccessSize)))
+ continue;
+
+ // May modify the pointer, bail out.
+ ++ScanFrom;
+ return nullptr;
+ }
+ }
+
+ // Got to the start of the block, we didn't find it, but are done for this
+ // block.
+ return nullptr;
+}
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
new file mode 100644
index 000000000000..3d8f77675f3a
--- /dev/null
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -0,0 +1,2464 @@
+//===- LoopAccessAnalysis.cpp - Loop Access Analysis Implementation --------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The implementation for the loop memory dependence that was originally
+// developed for the loop vectorizer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AliasSetTracker.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugLoc.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+#include <iterator>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-accesses"
+
+static cl::opt<unsigned, true>
+VectorizationFactor("force-vector-width", cl::Hidden,
+ cl::desc("Sets the SIMD width. Zero is autoselect."),
+ cl::location(VectorizerParams::VectorizationFactor));
+unsigned VectorizerParams::VectorizationFactor;
+
+static cl::opt<unsigned, true>
+VectorizationInterleave("force-vector-interleave", cl::Hidden,
+ cl::desc("Sets the vectorization interleave count. "
+ "Zero is autoselect."),
+ cl::location(
+ VectorizerParams::VectorizationInterleave));
+unsigned VectorizerParams::VectorizationInterleave;
+
+static cl::opt<unsigned, true> RuntimeMemoryCheckThreshold(
+ "runtime-memory-check-threshold", cl::Hidden,
+ cl::desc("When performing memory disambiguation checks at runtime do not "
+ "generate more than this number of comparisons (default = 8)."),
+ cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8));
+unsigned VectorizerParams::RuntimeMemoryCheckThreshold;
+
+/// The maximum iterations used to merge memory checks
+static cl::opt<unsigned> MemoryCheckMergeThreshold(
+ "memory-check-merge-threshold", cl::Hidden,
+ cl::desc("Maximum number of comparisons done when trying to merge "
+ "runtime memory checks. (default = 100)"),
+ cl::init(100));
+
+/// Maximum SIMD width.
+const unsigned VectorizerParams::MaxVectorWidth = 64;
+
+/// We collect dependences up to this threshold.
+static cl::opt<unsigned>
+ MaxDependences("max-dependences", cl::Hidden,
+ cl::desc("Maximum number of dependences collected by "
+ "loop-access analysis (default = 100)"),
+ cl::init(100));
+
+/// This enables versioning on the strides of symbolically striding memory
+/// accesses in code like the following.
+/// for (i = 0; i < N; ++i)
+/// A[i * Stride1] += B[i * Stride2] ...
+///
+/// Will be roughly translated to
+/// if (Stride1 == 1 && Stride2 == 1) {
+/// for (i = 0; i < N; i+=4)
+/// A[i:i+3] += ...
+/// } else
+/// ...
+static cl::opt<bool> EnableMemAccessVersioning(
+ "enable-mem-access-versioning", cl::init(true), cl::Hidden,
+ cl::desc("Enable symbolic stride memory access versioning"));
+
+/// Enable store-to-load forwarding conflict detection. This option can
+/// be disabled for correctness testing.
+static cl::opt<bool> EnableForwardingConflictDetection(
+ "store-to-load-forwarding-conflict-detection", cl::Hidden,
+ cl::desc("Enable conflict detection in loop-access analysis"),
+ cl::init(true));
+
+bool VectorizerParams::isInterleaveForced() {
+ return ::VectorizationInterleave.getNumOccurrences() > 0;
+}
+
+Value *llvm::stripIntegerCast(Value *V) {
+ if (auto *CI = dyn_cast<CastInst>(V))
+ if (CI->getOperand(0)->getType()->isIntegerTy())
+ return CI->getOperand(0);
+ return V;
+}
+
+const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
+ const ValueToValueMap &PtrToStride,
+ Value *Ptr, Value *OrigPtr) {
+ const SCEV *OrigSCEV = PSE.getSCEV(Ptr);
+
+ // If there is an entry in the map return the SCEV of the pointer with the
+ // symbolic stride replaced by one.
+ ValueToValueMap::const_iterator SI =
+ PtrToStride.find(OrigPtr ? OrigPtr : Ptr);
+ if (SI != PtrToStride.end()) {
+ Value *StrideVal = SI->second;
+
+ // Strip casts.
+ StrideVal = stripIntegerCast(StrideVal);
+
+ ScalarEvolution *SE = PSE.getSE();
+ const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
+ const auto *CT =
+ static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
+
+ PSE.addPredicate(*SE->getEqualPredicate(U, CT));
+ auto *Expr = PSE.getSCEV(Ptr);
+
+ LLVM_DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV
+ << " by: " << *Expr << "\n");
+ return Expr;
+ }
+
+ // Otherwise, just return the SCEV of the original pointer.
+ return OrigSCEV;
+}
+
+/// Calculate Start and End points of memory access.
+/// Let's assume A is the first access and B is a memory access on N-th loop
+/// iteration. Then B is calculated as:
+/// B = A + Step*N .
+/// Step value may be positive or negative.
+/// N is a calculated back-edge taken count:
+/// N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0
+/// Start and End points are calculated in the following way:
+/// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt,
+/// where SizeOfElt is the size of single memory access in bytes.
+///
+/// There is no conflict when the intervals are disjoint:
+/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
+void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
+ unsigned DepSetId, unsigned ASId,
+ const ValueToValueMap &Strides,
+ PredicatedScalarEvolution &PSE) {
+ // Get the stride replaced scev.
+ const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
+ ScalarEvolution *SE = PSE.getSE();
+
+ const SCEV *ScStart;
+ const SCEV *ScEnd;
+
+ if (SE->isLoopInvariant(Sc, Lp))
+ ScStart = ScEnd = Sc;
+ else {
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
+ assert(AR && "Invalid addrec expression");
+ const SCEV *Ex = PSE.getBackedgeTakenCount();
+
+ ScStart = AR->getStart();
+ ScEnd = AR->evaluateAtIteration(Ex, *SE);
+ const SCEV *Step = AR->getStepRecurrence(*SE);
+
+ // For expressions with negative step, the upper bound is ScStart and the
+ // lower bound is ScEnd.
+ if (const auto *CStep = dyn_cast<SCEVConstant>(Step)) {
+ if (CStep->getValue()->isNegative())
+ std::swap(ScStart, ScEnd);
+ } else {
+ // Fallback case: the step is not constant, but we can still
+ // get the upper and lower bounds of the interval by using min/max
+ // expressions.
+ ScStart = SE->getUMinExpr(ScStart, ScEnd);
+ ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd);
+ }
+ // Add the size of the pointed element to ScEnd.
+ unsigned EltSize =
+ Ptr->getType()->getPointerElementType()->getScalarSizeInBits() / 8;
+ const SCEV *EltSizeSCEV = SE->getConstant(ScEnd->getType(), EltSize);
+ ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
+ }
+
+ Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc);
+}
+
+SmallVector<RuntimePointerChecking::PointerCheck, 4>
+RuntimePointerChecking::generateChecks() const {
+ SmallVector<PointerCheck, 4> Checks;
+
+ for (unsigned I = 0; I < CheckingGroups.size(); ++I) {
+ for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) {
+ const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I];
+ const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J];
+
+ if (needsChecking(CGI, CGJ))
+ Checks.push_back(std::make_pair(&CGI, &CGJ));
+ }
+ }
+ return Checks;
+}
+
+void RuntimePointerChecking::generateChecks(
+ MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) {
+ assert(Checks.empty() && "Checks is not empty");
+ groupChecks(DepCands, UseDependencies);
+ Checks = generateChecks();
+}
+
+bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M,
+ const CheckingPtrGroup &N) const {
+ for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I)
+ for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J)
+ if (needsChecking(M.Members[I], N.Members[J]))
+ return true;
+ return false;
+}
+
+/// Compare \p I and \p J and return the minimum.
+/// Return nullptr in case we couldn't find an answer.
+static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J,
+ ScalarEvolution *SE) {
+ const SCEV *Diff = SE->getMinusSCEV(J, I);
+ const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff);
+
+ if (!C)
+ return nullptr;
+ if (C->getValue()->isNegative())
+ return J;
+ return I;
+}
+
+bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) {
+ const SCEV *Start = RtCheck.Pointers[Index].Start;
+ const SCEV *End = RtCheck.Pointers[Index].End;
+
+ // Compare the starts and ends with the known minimum and maximum
+ // of this set. We need to know how we compare against the min/max
+ // of the set in order to be able to emit memchecks.
+ const SCEV *Min0 = getMinFromExprs(Start, Low, RtCheck.SE);
+ if (!Min0)
+ return false;
+
+ const SCEV *Min1 = getMinFromExprs(End, High, RtCheck.SE);
+ if (!Min1)
+ return false;
+
+ // Update the low bound expression if we've found a new min value.
+ if (Min0 == Start)
+ Low = Start;
+
+ // Update the high bound expression if we've found a new max value.
+ if (Min1 != End)
+ High = End;
+
+ Members.push_back(Index);
+ return true;
+}
+
+void RuntimePointerChecking::groupChecks(
+ MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) {
+ // We build the groups from dependency candidates equivalence classes
+ // because:
+ // - We know that pointers in the same equivalence class share
+ // the same underlying object and therefore there is a chance
+ // that we can compare pointers
+ // - We wouldn't be able to merge two pointers for which we need
+ // to emit a memcheck. The classes in DepCands are already
+ // conveniently built such that no two pointers in the same
+ // class need checking against each other.
+
+ // We use the following (greedy) algorithm to construct the groups
+ // For every pointer in the equivalence class:
+ // For each existing group:
+ // - if the difference between this pointer and the min/max bounds
+ // of the group is a constant, then make the pointer part of the
+ // group and update the min/max bounds of that group as required.
+
+ CheckingGroups.clear();
+
+ // If we need to check two pointers to the same underlying object
+ // with a non-constant difference, we shouldn't perform any pointer
+ // grouping with those pointers. This is because we can easily get
+ // into cases where the resulting check would return false, even when
+ // the accesses are safe.
+ //
+ // The following example shows this:
+ // for (i = 0; i < 1000; ++i)
+ // a[5000 + i * m] = a[i] + a[i + 9000]
+ //
+ // Here grouping gives a check of (5000, 5000 + 1000 * m) against
+ // (0, 10000) which is always false. However, if m is 1, there is no
+ // dependence. Not grouping the checks for a[i] and a[i + 9000] allows
+ // us to perform an accurate check in this case.
+ //
+ // The above case requires that we have an UnknownDependence between
+ // accesses to the same underlying object. This cannot happen unless
+ // FoundNonConstantDistanceDependence is set, and therefore UseDependencies
+ // is also false. In this case we will use the fallback path and create
+ // separate checking groups for all pointers.
+
+ // If we don't have the dependency partitions, construct a new
+ // checking pointer group for each pointer. This is also required
+ // for correctness, because in this case we can have checking between
+ // pointers to the same underlying object.
+ if (!UseDependencies) {
+ for (unsigned I = 0; I < Pointers.size(); ++I)
+ CheckingGroups.push_back(CheckingPtrGroup(I, *this));
+ return;
+ }
+
+ unsigned TotalComparisons = 0;
+
+ DenseMap<Value *, unsigned> PositionMap;
+ for (unsigned Index = 0; Index < Pointers.size(); ++Index)
+ PositionMap[Pointers[Index].PointerValue] = Index;
+
+ // We need to keep track of what pointers we've already seen so we
+ // don't process them twice.
+ SmallSet<unsigned, 2> Seen;
+
+ // Go through all equivalence classes, get the "pointer check groups"
+ // and add them to the overall solution. We use the order in which accesses
+ // appear in 'Pointers' to enforce determinism.
+ for (unsigned I = 0; I < Pointers.size(); ++I) {
+ // We've seen this pointer before, and therefore already processed
+ // its equivalence class.
+ if (Seen.count(I))
+ continue;
+
+ MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue,
+ Pointers[I].IsWritePtr);
+
+ SmallVector<CheckingPtrGroup, 2> Groups;
+ auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access));
+
+ // Because DepCands is constructed by visiting accesses in the order in
+ // which they appear in alias sets (which is deterministic) and the
+ // iteration order within an equivalence class member is only dependent on
+ // the order in which unions and insertions are performed on the
+ // equivalence class, the iteration order is deterministic.
+ for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end();
+ MI != ME; ++MI) {
+ unsigned Pointer = PositionMap[MI->getPointer()];
+ bool Merged = false;
+ // Mark this pointer as seen.
+ Seen.insert(Pointer);
+
+ // Go through all the existing sets and see if we can find one
+ // which can include this pointer.
+ for (CheckingPtrGroup &Group : Groups) {
+ // Don't perform more than a certain amount of comparisons.
+ // This should limit the cost of grouping the pointers to something
+ // reasonable. If we do end up hitting this threshold, the algorithm
+ // will create separate groups for all remaining pointers.
+ if (TotalComparisons > MemoryCheckMergeThreshold)
+ break;
+
+ TotalComparisons++;
+
+ if (Group.addPointer(Pointer)) {
+ Merged = true;
+ break;
+ }
+ }
+
+ if (!Merged)
+ // We couldn't add this pointer to any existing set or the threshold
+ // for the number of comparisons has been reached. Create a new group
+ // to hold the current pointer.
+ Groups.push_back(CheckingPtrGroup(Pointer, *this));
+ }
+
+ // We've computed the grouped checks for this partition.
+ // Save the results and continue with the next one.
+ llvm::copy(Groups, std::back_inserter(CheckingGroups));
+ }
+}
+
+bool RuntimePointerChecking::arePointersInSamePartition(
+ const SmallVectorImpl<int> &PtrToPartition, unsigned PtrIdx1,
+ unsigned PtrIdx2) {
+ return (PtrToPartition[PtrIdx1] != -1 &&
+ PtrToPartition[PtrIdx1] == PtrToPartition[PtrIdx2]);
+}
+
+bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const {
+ const PointerInfo &PointerI = Pointers[I];
+ const PointerInfo &PointerJ = Pointers[J];
+
+ // No need to check if two readonly pointers intersect.
+ if (!PointerI.IsWritePtr && !PointerJ.IsWritePtr)
+ return false;
+
+ // Only need to check pointers between two different dependency sets.
+ if (PointerI.DependencySetId == PointerJ.DependencySetId)
+ return false;
+
+ // Only need to check pointers in the same alias set.
+ if (PointerI.AliasSetId != PointerJ.AliasSetId)
+ return false;
+
+ return true;
+}
+
+void RuntimePointerChecking::printChecks(
+ raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks,
+ unsigned Depth) const {
+ unsigned N = 0;
+ for (const auto &Check : Checks) {
+ const auto &First = Check.first->Members, &Second = Check.second->Members;
+
+ OS.indent(Depth) << "Check " << N++ << ":\n";
+
+ OS.indent(Depth + 2) << "Comparing group (" << Check.first << "):\n";
+ for (unsigned K = 0; K < First.size(); ++K)
+ OS.indent(Depth + 2) << *Pointers[First[K]].PointerValue << "\n";
+
+ OS.indent(Depth + 2) << "Against group (" << Check.second << "):\n";
+ for (unsigned K = 0; K < Second.size(); ++K)
+ OS.indent(Depth + 2) << *Pointers[Second[K]].PointerValue << "\n";
+ }
+}
+
+void RuntimePointerChecking::print(raw_ostream &OS, unsigned Depth) const {
+
+ OS.indent(Depth) << "Run-time memory checks:\n";
+ printChecks(OS, Checks, Depth);
+
+ OS.indent(Depth) << "Grouped accesses:\n";
+ for (unsigned I = 0; I < CheckingGroups.size(); ++I) {
+ const auto &CG = CheckingGroups[I];
+
+ OS.indent(Depth + 2) << "Group " << &CG << ":\n";
+ OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High
+ << ")\n";
+ for (unsigned J = 0; J < CG.Members.size(); ++J) {
+ OS.indent(Depth + 6) << "Member: " << *Pointers[CG.Members[J]].Expr
+ << "\n";
+ }
+ }
+}
+
+namespace {
+
+/// Analyses memory accesses in a loop.
+///
+/// Checks whether run time pointer checks are needed and builds sets for data
+/// dependence checking.
+class AccessAnalysis {
+public:
+ /// Read or write access location.
+ typedef PointerIntPair<Value *, 1, bool> MemAccessInfo;
+ typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList;
+
+ AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AliasAnalysis *AA,
+ LoopInfo *LI, MemoryDepChecker::DepCandidates &DA,
+ PredicatedScalarEvolution &PSE)
+ : DL(Dl), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA),
+ IsRTCheckAnalysisNeeded(false), PSE(PSE) {}
+
+ /// Register a load and whether it is only read from.
+ void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
+ Value *Ptr = const_cast<Value*>(Loc.Ptr);
+ AST.add(Ptr, LocationSize::unknown(), Loc.AATags);
+ Accesses.insert(MemAccessInfo(Ptr, false));
+ if (IsReadOnly)
+ ReadOnlyPtr.insert(Ptr);
+ }
+
+ /// Register a store.
+ void addStore(MemoryLocation &Loc) {
+ Value *Ptr = const_cast<Value*>(Loc.Ptr);
+ AST.add(Ptr, LocationSize::unknown(), Loc.AATags);
+ Accesses.insert(MemAccessInfo(Ptr, true));
+ }
+
+ /// Check if we can emit a run-time no-alias check for \p Access.
+ ///
+ /// Returns true if we can emit a run-time no alias check for \p Access.
+ /// If we can check this access, this also adds it to a dependence set and
+ /// adds a run-time to check for it to \p RtCheck. If \p Assume is true,
+ /// we will attempt to use additional run-time checks in order to get
+ /// the bounds of the pointer.
+ bool createCheckForAccess(RuntimePointerChecking &RtCheck,
+ MemAccessInfo Access,
+ const ValueToValueMap &Strides,
+ DenseMap<Value *, unsigned> &DepSetId,
+ Loop *TheLoop, unsigned &RunningDepId,
+ unsigned ASId, bool ShouldCheckStride,
+ bool Assume);
+
+ /// Check whether we can check the pointers at runtime for
+ /// non-intersection.
+ ///
+ /// Returns true if we need no check or if we do and we can generate them
+ /// (i.e. the pointers have computable bounds).
+ bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE,
+ Loop *TheLoop, const ValueToValueMap &Strides,
+ bool ShouldCheckWrap = false);
+
+ /// Goes over all memory accesses, checks whether a RT check is needed
+ /// and builds sets of dependent accesses.
+ void buildDependenceSets() {
+ processMemAccesses();
+ }
+
+ /// Initial processing of memory accesses determined that we need to
+ /// perform dependency checking.
+ ///
+ /// Note that this can later be cleared if we retry memcheck analysis without
+ /// dependency checking (i.e. FoundNonConstantDistanceDependence).
+ bool isDependencyCheckNeeded() { return !CheckDeps.empty(); }
+
+ /// We decided that no dependence analysis would be used. Reset the state.
+ void resetDepChecks(MemoryDepChecker &DepChecker) {
+ CheckDeps.clear();
+ DepChecker.clearDependences();
+ }
+
+ MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; }
+
+private:
+ typedef SetVector<MemAccessInfo> PtrAccessSet;
+
+ /// Go over all memory access and check whether runtime pointer checks
+ /// are needed and build sets of dependency check candidates.
+ void processMemAccesses();
+
+ /// Set of all accesses.
+ PtrAccessSet Accesses;
+
+ const DataLayout &DL;
+
+ /// The loop being checked.
+ const Loop *TheLoop;
+
+ /// List of accesses that need a further dependence check.
+ MemAccessInfoList CheckDeps;
+
+ /// Set of pointers that are read only.
+ SmallPtrSet<Value*, 16> ReadOnlyPtr;
+
+ /// An alias set tracker to partition the access set by underlying object and
+ //intrinsic property (such as TBAA metadata).
+ AliasSetTracker AST;
+
+ LoopInfo *LI;
+
+ /// Sets of potentially dependent accesses - members of one set share an
+ /// underlying pointer. The set "CheckDeps" identfies which sets really need a
+ /// dependence check.
+ MemoryDepChecker::DepCandidates &DepCands;
+
+ /// Initial processing of memory accesses determined that we may need
+ /// to add memchecks. Perform the analysis to determine the necessary checks.
+ ///
+ /// Note that, this is different from isDependencyCheckNeeded. When we retry
+ /// memcheck analysis without dependency checking
+ /// (i.e. FoundNonConstantDistanceDependence), isDependencyCheckNeeded is
+ /// cleared while this remains set if we have potentially dependent accesses.
+ bool IsRTCheckAnalysisNeeded;
+
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ PredicatedScalarEvolution &PSE;
+};
+
+} // end anonymous namespace
+
+/// Check whether a pointer can participate in a runtime bounds check.
+/// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
+/// by adding run-time checks (overflow checks) if necessary.
+static bool hasComputableBounds(PredicatedScalarEvolution &PSE,
+ const ValueToValueMap &Strides, Value *Ptr,
+ Loop *L, bool Assume) {
+ const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
+
+ // The bounds for loop-invariant pointer is trivial.
+ if (PSE.getSE()->isLoopInvariant(PtrScev, L))
+ return true;
+
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
+
+ if (!AR && Assume)
+ AR = PSE.getAsAddRec(Ptr);
+
+ if (!AR)
+ return false;
+
+ return AR->isAffine();
+}
+
+/// Check whether a pointer address cannot wrap.
+static bool isNoWrap(PredicatedScalarEvolution &PSE,
+ const ValueToValueMap &Strides, Value *Ptr, Loop *L) {
+ const SCEV *PtrScev = PSE.getSCEV(Ptr);
+ if (PSE.getSE()->isLoopInvariant(PtrScev, L))
+ return true;
+
+ int64_t Stride = getPtrStride(PSE, Ptr, L, Strides);
+ if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW))
+ return true;
+
+ return false;
+}
+
+bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
+ MemAccessInfo Access,
+ const ValueToValueMap &StridesMap,
+ DenseMap<Value *, unsigned> &DepSetId,
+ Loop *TheLoop, unsigned &RunningDepId,
+ unsigned ASId, bool ShouldCheckWrap,
+ bool Assume) {
+ Value *Ptr = Access.getPointer();
+
+ if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume))
+ return false;
+
+ // When we run after a failing dependency check we have to make sure
+ // we don't have wrapping pointers.
+ if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) {
+ auto *Expr = PSE.getSCEV(Ptr);
+ if (!Assume || !isa<SCEVAddRecExpr>(Expr))
+ return false;
+ PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
+ }
+
+ // The id of the dependence set.
+ unsigned DepId;
+
+ if (isDependencyCheckNeeded()) {
+ Value *Leader = DepCands.getLeaderValue(Access).getPointer();
+ unsigned &LeaderId = DepSetId[Leader];
+ if (!LeaderId)
+ LeaderId = RunningDepId++;
+ DepId = LeaderId;
+ } else
+ // Each access has its own dependence set.
+ DepId = RunningDepId++;
+
+ bool IsWrite = Access.getInt();
+ RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE);
+ LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
+
+ return true;
+ }
+
+bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
+ ScalarEvolution *SE, Loop *TheLoop,
+ const ValueToValueMap &StridesMap,
+ bool ShouldCheckWrap) {
+ // Find pointers with computable bounds. We are going to use this information
+ // to place a runtime bound check.
+ bool CanDoRT = true;
+
+ bool NeedRTCheck = false;
+ if (!IsRTCheckAnalysisNeeded) return true;
+
+ bool IsDepCheckNeeded = isDependencyCheckNeeded();
+
+ // We assign a consecutive id to access from different alias sets.
+ // Accesses between different groups doesn't need to be checked.
+ unsigned ASId = 1;
+ for (auto &AS : AST) {
+ int NumReadPtrChecks = 0;
+ int NumWritePtrChecks = 0;
+ bool CanDoAliasSetRT = true;
+
+ // We assign consecutive id to access from different dependence sets.
+ // Accesses within the same set don't need a runtime check.
+ unsigned RunningDepId = 1;
+ DenseMap<Value *, unsigned> DepSetId;
+
+ SmallVector<MemAccessInfo, 4> Retries;
+
+ for (auto A : AS) {
+ Value *Ptr = A.getValue();
+ bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true));
+ MemAccessInfo Access(Ptr, IsWrite);
+
+ if (IsWrite)
+ ++NumWritePtrChecks;
+ else
+ ++NumReadPtrChecks;
+
+ if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop,
+ RunningDepId, ASId, ShouldCheckWrap, false)) {
+ LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Ptr << '\n');
+ Retries.push_back(Access);
+ CanDoAliasSetRT = false;
+ }
+ }
+
+ // If we have at least two writes or one write and a read then we need to
+ // check them. But there is no need to checks if there is only one
+ // dependence set for this alias set.
+ //
+ // Note that this function computes CanDoRT and NeedRTCheck independently.
+ // For example CanDoRT=false, NeedRTCheck=false means that we have a pointer
+ // for which we couldn't find the bounds but we don't actually need to emit
+ // any checks so it does not matter.
+ bool NeedsAliasSetRTCheck = false;
+ if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2))
+ NeedsAliasSetRTCheck = (NumWritePtrChecks >= 2 ||
+ (NumReadPtrChecks >= 1 && NumWritePtrChecks >= 1));
+
+ // We need to perform run-time alias checks, but some pointers had bounds
+ // that couldn't be checked.
+ if (NeedsAliasSetRTCheck && !CanDoAliasSetRT) {
+ // Reset the CanDoSetRt flag and retry all accesses that have failed.
+ // We know that we need these checks, so we can now be more aggressive
+ // and add further checks if required (overflow checks).
+ CanDoAliasSetRT = true;
+ for (auto Access : Retries)
+ if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId,
+ TheLoop, RunningDepId, ASId,
+ ShouldCheckWrap, /*Assume=*/true)) {
+ CanDoAliasSetRT = false;
+ break;
+ }
+ }
+
+ CanDoRT &= CanDoAliasSetRT;
+ NeedRTCheck |= NeedsAliasSetRTCheck;
+ ++ASId;
+ }
+
+ // If the pointers that we would use for the bounds comparison have different
+ // address spaces, assume the values aren't directly comparable, so we can't
+ // use them for the runtime check. We also have to assume they could
+ // overlap. In the future there should be metadata for whether address spaces
+ // are disjoint.
+ unsigned NumPointers = RtCheck.Pointers.size();
+ for (unsigned i = 0; i < NumPointers; ++i) {
+ for (unsigned j = i + 1; j < NumPointers; ++j) {
+ // Only need to check pointers between two different dependency sets.
+ if (RtCheck.Pointers[i].DependencySetId ==
+ RtCheck.Pointers[j].DependencySetId)
+ continue;
+ // Only need to check pointers in the same alias set.
+ if (RtCheck.Pointers[i].AliasSetId != RtCheck.Pointers[j].AliasSetId)
+ continue;
+
+ Value *PtrI = RtCheck.Pointers[i].PointerValue;
+ Value *PtrJ = RtCheck.Pointers[j].PointerValue;
+
+ unsigned ASi = PtrI->getType()->getPointerAddressSpace();
+ unsigned ASj = PtrJ->getType()->getPointerAddressSpace();
+ if (ASi != ASj) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: Runtime check would require comparison between"
+ " different address spaces\n");
+ return false;
+ }
+ }
+ }
+
+ if (NeedRTCheck && CanDoRT)
+ RtCheck.generateChecks(DepCands, IsDepCheckNeeded);
+
+ LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks()
+ << " pointer comparisons.\n");
+
+ RtCheck.Need = NeedRTCheck;
+
+ bool CanDoRTIfNeeded = !NeedRTCheck || CanDoRT;
+ if (!CanDoRTIfNeeded)
+ RtCheck.reset();
+ return CanDoRTIfNeeded;
+}
+
+void AccessAnalysis::processMemAccesses() {
+ // We process the set twice: first we process read-write pointers, last we
+ // process read-only pointers. This allows us to skip dependence tests for
+ // read-only pointers.
+
+ LLVM_DEBUG(dbgs() << "LAA: Processing memory accesses...\n");
+ LLVM_DEBUG(dbgs() << " AST: "; AST.dump());
+ LLVM_DEBUG(dbgs() << "LAA: Accesses(" << Accesses.size() << "):\n");
+ LLVM_DEBUG({
+ for (auto A : Accesses)
+ dbgs() << "\t" << *A.getPointer() << " (" <<
+ (A.getInt() ? "write" : (ReadOnlyPtr.count(A.getPointer()) ?
+ "read-only" : "read")) << ")\n";
+ });
+
+ // The AliasSetTracker has nicely partitioned our pointers by metadata
+ // compatibility and potential for underlying-object overlap. As a result, we
+ // only need to check for potential pointer dependencies within each alias
+ // set.
+ for (auto &AS : AST) {
+ // Note that both the alias-set tracker and the alias sets themselves used
+ // linked lists internally and so the iteration order here is deterministic
+ // (matching the original instruction order within each set).
+
+ bool SetHasWrite = false;
+
+ // Map of pointers to last access encountered.
+ typedef DenseMap<const Value*, MemAccessInfo> UnderlyingObjToAccessMap;
+ UnderlyingObjToAccessMap ObjToLastAccess;
+
+ // Set of access to check after all writes have been processed.
+ PtrAccessSet DeferredAccesses;
+
+ // Iterate over each alias set twice, once to process read/write pointers,
+ // and then to process read-only pointers.
+ for (int SetIteration = 0; SetIteration < 2; ++SetIteration) {
+ bool UseDeferred = SetIteration > 0;
+ PtrAccessSet &S = UseDeferred ? DeferredAccesses : Accesses;
+
+ for (auto AV : AS) {
+ Value *Ptr = AV.getValue();
+
+ // For a single memory access in AliasSetTracker, Accesses may contain
+ // both read and write, and they both need to be handled for CheckDeps.
+ for (auto AC : S) {
+ if (AC.getPointer() != Ptr)
+ continue;
+
+ bool IsWrite = AC.getInt();
+
+ // If we're using the deferred access set, then it contains only
+ // reads.
+ bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite;
+ if (UseDeferred && !IsReadOnlyPtr)
+ continue;
+ // Otherwise, the pointer must be in the PtrAccessSet, either as a
+ // read or a write.
+ assert(((IsReadOnlyPtr && UseDeferred) || IsWrite ||
+ S.count(MemAccessInfo(Ptr, false))) &&
+ "Alias-set pointer not in the access set?");
+
+ MemAccessInfo Access(Ptr, IsWrite);
+ DepCands.insert(Access);
+
+ // Memorize read-only pointers for later processing and skip them in
+ // the first round (they need to be checked after we have seen all
+ // write pointers). Note: we also mark pointer that are not
+ // consecutive as "read-only" pointers (so that we check
+ // "a[b[i]] +="). Hence, we need the second check for "!IsWrite".
+ if (!UseDeferred && IsReadOnlyPtr) {
+ DeferredAccesses.insert(Access);
+ continue;
+ }
+
+ // If this is a write - check other reads and writes for conflicts. If
+ // this is a read only check other writes for conflicts (but only if
+ // there is no other write to the ptr - this is an optimization to
+ // catch "a[i] = a[i] + " without having to do a dependence check).
+ if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) {
+ CheckDeps.push_back(Access);
+ IsRTCheckAnalysisNeeded = true;
+ }
+
+ if (IsWrite)
+ SetHasWrite = true;
+
+ // Create sets of pointers connected by a shared alias set and
+ // underlying object.
+ typedef SmallVector<const Value *, 16> ValueVector;
+ ValueVector TempObjects;
+
+ GetUnderlyingObjects(Ptr, TempObjects, DL, LI);
+ LLVM_DEBUG(dbgs()
+ << "Underlying objects for pointer " << *Ptr << "\n");
+ for (const Value *UnderlyingObj : TempObjects) {
+ // nullptr never alias, don't join sets for pointer that have "null"
+ // in their UnderlyingObjects list.
+ if (isa<ConstantPointerNull>(UnderlyingObj) &&
+ !NullPointerIsDefined(
+ TheLoop->getHeader()->getParent(),
+ UnderlyingObj->getType()->getPointerAddressSpace()))
+ continue;
+
+ UnderlyingObjToAccessMap::iterator Prev =
+ ObjToLastAccess.find(UnderlyingObj);
+ if (Prev != ObjToLastAccess.end())
+ DepCands.unionSets(Access, Prev->second);
+
+ ObjToLastAccess[UnderlyingObj] = Access;
+ LLVM_DEBUG(dbgs() << " " << *UnderlyingObj << "\n");
+ }
+ }
+ }
+ }
+ }
+}
+
+static bool isInBoundsGep(Value *Ptr) {
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
+ return GEP->isInBounds();
+ return false;
+}
+
+/// Return true if an AddRec pointer \p Ptr is unsigned non-wrapping,
+/// i.e. monotonically increasing/decreasing.
+static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
+ PredicatedScalarEvolution &PSE, const Loop *L) {
+ // FIXME: This should probably only return true for NUW.
+ if (AR->getNoWrapFlags(SCEV::NoWrapMask))
+ return true;
+
+ // Scalar evolution does not propagate the non-wrapping flags to values that
+ // are derived from a non-wrapping induction variable because non-wrapping
+ // could be flow-sensitive.
+ //
+ // Look through the potentially overflowing instruction to try to prove
+ // non-wrapping for the *specific* value of Ptr.
+
+ // The arithmetic implied by an inbounds GEP can't overflow.
+ auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!GEP || !GEP->isInBounds())
+ return false;
+
+ // Make sure there is only one non-const index and analyze that.
+ Value *NonConstIndex = nullptr;
+ for (Value *Index : make_range(GEP->idx_begin(), GEP->idx_end()))
+ if (!isa<ConstantInt>(Index)) {
+ if (NonConstIndex)
+ return false;
+ NonConstIndex = Index;
+ }
+ if (!NonConstIndex)
+ // The recurrence is on the pointer, ignore for now.
+ return false;
+
+ // The index in GEP is signed. It is non-wrapping if it's derived from a NSW
+ // AddRec using a NSW operation.
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(NonConstIndex))
+ if (OBO->hasNoSignedWrap() &&
+ // Assume constant for other the operand so that the AddRec can be
+ // easily found.
+ isa<ConstantInt>(OBO->getOperand(1))) {
+ auto *OpScev = PSE.getSCEV(OBO->getOperand(0));
+
+ if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev))
+ return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW);
+ }
+
+ return false;
+}
+
+/// Check whether the access through \p Ptr has a constant stride.
+int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Value *Ptr,
+ const Loop *Lp, const ValueToValueMap &StridesMap,
+ bool Assume, bool ShouldCheckWrap) {
+ Type *Ty = Ptr->getType();
+ assert(Ty->isPointerTy() && "Unexpected non-ptr");
+
+ // Make sure that the pointer does not point to aggregate types.
+ auto *PtrTy = cast<PointerType>(Ty);
+ if (PtrTy->getElementType()->isAggregateType()) {
+ LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type"
+ << *Ptr << "\n");
+ return 0;
+ }
+
+ const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
+
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
+ if (Assume && !AR)
+ AR = PSE.getAsAddRec(Ptr);
+
+ if (!AR) {
+ LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr
+ << " SCEV: " << *PtrScev << "\n");
+ return 0;
+ }
+
+ // The access function must stride over the innermost loop.
+ if (Lp != AR->getLoop()) {
+ LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop "
+ << *Ptr << " SCEV: " << *AR << "\n");
+ return 0;
+ }
+
+ // The address calculation must not wrap. Otherwise, a dependence could be
+ // inverted.
+ // An inbounds getelementptr that is a AddRec with a unit stride
+ // cannot wrap per definition. The unit stride requirement is checked later.
+ // An getelementptr without an inbounds attribute and unit stride would have
+ // to access the pointer value "0" which is undefined behavior in address
+ // space 0, therefore we can also vectorize this case.
+ bool IsInBoundsGEP = isInBoundsGep(Ptr);
+ bool IsNoWrapAddRec = !ShouldCheckWrap ||
+ PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) ||
+ isNoWrapAddRec(Ptr, AR, PSE, Lp);
+ if (!IsNoWrapAddRec && !IsInBoundsGEP &&
+ NullPointerIsDefined(Lp->getHeader()->getParent(),
+ PtrTy->getAddressSpace())) {
+ if (Assume) {
+ PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
+ IsNoWrapAddRec = true;
+ LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n"
+ << "LAA: Pointer: " << *Ptr << "\n"
+ << "LAA: SCEV: " << *AR << "\n"
+ << "LAA: Added an overflow assumption\n");
+ } else {
+ LLVM_DEBUG(
+ dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
+ << *Ptr << " SCEV: " << *AR << "\n");
+ return 0;
+ }
+ }
+
+ // Check the step is constant.
+ const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
+
+ // Calculate the pointer stride and check if it is constant.
+ const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
+ if (!C) {
+ LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr
+ << " SCEV: " << *AR << "\n");
+ return 0;
+ }
+
+ auto &DL = Lp->getHeader()->getModule()->getDataLayout();
+ int64_t Size = DL.getTypeAllocSize(PtrTy->getElementType());
+ const APInt &APStepVal = C->getAPInt();
+
+ // Huge step value - give up.
+ if (APStepVal.getBitWidth() > 64)
+ return 0;
+
+ int64_t StepVal = APStepVal.getSExtValue();
+
+ // Strided access.
+ int64_t Stride = StepVal / Size;
+ int64_t Rem = StepVal % Size;
+ if (Rem)
+ return 0;
+
+ // If the SCEV could wrap but we have an inbounds gep with a unit stride we
+ // know we can't "wrap around the address space". In case of address space
+ // zero we know that this won't happen without triggering undefined behavior.
+ if (!IsNoWrapAddRec && Stride != 1 && Stride != -1 &&
+ (IsInBoundsGEP || !NullPointerIsDefined(Lp->getHeader()->getParent(),
+ PtrTy->getAddressSpace()))) {
+ if (Assume) {
+ // We can avoid this case by adding a run-time check.
+ LLVM_DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either "
+ << "inbounds or in address space 0 may wrap:\n"
+ << "LAA: Pointer: " << *Ptr << "\n"
+ << "LAA: SCEV: " << *AR << "\n"
+ << "LAA: Added an overflow assumption\n");
+ PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
+ } else
+ return 0;
+ }
+
+ return Stride;
+}
+
+bool llvm::sortPtrAccesses(ArrayRef<Value *> VL, const DataLayout &DL,
+ ScalarEvolution &SE,
+ SmallVectorImpl<unsigned> &SortedIndices) {
+ assert(llvm::all_of(
+ VL, [](const Value *V) { return V->getType()->isPointerTy(); }) &&
+ "Expected list of pointer operands.");
+ SmallVector<std::pair<int64_t, Value *>, 4> OffValPairs;
+ OffValPairs.reserve(VL.size());
+
+ // Walk over the pointers, and map each of them to an offset relative to
+ // first pointer in the array.
+ Value *Ptr0 = VL[0];
+ const SCEV *Scev0 = SE.getSCEV(Ptr0);
+ Value *Obj0 = GetUnderlyingObject(Ptr0, DL);
+
+ llvm::SmallSet<int64_t, 4> Offsets;
+ for (auto *Ptr : VL) {
+ // TODO: Outline this code as a special, more time consuming, version of
+ // computeConstantDifference() function.
+ if (Ptr->getType()->getPointerAddressSpace() !=
+ Ptr0->getType()->getPointerAddressSpace())
+ return false;
+ // If a pointer refers to a different underlying object, bail - the
+ // pointers are by definition incomparable.
+ Value *CurrObj = GetUnderlyingObject(Ptr, DL);
+ if (CurrObj != Obj0)
+ return false;
+
+ const SCEV *Scev = SE.getSCEV(Ptr);
+ const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Scev, Scev0));
+ // The pointers may not have a constant offset from each other, or SCEV
+ // may just not be smart enough to figure out they do. Regardless,
+ // there's nothing we can do.
+ if (!Diff)
+ return false;
+
+ // Check if the pointer with the same offset is found.
+ int64_t Offset = Diff->getAPInt().getSExtValue();
+ if (!Offsets.insert(Offset).second)
+ return false;
+ OffValPairs.emplace_back(Offset, Ptr);
+ }
+ SortedIndices.clear();
+ SortedIndices.resize(VL.size());
+ std::iota(SortedIndices.begin(), SortedIndices.end(), 0);
+
+ // Sort the memory accesses and keep the order of their uses in UseOrder.
+ llvm::stable_sort(SortedIndices, [&](unsigned Left, unsigned Right) {
+ return OffValPairs[Left].first < OffValPairs[Right].first;
+ });
+
+ // Check if the order is consecutive already.
+ if (llvm::all_of(SortedIndices, [&SortedIndices](const unsigned I) {
+ return I == SortedIndices[I];
+ }))
+ SortedIndices.clear();
+
+ return true;
+}
+
+/// Take the address space operand from the Load/Store instruction.
+/// Returns -1 if this is not a valid Load/Store instruction.
+static unsigned getAddressSpaceOperand(Value *I) {
+ if (LoadInst *L = dyn_cast<LoadInst>(I))
+ return L->getPointerAddressSpace();
+ if (StoreInst *S = dyn_cast<StoreInst>(I))
+ return S->getPointerAddressSpace();
+ return -1;
+}
+
+/// Returns true if the memory operations \p A and \p B are consecutive.
+bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL,
+ ScalarEvolution &SE, bool CheckType) {
+ Value *PtrA = getLoadStorePointerOperand(A);
+ Value *PtrB = getLoadStorePointerOperand(B);
+ unsigned ASA = getAddressSpaceOperand(A);
+ unsigned ASB = getAddressSpaceOperand(B);
+
+ // Check that the address spaces match and that the pointers are valid.
+ if (!PtrA || !PtrB || (ASA != ASB))
+ return false;
+
+ // Make sure that A and B are different pointers.
+ if (PtrA == PtrB)
+ return false;
+
+ // Make sure that A and B have the same type if required.
+ if (CheckType && PtrA->getType() != PtrB->getType())
+ return false;
+
+ unsigned IdxWidth = DL.getIndexSizeInBits(ASA);
+ Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
+
+ APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0);
+ PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
+ PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
+
+ // Retrieve the address space again as pointer stripping now tracks through
+ // `addrspacecast`.
+ ASA = cast<PointerType>(PtrA->getType())->getAddressSpace();
+ ASB = cast<PointerType>(PtrB->getType())->getAddressSpace();
+ // Check that the address spaces match and that the pointers are valid.
+ if (ASA != ASB)
+ return false;
+
+ IdxWidth = DL.getIndexSizeInBits(ASA);
+ OffsetA = OffsetA.sextOrTrunc(IdxWidth);
+ OffsetB = OffsetB.sextOrTrunc(IdxWidth);
+
+ APInt Size(IdxWidth, DL.getTypeStoreSize(Ty));
+
+ // OffsetDelta = OffsetB - OffsetA;
+ const SCEV *OffsetSCEVA = SE.getConstant(OffsetA);
+ const SCEV *OffsetSCEVB = SE.getConstant(OffsetB);
+ const SCEV *OffsetDeltaSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA);
+ const APInt &OffsetDelta = cast<SCEVConstant>(OffsetDeltaSCEV)->getAPInt();
+
+ // Check if they are based on the same pointer. That makes the offsets
+ // sufficient.
+ if (PtrA == PtrB)
+ return OffsetDelta == Size;
+
+ // Compute the necessary base pointer delta to have the necessary final delta
+ // equal to the size.
+ // BaseDelta = Size - OffsetDelta;
+ const SCEV *SizeSCEV = SE.getConstant(Size);
+ const SCEV *BaseDelta = SE.getMinusSCEV(SizeSCEV, OffsetDeltaSCEV);
+
+ // Otherwise compute the distance with SCEV between the base pointers.
+ const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
+ const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
+ const SCEV *X = SE.getAddExpr(PtrSCEVA, BaseDelta);
+ return X == PtrSCEVB;
+}
+
+MemoryDepChecker::VectorizationSafetyStatus
+MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) {
+ switch (Type) {
+ case NoDep:
+ case Forward:
+ case BackwardVectorizable:
+ return VectorizationSafetyStatus::Safe;
+
+ case Unknown:
+ return VectorizationSafetyStatus::PossiblySafeWithRtChecks;
+ case ForwardButPreventsForwarding:
+ case Backward:
+ case BackwardVectorizableButPreventsForwarding:
+ return VectorizationSafetyStatus::Unsafe;
+ }
+ llvm_unreachable("unexpected DepType!");
+}
+
+bool MemoryDepChecker::Dependence::isBackward() const {
+ switch (Type) {
+ case NoDep:
+ case Forward:
+ case ForwardButPreventsForwarding:
+ case Unknown:
+ return false;
+
+ case BackwardVectorizable:
+ case Backward:
+ case BackwardVectorizableButPreventsForwarding:
+ return true;
+ }
+ llvm_unreachable("unexpected DepType!");
+}
+
+bool MemoryDepChecker::Dependence::isPossiblyBackward() const {
+ return isBackward() || Type == Unknown;
+}
+
+bool MemoryDepChecker::Dependence::isForward() const {
+ switch (Type) {
+ case Forward:
+ case ForwardButPreventsForwarding:
+ return true;
+
+ case NoDep:
+ case Unknown:
+ case BackwardVectorizable:
+ case Backward:
+ case BackwardVectorizableButPreventsForwarding:
+ return false;
+ }
+ llvm_unreachable("unexpected DepType!");
+}
+
+bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance,
+ uint64_t TypeByteSize) {
+ // If loads occur at a distance that is not a multiple of a feasible vector
+ // factor store-load forwarding does not take place.
+ // Positive dependences might cause troubles because vectorizing them might
+ // prevent store-load forwarding making vectorized code run a lot slower.
+ // a[i] = a[i-3] ^ a[i-8];
+ // The stores to a[i:i+1] don't align with the stores to a[i-3:i-2] and
+ // hence on your typical architecture store-load forwarding does not take
+ // place. Vectorizing in such cases does not make sense.
+ // Store-load forwarding distance.
+
+ // After this many iterations store-to-load forwarding conflicts should not
+ // cause any slowdowns.
+ const uint64_t NumItersForStoreLoadThroughMemory = 8 * TypeByteSize;
+ // Maximum vector factor.
+ uint64_t MaxVFWithoutSLForwardIssues = std::min(
+ VectorizerParams::MaxVectorWidth * TypeByteSize, MaxSafeDepDistBytes);
+
+ // Compute the smallest VF at which the store and load would be misaligned.
+ for (uint64_t VF = 2 * TypeByteSize; VF <= MaxVFWithoutSLForwardIssues;
+ VF *= 2) {
+ // If the number of vector iteration between the store and the load are
+ // small we could incur conflicts.
+ if (Distance % VF && Distance / VF < NumItersForStoreLoadThroughMemory) {
+ MaxVFWithoutSLForwardIssues = (VF >>= 1);
+ break;
+ }
+ }
+
+ if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: Distance " << Distance
+ << " that could cause a store-load forwarding conflict\n");
+ return true;
+ }
+
+ if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes &&
+ MaxVFWithoutSLForwardIssues !=
+ VectorizerParams::MaxVectorWidth * TypeByteSize)
+ MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues;
+ return false;
+}
+
+void MemoryDepChecker::mergeInStatus(VectorizationSafetyStatus S) {
+ if (Status < S)
+ Status = S;
+}
+
+/// Given a non-constant (unknown) dependence-distance \p Dist between two
+/// memory accesses, that have the same stride whose absolute value is given
+/// in \p Stride, and that have the same type size \p TypeByteSize,
+/// in a loop whose takenCount is \p BackedgeTakenCount, check if it is
+/// possible to prove statically that the dependence distance is larger
+/// than the range that the accesses will travel through the execution of
+/// the loop. If so, return true; false otherwise. This is useful for
+/// example in loops such as the following (PR31098):
+/// for (i = 0; i < D; ++i) {
+/// = out[i];
+/// out[i+D] =
+/// }
+static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE,
+ const SCEV &BackedgeTakenCount,
+ const SCEV &Dist, uint64_t Stride,
+ uint64_t TypeByteSize) {
+
+ // If we can prove that
+ // (**) |Dist| > BackedgeTakenCount * Step
+ // where Step is the absolute stride of the memory accesses in bytes,
+ // then there is no dependence.
+ //
+ // Rationale:
+ // We basically want to check if the absolute distance (|Dist/Step|)
+ // is >= the loop iteration count (or > BackedgeTakenCount).
+ // This is equivalent to the Strong SIV Test (Practical Dependence Testing,
+ // Section 4.2.1); Note, that for vectorization it is sufficient to prove
+ // that the dependence distance is >= VF; This is checked elsewhere.
+ // But in some cases we can prune unknown dependence distances early, and
+ // even before selecting the VF, and without a runtime test, by comparing
+ // the distance against the loop iteration count. Since the vectorized code
+ // will be executed only if LoopCount >= VF, proving distance >= LoopCount
+ // also guarantees that distance >= VF.
+ //
+ const uint64_t ByteStride = Stride * TypeByteSize;
+ const SCEV *Step = SE.getConstant(BackedgeTakenCount.getType(), ByteStride);
+ const SCEV *Product = SE.getMulExpr(&BackedgeTakenCount, Step);
+
+ const SCEV *CastedDist = &Dist;
+ const SCEV *CastedProduct = Product;
+ uint64_t DistTypeSize = DL.getTypeAllocSize(Dist.getType());
+ uint64_t ProductTypeSize = DL.getTypeAllocSize(Product->getType());
+
+ // The dependence distance can be positive/negative, so we sign extend Dist;
+ // The multiplication of the absolute stride in bytes and the
+ // backedgeTakenCount is non-negative, so we zero extend Product.
+ if (DistTypeSize > ProductTypeSize)
+ CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType());
+ else
+ CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType());
+
+ // Is Dist - (BackedgeTakenCount * Step) > 0 ?
+ // (If so, then we have proven (**) because |Dist| >= Dist)
+ const SCEV *Minus = SE.getMinusSCEV(CastedDist, CastedProduct);
+ if (SE.isKnownPositive(Minus))
+ return true;
+
+ // Second try: Is -Dist - (BackedgeTakenCount * Step) > 0 ?
+ // (If so, then we have proven (**) because |Dist| >= -1*Dist)
+ const SCEV *NegDist = SE.getNegativeSCEV(CastedDist);
+ Minus = SE.getMinusSCEV(NegDist, CastedProduct);
+ if (SE.isKnownPositive(Minus))
+ return true;
+
+ return false;
+}
+
+/// Check the dependence for two accesses with the same stride \p Stride.
+/// \p Distance is the positive distance and \p TypeByteSize is type size in
+/// bytes.
+///
+/// \returns true if they are independent.
+static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride,
+ uint64_t TypeByteSize) {
+ assert(Stride > 1 && "The stride must be greater than 1");
+ assert(TypeByteSize > 0 && "The type size in byte must be non-zero");
+ assert(Distance > 0 && "The distance must be non-zero");
+
+ // Skip if the distance is not multiple of type byte size.
+ if (Distance % TypeByteSize)
+ return false;
+
+ uint64_t ScaledDist = Distance / TypeByteSize;
+
+ // No dependence if the scaled distance is not multiple of the stride.
+ // E.g.
+ // for (i = 0; i < 1024 ; i += 4)
+ // A[i+2] = A[i] + 1;
+ //
+ // Two accesses in memory (scaled distance is 2, stride is 4):
+ // | A[0] | | | | A[4] | | | |
+ // | | | A[2] | | | | A[6] | |
+ //
+ // E.g.
+ // for (i = 0; i < 1024 ; i += 3)
+ // A[i+4] = A[i] + 1;
+ //
+ // Two accesses in memory (scaled distance is 4, stride is 3):
+ // | A[0] | | | A[3] | | | A[6] | | |
+ // | | | | | A[4] | | | A[7] | |
+ return ScaledDist % Stride;
+}
+
+MemoryDepChecker::Dependence::DepType
+MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
+ const MemAccessInfo &B, unsigned BIdx,
+ const ValueToValueMap &Strides) {
+ assert (AIdx < BIdx && "Must pass arguments in program order");
+
+ Value *APtr = A.getPointer();
+ Value *BPtr = B.getPointer();
+ bool AIsWrite = A.getInt();
+ bool BIsWrite = B.getInt();
+
+ // Two reads are independent.
+ if (!AIsWrite && !BIsWrite)
+ return Dependence::NoDep;
+
+ // We cannot check pointers in different address spaces.
+ if (APtr->getType()->getPointerAddressSpace() !=
+ BPtr->getType()->getPointerAddressSpace())
+ return Dependence::Unknown;
+
+ int64_t StrideAPtr = getPtrStride(PSE, APtr, InnermostLoop, Strides, true);
+ int64_t StrideBPtr = getPtrStride(PSE, BPtr, InnermostLoop, Strides, true);
+
+ const SCEV *Src = PSE.getSCEV(APtr);
+ const SCEV *Sink = PSE.getSCEV(BPtr);
+
+ // If the induction step is negative we have to invert source and sink of the
+ // dependence.
+ if (StrideAPtr < 0) {
+ std::swap(APtr, BPtr);
+ std::swap(Src, Sink);
+ std::swap(AIsWrite, BIsWrite);
+ std::swap(AIdx, BIdx);
+ std::swap(StrideAPtr, StrideBPtr);
+ }
+
+ const SCEV *Dist = PSE.getSE()->getMinusSCEV(Sink, Src);
+
+ LLVM_DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink
+ << "(Induction step: " << StrideAPtr << ")\n");
+ LLVM_DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to "
+ << *InstMap[BIdx] << ": " << *Dist << "\n");
+
+ // Need accesses with constant stride. We don't want to vectorize
+ // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in
+ // the address space.
+ if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr){
+ LLVM_DEBUG(dbgs() << "Pointer access with non-constant stride\n");
+ return Dependence::Unknown;
+ }
+
+ Type *ATy = APtr->getType()->getPointerElementType();
+ Type *BTy = BPtr->getType()->getPointerElementType();
+ auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout();
+ uint64_t TypeByteSize = DL.getTypeAllocSize(ATy);
+ uint64_t Stride = std::abs(StrideAPtr);
+ const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist);
+ if (!C) {
+ if (TypeByteSize == DL.getTypeAllocSize(BTy) &&
+ isSafeDependenceDistance(DL, *(PSE.getSE()),
+ *(PSE.getBackedgeTakenCount()), *Dist, Stride,
+ TypeByteSize))
+ return Dependence::NoDep;
+
+ LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n");
+ FoundNonConstantDistanceDependence = true;
+ return Dependence::Unknown;
+ }
+
+ const APInt &Val = C->getAPInt();
+ int64_t Distance = Val.getSExtValue();
+
+ // Attempt to prove strided accesses independent.
+ if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy &&
+ areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) {
+ LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n");
+ return Dependence::NoDep;
+ }
+
+ // Negative distances are not plausible dependencies.
+ if (Val.isNegative()) {
+ bool IsTrueDataDependence = (AIsWrite && !BIsWrite);
+ if (IsTrueDataDependence && EnableForwardingConflictDetection &&
+ (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeByteSize) ||
+ ATy != BTy)) {
+ LLVM_DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
+ return Dependence::ForwardButPreventsForwarding;
+ }
+
+ LLVM_DEBUG(dbgs() << "LAA: Dependence is negative\n");
+ return Dependence::Forward;
+ }
+
+ // Write to the same location with the same size.
+ // Could be improved to assert type sizes are the same (i32 == float, etc).
+ if (Val == 0) {
+ if (ATy == BTy)
+ return Dependence::Forward;
+ LLVM_DEBUG(
+ dbgs() << "LAA: Zero dependence difference but different types\n");
+ return Dependence::Unknown;
+ }
+
+ assert(Val.isStrictlyPositive() && "Expect a positive value");
+
+ if (ATy != BTy) {
+ LLVM_DEBUG(
+ dbgs()
+ << "LAA: ReadWrite-Write positive dependency with different types\n");
+ return Dependence::Unknown;
+ }
+
+ // Bail out early if passed-in parameters make vectorization not feasible.
+ unsigned ForcedFactor = (VectorizerParams::VectorizationFactor ?
+ VectorizerParams::VectorizationFactor : 1);
+ unsigned ForcedUnroll = (VectorizerParams::VectorizationInterleave ?
+ VectorizerParams::VectorizationInterleave : 1);
+ // The minimum number of iterations for a vectorized/unrolled version.
+ unsigned MinNumIter = std::max(ForcedFactor * ForcedUnroll, 2U);
+
+ // It's not vectorizable if the distance is smaller than the minimum distance
+ // needed for a vectroized/unrolled version. Vectorizing one iteration in
+ // front needs TypeByteSize * Stride. Vectorizing the last iteration needs
+ // TypeByteSize (No need to plus the last gap distance).
+ //
+ // E.g. Assume one char is 1 byte in memory and one int is 4 bytes.
+ // foo(int *A) {
+ // int *B = (int *)((char *)A + 14);
+ // for (i = 0 ; i < 1024 ; i += 2)
+ // B[i] = A[i] + 1;
+ // }
+ //
+ // Two accesses in memory (stride is 2):
+ // | A[0] | | A[2] | | A[4] | | A[6] | |
+ // | B[0] | | B[2] | | B[4] |
+ //
+ // Distance needs for vectorizing iterations except the last iteration:
+ // 4 * 2 * (MinNumIter - 1). Distance needs for the last iteration: 4.
+ // So the minimum distance needed is: 4 * 2 * (MinNumIter - 1) + 4.
+ //
+ // If MinNumIter is 2, it is vectorizable as the minimum distance needed is
+ // 12, which is less than distance.
+ //
+ // If MinNumIter is 4 (Say if a user forces the vectorization factor to be 4),
+ // the minimum distance needed is 28, which is greater than distance. It is
+ // not safe to do vectorization.
+ uint64_t MinDistanceNeeded =
+ TypeByteSize * Stride * (MinNumIter - 1) + TypeByteSize;
+ if (MinDistanceNeeded > static_cast<uint64_t>(Distance)) {
+ LLVM_DEBUG(dbgs() << "LAA: Failure because of positive distance "
+ << Distance << '\n');
+ return Dependence::Backward;
+ }
+
+ // Unsafe if the minimum distance needed is greater than max safe distance.
+ if (MinDistanceNeeded > MaxSafeDepDistBytes) {
+ LLVM_DEBUG(dbgs() << "LAA: Failure because it needs at least "
+ << MinDistanceNeeded << " size in bytes");
+ return Dependence::Backward;
+ }
+
+ // Positive distance bigger than max vectorization factor.
+ // FIXME: Should use max factor instead of max distance in bytes, which could
+ // not handle different types.
+ // E.g. Assume one char is 1 byte in memory and one int is 4 bytes.
+ // void foo (int *A, char *B) {
+ // for (unsigned i = 0; i < 1024; i++) {
+ // A[i+2] = A[i] + 1;
+ // B[i+2] = B[i] + 1;
+ // }
+ // }
+ //
+ // This case is currently unsafe according to the max safe distance. If we
+ // analyze the two accesses on array B, the max safe dependence distance
+ // is 2. Then we analyze the accesses on array A, the minimum distance needed
+ // is 8, which is less than 2 and forbidden vectorization, But actually
+ // both A and B could be vectorized by 2 iterations.
+ MaxSafeDepDistBytes =
+ std::min(static_cast<uint64_t>(Distance), MaxSafeDepDistBytes);
+
+ bool IsTrueDataDependence = (!AIsWrite && BIsWrite);
+ if (IsTrueDataDependence && EnableForwardingConflictDetection &&
+ couldPreventStoreLoadForward(Distance, TypeByteSize))
+ return Dependence::BackwardVectorizableButPreventsForwarding;
+
+ uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride);
+ LLVM_DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue()
+ << " with max VF = " << MaxVF << '\n');
+ uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8;
+ MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits);
+ return Dependence::BackwardVectorizable;
+}
+
+bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets,
+ MemAccessInfoList &CheckDeps,
+ const ValueToValueMap &Strides) {
+
+ MaxSafeDepDistBytes = -1;
+ SmallPtrSet<MemAccessInfo, 8> Visited;
+ for (MemAccessInfo CurAccess : CheckDeps) {
+ if (Visited.count(CurAccess))
+ continue;
+
+ // Get the relevant memory access set.
+ EquivalenceClasses<MemAccessInfo>::iterator I =
+ AccessSets.findValue(AccessSets.getLeaderValue(CurAccess));
+
+ // Check accesses within this set.
+ EquivalenceClasses<MemAccessInfo>::member_iterator AI =
+ AccessSets.member_begin(I);
+ EquivalenceClasses<MemAccessInfo>::member_iterator AE =
+ AccessSets.member_end();
+
+ // Check every access pair.
+ while (AI != AE) {
+ Visited.insert(*AI);
+ bool AIIsWrite = AI->getInt();
+ // Check loads only against next equivalent class, but stores also against
+ // other stores in the same equivalence class - to the same address.
+ EquivalenceClasses<MemAccessInfo>::member_iterator OI =
+ (AIIsWrite ? AI : std::next(AI));
+ while (OI != AE) {
+ // Check every accessing instruction pair in program order.
+ for (std::vector<unsigned>::iterator I1 = Accesses[*AI].begin(),
+ I1E = Accesses[*AI].end(); I1 != I1E; ++I1)
+ // Scan all accesses of another equivalence class, but only the next
+ // accesses of the same equivalent class.
+ for (std::vector<unsigned>::iterator
+ I2 = (OI == AI ? std::next(I1) : Accesses[*OI].begin()),
+ I2E = (OI == AI ? I1E : Accesses[*OI].end());
+ I2 != I2E; ++I2) {
+ auto A = std::make_pair(&*AI, *I1);
+ auto B = std::make_pair(&*OI, *I2);
+
+ assert(*I1 != *I2);
+ if (*I1 > *I2)
+ std::swap(A, B);
+
+ Dependence::DepType Type =
+ isDependent(*A.first, A.second, *B.first, B.second, Strides);
+ mergeInStatus(Dependence::isSafeForVectorization(Type));
+
+ // Gather dependences unless we accumulated MaxDependences
+ // dependences. In that case return as soon as we find the first
+ // unsafe dependence. This puts a limit on this quadratic
+ // algorithm.
+ if (RecordDependences) {
+ if (Type != Dependence::NoDep)
+ Dependences.push_back(Dependence(A.second, B.second, Type));
+
+ if (Dependences.size() >= MaxDependences) {
+ RecordDependences = false;
+ Dependences.clear();
+ LLVM_DEBUG(dbgs()
+ << "Too many dependences, stopped recording\n");
+ }
+ }
+ if (!RecordDependences && !isSafeForVectorization())
+ return false;
+ }
+ ++OI;
+ }
+ AI++;
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << "Total Dependences: " << Dependences.size() << "\n");
+ return isSafeForVectorization();
+}
+
+SmallVector<Instruction *, 4>
+MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool isWrite) const {
+ MemAccessInfo Access(Ptr, isWrite);
+ auto &IndexVector = Accesses.find(Access)->second;
+
+ SmallVector<Instruction *, 4> Insts;
+ transform(IndexVector,
+ std::back_inserter(Insts),
+ [&](unsigned Idx) { return this->InstMap[Idx]; });
+ return Insts;
+}
+
+const char *MemoryDepChecker::Dependence::DepName[] = {
+ "NoDep", "Unknown", "Forward", "ForwardButPreventsForwarding", "Backward",
+ "BackwardVectorizable", "BackwardVectorizableButPreventsForwarding"};
+
+void MemoryDepChecker::Dependence::print(
+ raw_ostream &OS, unsigned Depth,
+ const SmallVectorImpl<Instruction *> &Instrs) const {
+ OS.indent(Depth) << DepName[Type] << ":\n";
+ OS.indent(Depth + 2) << *Instrs[Source] << " -> \n";
+ OS.indent(Depth + 2) << *Instrs[Destination] << "\n";
+}
+
+bool LoopAccessInfo::canAnalyzeLoop() {
+ // We need to have a loop header.
+ LLVM_DEBUG(dbgs() << "LAA: Found a loop in "
+ << TheLoop->getHeader()->getParent()->getName() << ": "
+ << TheLoop->getHeader()->getName() << '\n');
+
+ // We can only analyze innermost loops.
+ if (!TheLoop->empty()) {
+ LLVM_DEBUG(dbgs() << "LAA: loop is not the innermost loop\n");
+ recordAnalysis("NotInnerMostLoop") << "loop is not the innermost loop";
+ return false;
+ }
+
+ // We must have a single backedge.
+ if (TheLoop->getNumBackEdges() != 1) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: loop control flow is not understood by analyzer\n");
+ recordAnalysis("CFGNotUnderstood")
+ << "loop control flow is not understood by analyzer";
+ return false;
+ }
+
+ // We must have a single exiting block.
+ if (!TheLoop->getExitingBlock()) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: loop control flow is not understood by analyzer\n");
+ recordAnalysis("CFGNotUnderstood")
+ << "loop control flow is not understood by analyzer";
+ return false;
+ }
+
+ // We only handle bottom-tested loops, i.e. loop in which the condition is
+ // checked at the end of each iteration. With that we can assume that all
+ // instructions in the loop are executed the same number of times.
+ if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: loop control flow is not understood by analyzer\n");
+ recordAnalysis("CFGNotUnderstood")
+ << "loop control flow is not understood by analyzer";
+ return false;
+ }
+
+ // ScalarEvolution needs to be able to find the exit count.
+ const SCEV *ExitCount = PSE->getBackedgeTakenCount();
+ if (ExitCount == PSE->getSE()->getCouldNotCompute()) {
+ recordAnalysis("CantComputeNumberOfIterations")
+ << "could not determine number of loop iterations";
+ LLVM_DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n");
+ return false;
+ }
+
+ return true;
+}
+
+void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI,
+ const TargetLibraryInfo *TLI,
+ DominatorTree *DT) {
+ typedef SmallPtrSet<Value*, 16> ValueSet;
+
+ // Holds the Load and Store instructions.
+ SmallVector<LoadInst *, 16> Loads;
+ SmallVector<StoreInst *, 16> Stores;
+
+ // Holds all the different accesses in the loop.
+ unsigned NumReads = 0;
+ unsigned NumReadWrites = 0;
+
+ bool HasComplexMemInst = false;
+
+ // A runtime check is only legal to insert if there are no convergent calls.
+ HasConvergentOp = false;
+
+ PtrRtChecking->Pointers.clear();
+ PtrRtChecking->Need = false;
+
+ const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel();
+
+ // For each block.
+ for (BasicBlock *BB : TheLoop->blocks()) {
+ // Scan the BB and collect legal loads and stores. Also detect any
+ // convergent instructions.
+ for (Instruction &I : *BB) {
+ if (auto *Call = dyn_cast<CallBase>(&I)) {
+ if (Call->isConvergent())
+ HasConvergentOp = true;
+ }
+
+ // With both a non-vectorizable memory instruction and a convergent
+ // operation, found in this loop, no reason to continue the search.
+ if (HasComplexMemInst && HasConvergentOp) {
+ CanVecMem = false;
+ return;
+ }
+
+ // Avoid hitting recordAnalysis multiple times.
+ if (HasComplexMemInst)
+ continue;
+
+ // If this is a load, save it. If this instruction can read from memory
+ // but is not a load, then we quit. Notice that we don't handle function
+ // calls that read or write.
+ if (I.mayReadFromMemory()) {
+ // Many math library functions read the rounding mode. We will only
+ // vectorize a loop if it contains known function calls that don't set
+ // the flag. Therefore, it is safe to ignore this read from memory.
+ auto *Call = dyn_cast<CallInst>(&I);
+ if (Call && getVectorIntrinsicIDForCall(Call, TLI))
+ continue;
+
+ // If the function has an explicit vectorized counterpart, we can safely
+ // assume that it can be vectorized.
+ if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() &&
+ TLI->isFunctionVectorizable(Call->getCalledFunction()->getName()))
+ continue;
+
+ auto *Ld = dyn_cast<LoadInst>(&I);
+ if (!Ld) {
+ recordAnalysis("CantVectorizeInstruction", Ld)
+ << "instruction cannot be vectorized";
+ HasComplexMemInst = true;
+ continue;
+ }
+ if (!Ld->isSimple() && !IsAnnotatedParallel) {
+ recordAnalysis("NonSimpleLoad", Ld)
+ << "read with atomic ordering or volatile read";
+ LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n");
+ HasComplexMemInst = true;
+ continue;
+ }
+ NumLoads++;
+ Loads.push_back(Ld);
+ DepChecker->addAccess(Ld);
+ if (EnableMemAccessVersioning)
+ collectStridedAccess(Ld);
+ continue;
+ }
+
+ // Save 'store' instructions. Abort if other instructions write to memory.
+ if (I.mayWriteToMemory()) {
+ auto *St = dyn_cast<StoreInst>(&I);
+ if (!St) {
+ recordAnalysis("CantVectorizeInstruction", St)
+ << "instruction cannot be vectorized";
+ HasComplexMemInst = true;
+ continue;
+ }
+ if (!St->isSimple() && !IsAnnotatedParallel) {
+ recordAnalysis("NonSimpleStore", St)
+ << "write with atomic ordering or volatile write";
+ LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n");
+ HasComplexMemInst = true;
+ continue;
+ }
+ NumStores++;
+ Stores.push_back(St);
+ DepChecker->addAccess(St);
+ if (EnableMemAccessVersioning)
+ collectStridedAccess(St);
+ }
+ } // Next instr.
+ } // Next block.
+
+ if (HasComplexMemInst) {
+ CanVecMem = false;
+ return;
+ }
+
+ // Now we have two lists that hold the loads and the stores.
+ // Next, we find the pointers that they use.
+
+ // Check if we see any stores. If there are no stores, then we don't
+ // care if the pointers are *restrict*.
+ if (!Stores.size()) {
+ LLVM_DEBUG(dbgs() << "LAA: Found a read-only loop!\n");
+ CanVecMem = true;
+ return;
+ }
+
+ MemoryDepChecker::DepCandidates DependentAccesses;
+ AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
+ TheLoop, AA, LI, DependentAccesses, *PSE);
+
+ // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
+ // multiple times on the same object. If the ptr is accessed twice, once
+ // for read and once for write, it will only appear once (on the write
+ // list). This is okay, since we are going to check for conflicts between
+ // writes and between reads and writes, but not between reads and reads.
+ ValueSet Seen;
+
+ // Record uniform store addresses to identify if we have multiple stores
+ // to the same address.
+ ValueSet UniformStores;
+
+ for (StoreInst *ST : Stores) {
+ Value *Ptr = ST->getPointerOperand();
+
+ if (isUniform(Ptr))
+ HasDependenceInvolvingLoopInvariantAddress |=
+ !UniformStores.insert(Ptr).second;
+
+ // If we did *not* see this pointer before, insert it to the read-write
+ // list. At this phase it is only a 'write' list.
+ if (Seen.insert(Ptr).second) {
+ ++NumReadWrites;
+
+ MemoryLocation Loc = MemoryLocation::get(ST);
+ // The TBAA metadata could have a control dependency on the predication
+ // condition, so we cannot rely on it when determining whether or not we
+ // need runtime pointer checks.
+ if (blockNeedsPredication(ST->getParent(), TheLoop, DT))
+ Loc.AATags.TBAA = nullptr;
+
+ Accesses.addStore(Loc);
+ }
+ }
+
+ if (IsAnnotatedParallel) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: A loop annotated parallel, ignore memory dependency "
+ << "checks.\n");
+ CanVecMem = true;
+ return;
+ }
+
+ for (LoadInst *LD : Loads) {
+ Value *Ptr = LD->getPointerOperand();
+ // If we did *not* see this pointer before, insert it to the
+ // read list. If we *did* see it before, then it is already in
+ // the read-write list. This allows us to vectorize expressions
+ // such as A[i] += x; Because the address of A[i] is a read-write
+ // pointer. This only works if the index of A[i] is consecutive.
+ // If the address of i is unknown (for example A[B[i]]) then we may
+ // read a few words, modify, and write a few words, and some of the
+ // words may be written to the same address.
+ bool IsReadOnlyPtr = false;
+ if (Seen.insert(Ptr).second ||
+ !getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) {
+ ++NumReads;
+ IsReadOnlyPtr = true;
+ }
+
+ // See if there is an unsafe dependency between a load to a uniform address and
+ // store to the same uniform address.
+ if (UniformStores.count(Ptr)) {
+ LLVM_DEBUG(dbgs() << "LAA: Found an unsafe dependency between a uniform "
+ "load and uniform store to the same address!\n");
+ HasDependenceInvolvingLoopInvariantAddress = true;
+ }
+
+ MemoryLocation Loc = MemoryLocation::get(LD);
+ // The TBAA metadata could have a control dependency on the predication
+ // condition, so we cannot rely on it when determining whether or not we
+ // need runtime pointer checks.
+ if (blockNeedsPredication(LD->getParent(), TheLoop, DT))
+ Loc.AATags.TBAA = nullptr;
+
+ Accesses.addLoad(Loc, IsReadOnlyPtr);
+ }
+
+ // If we write (or read-write) to a single destination and there are no
+ // other reads in this loop then is it safe to vectorize.
+ if (NumReadWrites == 1 && NumReads == 0) {
+ LLVM_DEBUG(dbgs() << "LAA: Found a write-only loop!\n");
+ CanVecMem = true;
+ return;
+ }
+
+ // Build dependence sets and check whether we need a runtime pointer bounds
+ // check.
+ Accesses.buildDependenceSets();
+
+ // Find pointers with computable bounds. We are going to use this information
+ // to place a runtime bound check.
+ bool CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(),
+ TheLoop, SymbolicStrides);
+ if (!CanDoRTIfNeeded) {
+ recordAnalysis("CantIdentifyArrayBounds") << "cannot identify array bounds";
+ LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find "
+ << "the array bounds.\n");
+ CanVecMem = false;
+ return;
+ }
+
+ LLVM_DEBUG(
+ dbgs() << "LAA: May be able to perform a memory runtime check if needed.\n");
+
+ CanVecMem = true;
+ if (Accesses.isDependencyCheckNeeded()) {
+ LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n");
+ CanVecMem = DepChecker->areDepsSafe(
+ DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides);
+ MaxSafeDepDistBytes = DepChecker->getMaxSafeDepDistBytes();
+
+ if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) {
+ LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n");
+
+ // Clear the dependency checks. We assume they are not needed.
+ Accesses.resetDepChecks(*DepChecker);
+
+ PtrRtChecking->reset();
+ PtrRtChecking->Need = true;
+
+ auto *SE = PSE->getSE();
+ CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, SE, TheLoop,
+ SymbolicStrides, true);
+
+ // Check that we found the bounds for the pointer.
+ if (!CanDoRTIfNeeded) {
+ recordAnalysis("CantCheckMemDepsAtRunTime")
+ << "cannot check memory dependencies at runtime";
+ LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n");
+ CanVecMem = false;
+ return;
+ }
+
+ CanVecMem = true;
+ }
+ }
+
+ if (HasConvergentOp) {
+ recordAnalysis("CantInsertRuntimeCheckWithConvergent")
+ << "cannot add control dependency to convergent operation";
+ LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because a runtime check "
+ "would be needed with a convergent operation\n");
+ CanVecMem = false;
+ return;
+ }
+
+ if (CanVecMem)
+ LLVM_DEBUG(
+ dbgs() << "LAA: No unsafe dependent memory operations in loop. We"
+ << (PtrRtChecking->Need ? "" : " don't")
+ << " need runtime memory checks.\n");
+ else {
+ recordAnalysis("UnsafeMemDep")
+ << "unsafe dependent memory operations in loop. Use "
+ "#pragma loop distribute(enable) to allow loop distribution "
+ "to attempt to isolate the offending operations into a separate "
+ "loop";
+ LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n");
+ }
+}
+
+bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop,
+ DominatorTree *DT) {
+ assert(TheLoop->contains(BB) && "Unknown block used");
+
+ // Blocks that do not dominate the latch need predication.
+ BasicBlock* Latch = TheLoop->getLoopLatch();
+ return !DT->dominates(BB, Latch);
+}
+
+OptimizationRemarkAnalysis &LoopAccessInfo::recordAnalysis(StringRef RemarkName,
+ Instruction *I) {
+ assert(!Report && "Multiple reports generated");
+
+ Value *CodeRegion = TheLoop->getHeader();
+ DebugLoc DL = TheLoop->getStartLoc();
+
+ if (I) {
+ CodeRegion = I->getParent();
+ // If there is no debug location attached to the instruction, revert back to
+ // using the loop's.
+ if (I->getDebugLoc())
+ DL = I->getDebugLoc();
+ }
+
+ Report = std::make_unique<OptimizationRemarkAnalysis>(DEBUG_TYPE, RemarkName, DL,
+ CodeRegion);
+ return *Report;
+}
+
+bool LoopAccessInfo::isUniform(Value *V) const {
+ auto *SE = PSE->getSE();
+ // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is
+ // never considered uniform.
+ // TODO: Is this really what we want? Even without FP SCEV, we may want some
+ // trivially loop-invariant FP values to be considered uniform.
+ if (!SE->isSCEVable(V->getType()))
+ return false;
+ return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop));
+}
+
+// FIXME: this function is currently a duplicate of the one in
+// LoopVectorize.cpp.
+static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
+ Instruction *Loc) {
+ if (FirstInst)
+ return FirstInst;
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ return I->getParent() == Loc->getParent() ? I : nullptr;
+ return nullptr;
+}
+
+namespace {
+
+/// IR Values for the lower and upper bounds of a pointer evolution. We
+/// need to use value-handles because SCEV expansion can invalidate previously
+/// expanded values. Thus expansion of a pointer can invalidate the bounds for
+/// a previous one.
+struct PointerBounds {
+ TrackingVH<Value> Start;
+ TrackingVH<Value> End;
+};
+
+} // end anonymous namespace
+
+/// Expand code for the lower and upper bound of the pointer group \p CG
+/// in \p TheLoop. \return the values for the bounds.
+static PointerBounds
+expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop,
+ Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE,
+ const RuntimePointerChecking &PtrRtChecking) {
+ Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue;
+ const SCEV *Sc = SE->getSCEV(Ptr);
+
+ unsigned AS = Ptr->getType()->getPointerAddressSpace();
+ LLVMContext &Ctx = Loc->getContext();
+
+ // Use this type for pointer arithmetic.
+ Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+
+ if (SE->isLoopInvariant(Sc, TheLoop)) {
+ LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:"
+ << *Ptr << "\n");
+ // Ptr could be in the loop body. If so, expand a new one at the correct
+ // location.
+ Instruction *Inst = dyn_cast<Instruction>(Ptr);
+ Value *NewPtr = (Inst && TheLoop->contains(Inst))
+ ? Exp.expandCodeFor(Sc, PtrArithTy, Loc)
+ : Ptr;
+ // We must return a half-open range, which means incrementing Sc.
+ const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy));
+ Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc);
+ return {NewPtr, NewPtrPlusOne};
+ } else {
+ Value *Start = nullptr, *End = nullptr;
+ LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+ Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High
+ << "\n");
+ return {Start, End};
+ }
+}
+
+/// Turns a collection of checks into a collection of expanded upper and
+/// lower bounds for both pointers in the check.
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(
+ const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks,
+ Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp,
+ const RuntimePointerChecking &PtrRtChecking) {
+ SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
+
+ // Here we're relying on the SCEV Expander's cache to only emit code for the
+ // same bounds once.
+ transform(
+ PointerChecks, std::back_inserter(ChecksWithBounds),
+ [&](const RuntimePointerChecking::PointerCheck &Check) {
+ PointerBounds
+ First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking),
+ Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking);
+ return std::make_pair(First, Second);
+ });
+
+ return ChecksWithBounds;
+}
+
+std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks(
+ Instruction *Loc,
+ const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
+ const {
+ const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+ auto *SE = PSE->getSE();
+ SCEVExpander Exp(*SE, DL, "induction");
+ auto ExpandedChecks =
+ expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, *PtrRtChecking);
+
+ LLVMContext &Ctx = Loc->getContext();
+ Instruction *FirstInst = nullptr;
+ IRBuilder<> ChkBuilder(Loc);
+ // Our instructions might fold to a constant.
+ Value *MemoryRuntimeCheck = nullptr;
+
+ for (const auto &Check : ExpandedChecks) {
+ const PointerBounds &A = Check.first, &B = Check.second;
+ // Check if two pointers (A and B) conflict where conflict is computed as:
+ // start(A) <= end(B) && start(B) <= end(A)
+ unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
+ unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
+
+ assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
+ (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+ "Trying to bounds check pointers with different address spaces");
+
+ Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
+ Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
+
+ Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
+ Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
+ Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
+ Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
+
+ // [A|B].Start points to the first accessed byte under base [A|B].
+ // [A|B].End points to the last accessed byte, plus one.
+ // There is no conflict when the intervals are disjoint:
+ // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
+ //
+ // bound0 = (B.Start < A.End)
+ // bound1 = (A.Start < B.End)
+ // IsConflict = bound0 & bound1
+ Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
+ FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
+ Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
+ FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
+ Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+ if (MemoryRuntimeCheck) {
+ IsConflict =
+ ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
+ FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+ }
+ MemoryRuntimeCheck = IsConflict;
+ }
+
+ if (!MemoryRuntimeCheck)
+ return std::make_pair(nullptr, nullptr);
+
+ // We have to do this trickery because the IRBuilder might fold the check to a
+ // constant expression in which case there is no Instruction anchored in a
+ // the block.
+ Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck,
+ ConstantInt::getTrue(Ctx));
+ ChkBuilder.Insert(Check, "memcheck.conflict");
+ FirstInst = getFirstInst(FirstInst, Check, Loc);
+ return std::make_pair(FirstInst, Check);
+}
+
+std::pair<Instruction *, Instruction *>
+LoopAccessInfo::addRuntimeChecks(Instruction *Loc) const {
+ if (!PtrRtChecking->Need)
+ return std::make_pair(nullptr, nullptr);
+
+ return addRuntimeChecks(Loc, PtrRtChecking->getChecks());
+}
+
+void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
+ Value *Ptr = nullptr;
+ if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess))
+ Ptr = LI->getPointerOperand();
+ else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess))
+ Ptr = SI->getPointerOperand();
+ else
+ return;
+
+ Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+ if (!Stride)
+ return;
+
+ LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
+ "versioning:");
+ LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
+
+ // Avoid adding the "Stride == 1" predicate when we know that
+ // Stride >= Trip-Count. Such a predicate will effectively optimize a single
+ // or zero iteration loop, as Trip-Count <= Stride == 1.
+ //
+ // TODO: We are currently not making a very informed decision on when it is
+ // beneficial to apply stride versioning. It might make more sense that the
+ // users of this analysis (such as the vectorizer) will trigger it, based on
+ // their specific cost considerations; For example, in cases where stride
+ // versioning does not help resolving memory accesses/dependences, the
+ // vectorizer should evaluate the cost of the runtime test, and the benefit
+ // of various possible stride specializations, considering the alternatives
+ // of using gather/scatters (if available).
+
+ const SCEV *StrideExpr = PSE->getSCEV(Stride);
+ const SCEV *BETakenCount = PSE->getBackedgeTakenCount();
+
+ // Match the types so we can compare the stride and the BETakenCount.
+ // The Stride can be positive/negative, so we sign extend Stride;
+ // The backedgeTakenCount is non-negative, so we zero extend BETakenCount.
+ const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+ uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType());
+ uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType());
+ const SCEV *CastedStride = StrideExpr;
+ const SCEV *CastedBECount = BETakenCount;
+ ScalarEvolution *SE = PSE->getSE();
+ if (BETypeSize >= StrideTypeSize)
+ CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType());
+ else
+ CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType());
+ const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
+ // Since TripCount == BackEdgeTakenCount + 1, checking:
+ // "Stride >= TripCount" is equivalent to checking:
+ // Stride - BETakenCount > 0
+ if (SE->isKnownPositive(StrideMinusBETaken)) {
+ LLVM_DEBUG(
+ dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
+ "Stride==1 predicate will imply that the loop executes "
+ "at most once.\n");
+ return;
+ }
+ LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.");
+
+ SymbolicStrides[Ptr] = Stride;
+ StrideSet.insert(Stride);
+}
+
+LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
+ const TargetLibraryInfo *TLI, AliasAnalysis *AA,
+ DominatorTree *DT, LoopInfo *LI)
+ : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)),
+ PtrRtChecking(std::make_unique<RuntimePointerChecking>(SE)),
+ DepChecker(std::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L),
+ NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false),
+ HasConvergentOp(false),
+ HasDependenceInvolvingLoopInvariantAddress(false) {
+ if (canAnalyzeLoop())
+ analyzeLoop(AA, LI, TLI, DT);
+}
+
+void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
+ if (CanVecMem) {
+ OS.indent(Depth) << "Memory dependences are safe";
+ if (MaxSafeDepDistBytes != -1ULL)
+ OS << " with a maximum dependence distance of " << MaxSafeDepDistBytes
+ << " bytes";
+ if (PtrRtChecking->Need)
+ OS << " with run-time checks";
+ OS << "\n";
+ }
+
+ if (HasConvergentOp)
+ OS.indent(Depth) << "Has convergent operation in loop\n";
+
+ if (Report)
+ OS.indent(Depth) << "Report: " << Report->getMsg() << "\n";
+
+ if (auto *Dependences = DepChecker->getDependences()) {
+ OS.indent(Depth) << "Dependences:\n";
+ for (auto &Dep : *Dependences) {
+ Dep.print(OS, Depth + 2, DepChecker->getMemoryInstructions());
+ OS << "\n";
+ }
+ } else
+ OS.indent(Depth) << "Too many dependences, not recorded\n";
+
+ // List the pair of accesses need run-time checks to prove independence.
+ PtrRtChecking->print(OS, Depth);
+ OS << "\n";
+
+ OS.indent(Depth) << "Non vectorizable stores to invariant address were "
+ << (HasDependenceInvolvingLoopInvariantAddress ? "" : "not ")
+ << "found in loop.\n";
+
+ OS.indent(Depth) << "SCEV assumptions:\n";
+ PSE->getUnionPredicate().print(OS, Depth);
+
+ OS << "\n";
+
+ OS.indent(Depth) << "Expressions re-written:\n";
+ PSE->print(OS, Depth);
+}
+
+const LoopAccessInfo &LoopAccessLegacyAnalysis::getInfo(Loop *L) {
+ auto &LAI = LoopAccessInfoMap[L];
+
+ if (!LAI)
+ LAI = std::make_unique<LoopAccessInfo>(L, SE, TLI, AA, DT, LI);
+
+ return *LAI.get();
+}
+
+void LoopAccessLegacyAnalysis::print(raw_ostream &OS, const Module *M) const {
+ LoopAccessLegacyAnalysis &LAA = *const_cast<LoopAccessLegacyAnalysis *>(this);
+
+ for (Loop *TopLevelLoop : *LI)
+ for (Loop *L : depth_first(TopLevelLoop)) {
+ OS.indent(2) << L->getHeader()->getName() << ":\n";
+ auto &LAI = LAA.getInfo(L);
+ LAI.print(OS, 4);
+ }
+}
+
+bool LoopAccessLegacyAnalysis::runOnFunction(Function &F) {
+ SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+ auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
+ TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
+ AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
+ DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+ return false;
+}
+
+void LoopAccessLegacyAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+
+ AU.setPreservesAll();
+}
+
+char LoopAccessLegacyAnalysis::ID = 0;
+static const char laa_name[] = "Loop Access Analysis";
+#define LAA_NAME "loop-accesses"
+
+INITIALIZE_PASS_BEGIN(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true)
+
+AnalysisKey LoopAccessAnalysis::Key;
+
+LoopAccessInfo LoopAccessAnalysis::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR) {
+ return LoopAccessInfo(&L, &AR.SE, &AR.TLI, &AR.AA, &AR.DT, &AR.LI);
+}
+
+namespace llvm {
+
+ Pass *createLAAPass() {
+ return new LoopAccessLegacyAnalysis();
+ }
+
+} // end namespace llvm
diff --git a/llvm/lib/Analysis/LoopAnalysisManager.cpp b/llvm/lib/Analysis/LoopAnalysisManager.cpp
new file mode 100644
index 000000000000..02d40fb8d72a
--- /dev/null
+++ b/llvm/lib/Analysis/LoopAnalysisManager.cpp
@@ -0,0 +1,151 @@
+//===- LoopAnalysisManager.cpp - Loop analysis management -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
+#include "llvm/IR/Dominators.h"
+
+using namespace llvm;
+
+namespace llvm {
+// Explicit template instantiations and specialization definitions for core
+// template typedefs.
+template class AllAnalysesOn<Loop>;
+template class AnalysisManager<Loop, LoopStandardAnalysisResults &>;
+template class InnerAnalysisManagerProxy<LoopAnalysisManager, Function>;
+template class OuterAnalysisManagerProxy<FunctionAnalysisManager, Loop,
+ LoopStandardAnalysisResults &>;
+
+bool LoopAnalysisManagerFunctionProxy::Result::invalidate(
+ Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // First compute the sequence of IR units covered by this proxy. We will want
+ // to visit this in postorder, but because this is a tree structure we can do
+ // this by building a preorder sequence and walking it backwards. We also
+ // want siblings in forward program order to match the LoopPassManager so we
+ // get the preorder with siblings reversed.
+ SmallVector<Loop *, 4> PreOrderLoops = LI->getLoopsInReverseSiblingPreorder();
+
+ // If this proxy or the loop info is going to be invalidated, we also need
+ // to clear all the keys coming from that analysis. We also completely blow
+ // away the loop analyses if any of the standard analyses provided by the
+ // loop pass manager go away so that loop analyses can freely use these
+ // without worrying about declaring dependencies on them etc.
+ // FIXME: It isn't clear if this is the right tradeoff. We could instead make
+ // loop analyses declare any dependencies on these and use the more general
+ // invalidation logic below to act on that.
+ auto PAC = PA.getChecker<LoopAnalysisManagerFunctionProxy>();
+ bool invalidateMemorySSAAnalysis = false;
+ if (MSSAUsed)
+ invalidateMemorySSAAnalysis = Inv.invalidate<MemorySSAAnalysis>(F, PA);
+ if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
+ Inv.invalidate<AAManager>(F, PA) ||
+ Inv.invalidate<AssumptionAnalysis>(F, PA) ||
+ Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
+ Inv.invalidate<LoopAnalysis>(F, PA) ||
+ Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) ||
+ invalidateMemorySSAAnalysis) {
+ // Note that the LoopInfo may be stale at this point, however the loop
+ // objects themselves remain the only viable keys that could be in the
+ // analysis manager's cache. So we just walk the keys and forcibly clear
+ // those results. Note that the order doesn't matter here as this will just
+ // directly destroy the results without calling methods on them.
+ for (Loop *L : PreOrderLoops) {
+ // NB! `L` may not be in a good enough state to run Loop::getName.
+ InnerAM->clear(*L, "<possibly invalidated loop>");
+ }
+
+ // We also need to null out the inner AM so that when the object gets
+ // destroyed as invalid we don't try to clear the inner AM again. At that
+ // point we won't be able to reliably walk the loops for this function and
+ // only clear results associated with those loops the way we do here.
+ // FIXME: Making InnerAM null at this point isn't very nice. Most analyses
+ // try to remain valid during invalidation. Maybe we should add an
+ // `IsClean` flag?
+ InnerAM = nullptr;
+
+ // Now return true to indicate this *is* invalid and a fresh proxy result
+ // needs to be built. This is especially important given the null InnerAM.
+ return true;
+ }
+
+ // Directly check if the relevant set is preserved so we can short circuit
+ // invalidating loops.
+ bool AreLoopAnalysesPreserved =
+ PA.allAnalysesInSetPreserved<AllAnalysesOn<Loop>>();
+
+ // Since we have a valid LoopInfo we can actually leave the cached results in
+ // the analysis manager associated with the Loop keys, but we need to
+ // propagate any necessary invalidation logic into them. We'd like to
+ // invalidate things in roughly the same order as they were put into the
+ // cache and so we walk the preorder list in reverse to form a valid
+ // postorder.
+ for (Loop *L : reverse(PreOrderLoops)) {
+ Optional<PreservedAnalyses> InnerPA;
+
+ // Check to see whether the preserved set needs to be adjusted based on
+ // function-level analysis invalidation triggering deferred invalidation
+ // for this loop.
+ if (auto *OuterProxy =
+ InnerAM->getCachedResult<FunctionAnalysisManagerLoopProxy>(*L))
+ for (const auto &OuterInvalidationPair :
+ OuterProxy->getOuterInvalidations()) {
+ AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first;
+ const auto &InnerAnalysisIDs = OuterInvalidationPair.second;
+ if (Inv.invalidate(OuterAnalysisID, F, PA)) {
+ if (!InnerPA)
+ InnerPA = PA;
+ for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs)
+ InnerPA->abandon(InnerAnalysisID);
+ }
+ }
+
+ // Check if we needed a custom PA set. If so we'll need to run the inner
+ // invalidation.
+ if (InnerPA) {
+ InnerAM->invalidate(*L, *InnerPA);
+ continue;
+ }
+
+ // Otherwise we only need to do invalidation if the original PA set didn't
+ // preserve all Loop analyses.
+ if (!AreLoopAnalysesPreserved)
+ InnerAM->invalidate(*L, PA);
+ }
+
+ // Return false to indicate that this result is still a valid proxy.
+ return false;
+}
+
+template <>
+LoopAnalysisManagerFunctionProxy::Result
+LoopAnalysisManagerFunctionProxy::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ return Result(*InnerAM, AM.getResult<LoopAnalysis>(F));
+}
+}
+
+PreservedAnalyses llvm::getLoopPassPreservedAnalyses() {
+ PreservedAnalyses PA;
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<LoopAnalysis>();
+ PA.preserve<LoopAnalysisManagerFunctionProxy>();
+ PA.preserve<ScalarEvolutionAnalysis>();
+ // FIXME: What we really want to do here is preserve an AA category, but that
+ // concept doesn't exist yet.
+ PA.preserve<AAManager>();
+ PA.preserve<BasicAA>();
+ PA.preserve<GlobalsAA>();
+ PA.preserve<SCEVAA>();
+ return PA;
+}
diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
new file mode 100644
index 000000000000..10d2fe07884a
--- /dev/null
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -0,0 +1,625 @@
+//===- LoopCacheAnalysis.cpp - Loop Cache Analysis -------------------------==//
+//
+// The LLVM Compiler Infrastructure
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file defines the implementation for the loop cache analysis.
+/// The implementation is largely based on the following paper:
+///
+/// Compiler Optimizations for Improving Data Locality
+/// By: Steve Carr, Katherine S. McKinley, Chau-Wen Tseng
+/// http://www.cs.utexas.edu/users/mckinley/papers/asplos-1994.pdf
+///
+/// The general approach taken to estimate the number of cache lines used by the
+/// memory references in an inner loop is:
+/// 1. Partition memory references that exhibit temporal or spacial reuse
+/// into reference groups.
+/// 2. For each loop L in the a loop nest LN:
+/// a. Compute the cost of the reference group
+/// b. Compute the loop cost by summing up the reference groups costs
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopCacheAnalysis.h"
+#include "llvm/ADT/BreadthFirstIterator.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-cache-cost"
+
+static cl::opt<unsigned> DefaultTripCount(
+ "default-trip-count", cl::init(100), cl::Hidden,
+ cl::desc("Use this to specify the default trip count of a loop"));
+
+// In this analysis two array references are considered to exhibit temporal
+// reuse if they access either the same memory location, or a memory location
+// with distance smaller than a configurable threshold.
+static cl::opt<unsigned> TemporalReuseThreshold(
+ "temporal-reuse-threshold", cl::init(2), cl::Hidden,
+ cl::desc("Use this to specify the max. distance between array elements "
+ "accessed in a loop so that the elements are classified to have "
+ "temporal reuse"));
+
+/// Retrieve the innermost loop in the given loop nest \p Loops. It returns a
+/// nullptr if any loops in the loop vector supplied has more than one sibling.
+/// The loop vector is expected to contain loops collected in breadth-first
+/// order.
+static Loop *getInnerMostLoop(const LoopVectorTy &Loops) {
+ assert(!Loops.empty() && "Expecting a non-empy loop vector");
+
+ Loop *LastLoop = Loops.back();
+ Loop *ParentLoop = LastLoop->getParentLoop();
+
+ if (ParentLoop == nullptr) {
+ assert(Loops.size() == 1 && "Expecting a single loop");
+ return LastLoop;
+ }
+
+ return (std::is_sorted(Loops.begin(), Loops.end(),
+ [](const Loop *L1, const Loop *L2) {
+ return L1->getLoopDepth() < L2->getLoopDepth();
+ }))
+ ? LastLoop
+ : nullptr;
+}
+
+static bool isOneDimensionalArray(const SCEV &AccessFn, const SCEV &ElemSize,
+ const Loop &L, ScalarEvolution &SE) {
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&AccessFn);
+ if (!AR || !AR->isAffine())
+ return false;
+
+ assert(AR->getLoop() && "AR should have a loop");
+
+ // Check that start and increment are not add recurrences.
+ const SCEV *Start = AR->getStart();
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ if (isa<SCEVAddRecExpr>(Start) || isa<SCEVAddRecExpr>(Step))
+ return false;
+
+ // Check that start and increment are both invariant in the loop.
+ if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L))
+ return false;
+
+ return AR->getStepRecurrence(SE) == &ElemSize;
+}
+
+/// Compute the trip count for the given loop \p L. Return the SCEV expression
+/// for the trip count or nullptr if it cannot be computed.
+static const SCEV *computeTripCount(const Loop &L, ScalarEvolution &SE) {
+ const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(&L);
+ if (isa<SCEVCouldNotCompute>(BackedgeTakenCount) ||
+ !isa<SCEVConstant>(BackedgeTakenCount))
+ return nullptr;
+
+ return SE.getAddExpr(BackedgeTakenCount,
+ SE.getOne(BackedgeTakenCount->getType()));
+}
+
+//===----------------------------------------------------------------------===//
+// IndexedReference implementation
+//
+raw_ostream &llvm::operator<<(raw_ostream &OS, const IndexedReference &R) {
+ if (!R.IsValid) {
+ OS << R.StoreOrLoadInst;
+ OS << ", IsValid=false.";
+ return OS;
+ }
+
+ OS << *R.BasePointer;
+ for (const SCEV *Subscript : R.Subscripts)
+ OS << "[" << *Subscript << "]";
+
+ OS << ", Sizes: ";
+ for (const SCEV *Size : R.Sizes)
+ OS << "[" << *Size << "]";
+
+ return OS;
+}
+
+IndexedReference::IndexedReference(Instruction &StoreOrLoadInst,
+ const LoopInfo &LI, ScalarEvolution &SE)
+ : StoreOrLoadInst(StoreOrLoadInst), SE(SE) {
+ assert((isa<StoreInst>(StoreOrLoadInst) || isa<LoadInst>(StoreOrLoadInst)) &&
+ "Expecting a load or store instruction");
+
+ IsValid = delinearize(LI);
+ if (IsValid)
+ LLVM_DEBUG(dbgs().indent(2) << "Succesfully delinearized: " << *this
+ << "\n");
+}
+
+Optional<bool> IndexedReference::hasSpacialReuse(const IndexedReference &Other,
+ unsigned CLS,
+ AliasAnalysis &AA) const {
+ assert(IsValid && "Expecting a valid reference");
+
+ if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "No spacial reuse: different base pointers\n");
+ return false;
+ }
+
+ unsigned NumSubscripts = getNumSubscripts();
+ if (NumSubscripts != Other.getNumSubscripts()) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "No spacial reuse: different number of subscripts\n");
+ return false;
+ }
+
+ // all subscripts must be equal, except the leftmost one (the last one).
+ for (auto SubNum : seq<unsigned>(0, NumSubscripts - 1)) {
+ if (getSubscript(SubNum) != Other.getSubscript(SubNum)) {
+ LLVM_DEBUG(dbgs().indent(2) << "No spacial reuse, different subscripts: "
+ << "\n\t" << *getSubscript(SubNum) << "\n\t"
+ << *Other.getSubscript(SubNum) << "\n");
+ return false;
+ }
+ }
+
+ // the difference between the last subscripts must be less than the cache line
+ // size.
+ const SCEV *LastSubscript = getLastSubscript();
+ const SCEV *OtherLastSubscript = Other.getLastSubscript();
+ const SCEVConstant *Diff = dyn_cast<SCEVConstant>(
+ SE.getMinusSCEV(LastSubscript, OtherLastSubscript));
+
+ if (Diff == nullptr) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "No spacial reuse, difference between subscript:\n\t"
+ << *LastSubscript << "\n\t" << OtherLastSubscript
+ << "\nis not constant.\n");
+ return None;
+ }
+
+ bool InSameCacheLine = (Diff->getValue()->getSExtValue() < CLS);
+
+ LLVM_DEBUG({
+ if (InSameCacheLine)
+ dbgs().indent(2) << "Found spacial reuse.\n";
+ else
+ dbgs().indent(2) << "No spacial reuse.\n";
+ });
+
+ return InSameCacheLine;
+}
+
+Optional<bool> IndexedReference::hasTemporalReuse(const IndexedReference &Other,
+ unsigned MaxDistance,
+ const Loop &L,
+ DependenceInfo &DI,
+ AliasAnalysis &AA) const {
+ assert(IsValid && "Expecting a valid reference");
+
+ if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "No temporal reuse: different base pointer\n");
+ return false;
+ }
+
+ std::unique_ptr<Dependence> D =
+ DI.depends(&StoreOrLoadInst, &Other.StoreOrLoadInst, true);
+
+ if (D == nullptr) {
+ LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: no dependence\n");
+ return false;
+ }
+
+ if (D->isLoopIndependent()) {
+ LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n");
+ return true;
+ }
+
+ // Check the dependence distance at every loop level. There is temporal reuse
+ // if the distance at the given loop's depth is small (|d| <= MaxDistance) and
+ // it is zero at every other loop level.
+ int LoopDepth = L.getLoopDepth();
+ int Levels = D->getLevels();
+ for (int Level = 1; Level <= Levels; ++Level) {
+ const SCEV *Distance = D->getDistance(Level);
+ const SCEVConstant *SCEVConst = dyn_cast_or_null<SCEVConstant>(Distance);
+
+ if (SCEVConst == nullptr) {
+ LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: distance unknown\n");
+ return None;
+ }
+
+ const ConstantInt &CI = *SCEVConst->getValue();
+ if (Level != LoopDepth && !CI.isZero()) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "No temporal reuse: distance is not zero at depth=" << Level
+ << "\n");
+ return false;
+ } else if (Level == LoopDepth && CI.getSExtValue() > MaxDistance) {
+ LLVM_DEBUG(
+ dbgs().indent(2)
+ << "No temporal reuse: distance is greater than MaxDistance at depth="
+ << Level << "\n");
+ return false;
+ }
+ }
+
+ LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n");
+ return true;
+}
+
+CacheCostTy IndexedReference::computeRefCost(const Loop &L,
+ unsigned CLS) const {
+ assert(IsValid && "Expecting a valid reference");
+ LLVM_DEBUG({
+ dbgs().indent(2) << "Computing cache cost for:\n";
+ dbgs().indent(4) << *this << "\n";
+ });
+
+ // If the indexed reference is loop invariant the cost is one.
+ if (isLoopInvariant(L)) {
+ LLVM_DEBUG(dbgs().indent(4) << "Reference is loop invariant: RefCost=1\n");
+ return 1;
+ }
+
+ const SCEV *TripCount = computeTripCount(L, SE);
+ if (!TripCount) {
+ LLVM_DEBUG(dbgs() << "Trip count of loop " << L.getName()
+ << " could not be computed, using DefaultTripCount\n");
+ const SCEV *ElemSize = Sizes.back();
+ TripCount = SE.getConstant(ElemSize->getType(), DefaultTripCount);
+ }
+ LLVM_DEBUG(dbgs() << "TripCount=" << *TripCount << "\n");
+
+ // If the indexed reference is 'consecutive' the cost is
+ // (TripCount*Stride)/CLS, otherwise the cost is TripCount.
+ const SCEV *RefCost = TripCount;
+
+ if (isConsecutive(L, CLS)) {
+ const SCEV *Coeff = getLastCoefficient();
+ const SCEV *ElemSize = Sizes.back();
+ const SCEV *Stride = SE.getMulExpr(Coeff, ElemSize);
+ const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS);
+ const SCEV *Numerator = SE.getMulExpr(Stride, TripCount);
+ RefCost = SE.getUDivExpr(Numerator, CacheLineSize);
+ LLVM_DEBUG(dbgs().indent(4)
+ << "Access is consecutive: RefCost=(TripCount*Stride)/CLS="
+ << *RefCost << "\n");
+ } else
+ LLVM_DEBUG(dbgs().indent(4)
+ << "Access is not consecutive: RefCost=TripCount=" << *RefCost
+ << "\n");
+
+ // Attempt to fold RefCost into a constant.
+ if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
+ return ConstantCost->getValue()->getSExtValue();
+
+ LLVM_DEBUG(dbgs().indent(4)
+ << "RefCost is not a constant! Setting to RefCost=InvalidCost "
+ "(invalid value).\n");
+
+ return CacheCost::InvalidCost;
+}
+
+bool IndexedReference::delinearize(const LoopInfo &LI) {
+ assert(Subscripts.empty() && "Subscripts should be empty");
+ assert(Sizes.empty() && "Sizes should be empty");
+ assert(!IsValid && "Should be called once from the constructor");
+ LLVM_DEBUG(dbgs() << "Delinearizing: " << StoreOrLoadInst << "\n");
+
+ const SCEV *ElemSize = SE.getElementSize(&StoreOrLoadInst);
+ const BasicBlock *BB = StoreOrLoadInst.getParent();
+
+ for (Loop *L = LI.getLoopFor(BB); L != nullptr; L = L->getParentLoop()) {
+ const SCEV *AccessFn =
+ SE.getSCEVAtScope(getPointerOperand(&StoreOrLoadInst), L);
+
+ BasePointer = dyn_cast<SCEVUnknown>(SE.getPointerBase(AccessFn));
+ if (BasePointer == nullptr) {
+ LLVM_DEBUG(
+ dbgs().indent(2)
+ << "ERROR: failed to delinearize, can't identify base pointer\n");
+ return false;
+ }
+
+ AccessFn = SE.getMinusSCEV(AccessFn, BasePointer);
+
+ LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName()
+ << "', AccessFn: " << *AccessFn << "\n");
+
+ SE.delinearize(AccessFn, Subscripts, Sizes,
+ SE.getElementSize(&StoreOrLoadInst));
+
+ if (Subscripts.empty() || Sizes.empty() ||
+ Subscripts.size() != Sizes.size()) {
+ // Attempt to determine whether we have a single dimensional array access.
+ // before giving up.
+ if (!isOneDimensionalArray(*AccessFn, *ElemSize, *L, SE)) {
+ LLVM_DEBUG(dbgs().indent(2)
+ << "ERROR: failed to delinearize reference\n");
+ Subscripts.clear();
+ Sizes.clear();
+ break;
+ }
+
+ const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize);
+ Subscripts.push_back(Div);
+ Sizes.push_back(ElemSize);
+ }
+
+ return all_of(Subscripts, [&](const SCEV *Subscript) {
+ return isSimpleAddRecurrence(*Subscript, *L);
+ });
+ }
+
+ return false;
+}
+
+bool IndexedReference::isLoopInvariant(const Loop &L) const {
+ Value *Addr = getPointerOperand(&StoreOrLoadInst);
+ assert(Addr != nullptr && "Expecting either a load or a store instruction");
+ assert(SE.isSCEVable(Addr->getType()) && "Addr should be SCEVable");
+
+ if (SE.isLoopInvariant(SE.getSCEV(Addr), &L))
+ return true;
+
+ // The indexed reference is loop invariant if none of the coefficients use
+ // the loop induction variable.
+ bool allCoeffForLoopAreZero = all_of(Subscripts, [&](const SCEV *Subscript) {
+ return isCoeffForLoopZeroOrInvariant(*Subscript, L);
+ });
+
+ return allCoeffForLoopAreZero;
+}
+
+bool IndexedReference::isConsecutive(const Loop &L, unsigned CLS) const {
+ // The indexed reference is 'consecutive' if the only coefficient that uses
+ // the loop induction variable is the last one...
+ const SCEV *LastSubscript = Subscripts.back();
+ for (const SCEV *Subscript : Subscripts) {
+ if (Subscript == LastSubscript)
+ continue;
+ if (!isCoeffForLoopZeroOrInvariant(*Subscript, L))
+ return false;
+ }
+
+ // ...and the access stride is less than the cache line size.
+ const SCEV *Coeff = getLastCoefficient();
+ const SCEV *ElemSize = Sizes.back();
+ const SCEV *Stride = SE.getMulExpr(Coeff, ElemSize);
+ const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS);
+
+ return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize);
+}
+
+const SCEV *IndexedReference::getLastCoefficient() const {
+ const SCEV *LastSubscript = getLastSubscript();
+ assert(isa<SCEVAddRecExpr>(LastSubscript) &&
+ "Expecting a SCEV add recurrence expression");
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LastSubscript);
+ return AR->getStepRecurrence(SE);
+}
+
+bool IndexedReference::isCoeffForLoopZeroOrInvariant(const SCEV &Subscript,
+ const Loop &L) const {
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&Subscript);
+ return (AR != nullptr) ? AR->getLoop() != &L
+ : SE.isLoopInvariant(&Subscript, &L);
+}
+
+bool IndexedReference::isSimpleAddRecurrence(const SCEV &Subscript,
+ const Loop &L) const {
+ if (!isa<SCEVAddRecExpr>(Subscript))
+ return false;
+
+ const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(&Subscript);
+ assert(AR->getLoop() && "AR should have a loop");
+
+ if (!AR->isAffine())
+ return false;
+
+ const SCEV *Start = AR->getStart();
+ const SCEV *Step = AR->getStepRecurrence(SE);
+
+ if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L))
+ return false;
+
+ return true;
+}
+
+bool IndexedReference::isAliased(const IndexedReference &Other,
+ AliasAnalysis &AA) const {
+ const auto &Loc1 = MemoryLocation::get(&StoreOrLoadInst);
+ const auto &Loc2 = MemoryLocation::get(&Other.StoreOrLoadInst);
+ return AA.isMustAlias(Loc1, Loc2);
+}
+
+//===----------------------------------------------------------------------===//
+// CacheCost implementation
+//
+raw_ostream &llvm::operator<<(raw_ostream &OS, const CacheCost &CC) {
+ for (const auto &LC : CC.LoopCosts) {
+ const Loop *L = LC.first;
+ OS << "Loop '" << L->getName() << "' has cost = " << LC.second << "\n";
+ }
+ return OS;
+}
+
+CacheCost::CacheCost(const LoopVectorTy &Loops, const LoopInfo &LI,
+ ScalarEvolution &SE, TargetTransformInfo &TTI,
+ AliasAnalysis &AA, DependenceInfo &DI,
+ Optional<unsigned> TRT)
+ : Loops(Loops), TripCounts(), LoopCosts(),
+ TRT(TRT == None ? Optional<unsigned>(TemporalReuseThreshold) : TRT),
+ LI(LI), SE(SE), TTI(TTI), AA(AA), DI(DI) {
+ assert(!Loops.empty() && "Expecting a non-empty loop vector.");
+
+ for (const Loop *L : Loops) {
+ unsigned TripCount = SE.getSmallConstantTripCount(L);
+ TripCount = (TripCount == 0) ? DefaultTripCount : TripCount;
+ TripCounts.push_back({L, TripCount});
+ }
+
+ calculateCacheFootprint();
+}
+
+std::unique_ptr<CacheCost>
+CacheCost::getCacheCost(Loop &Root, LoopStandardAnalysisResults &AR,
+ DependenceInfo &DI, Optional<unsigned> TRT) {
+ if (Root.getParentLoop()) {
+ LLVM_DEBUG(dbgs() << "Expecting the outermost loop in a loop nest\n");
+ return nullptr;
+ }
+
+ LoopVectorTy Loops;
+ for (Loop *L : breadth_first(&Root))
+ Loops.push_back(L);
+
+ if (!getInnerMostLoop(Loops)) {
+ LLVM_DEBUG(dbgs() << "Cannot compute cache cost of loop nest with more "
+ "than one innermost loop\n");
+ return nullptr;
+ }
+
+ return std::make_unique<CacheCost>(Loops, AR.LI, AR.SE, AR.TTI, AR.AA, DI, TRT);
+}
+
+void CacheCost::calculateCacheFootprint() {
+ LLVM_DEBUG(dbgs() << "POPULATING REFERENCE GROUPS\n");
+ ReferenceGroupsTy RefGroups;
+ if (!populateReferenceGroups(RefGroups))
+ return;
+
+ LLVM_DEBUG(dbgs() << "COMPUTING LOOP CACHE COSTS\n");
+ for (const Loop *L : Loops) {
+ assert((std::find_if(LoopCosts.begin(), LoopCosts.end(),
+ [L](const LoopCacheCostTy &LCC) {
+ return LCC.first == L;
+ }) == LoopCosts.end()) &&
+ "Should not add duplicate element");
+ CacheCostTy LoopCost = computeLoopCacheCost(*L, RefGroups);
+ LoopCosts.push_back(std::make_pair(L, LoopCost));
+ }
+
+ sortLoopCosts();
+ RefGroups.clear();
+}
+
+bool CacheCost::populateReferenceGroups(ReferenceGroupsTy &RefGroups) const {
+ assert(RefGroups.empty() && "Reference groups should be empty");
+
+ unsigned CLS = TTI.getCacheLineSize();
+ Loop *InnerMostLoop = getInnerMostLoop(Loops);
+ assert(InnerMostLoop != nullptr && "Expecting a valid innermost loop");
+
+ for (BasicBlock *BB : InnerMostLoop->getBlocks()) {
+ for (Instruction &I : *BB) {
+ if (!isa<StoreInst>(I) && !isa<LoadInst>(I))
+ continue;
+
+ std::unique_ptr<IndexedReference> R(new IndexedReference(I, LI, SE));
+ if (!R->isValid())
+ continue;
+
+ bool Added = false;
+ for (ReferenceGroupTy &RefGroup : RefGroups) {
+ const IndexedReference &Representative = *RefGroup.front().get();
+ LLVM_DEBUG({
+ dbgs() << "References:\n";
+ dbgs().indent(2) << *R << "\n";
+ dbgs().indent(2) << Representative << "\n";
+ });
+
+ Optional<bool> HasTemporalReuse =
+ R->hasTemporalReuse(Representative, *TRT, *InnerMostLoop, DI, AA);
+ Optional<bool> HasSpacialReuse =
+ R->hasSpacialReuse(Representative, CLS, AA);
+
+ if ((HasTemporalReuse.hasValue() && *HasTemporalReuse) ||
+ (HasSpacialReuse.hasValue() && *HasSpacialReuse)) {
+ RefGroup.push_back(std::move(R));
+ Added = true;
+ break;
+ }
+ }
+
+ if (!Added) {
+ ReferenceGroupTy RG;
+ RG.push_back(std::move(R));
+ RefGroups.push_back(std::move(RG));
+ }
+ }
+ }
+
+ if (RefGroups.empty())
+ return false;
+
+ LLVM_DEBUG({
+ dbgs() << "\nIDENTIFIED REFERENCE GROUPS:\n";
+ int n = 1;
+ for (const ReferenceGroupTy &RG : RefGroups) {
+ dbgs().indent(2) << "RefGroup " << n << ":\n";
+ for (const auto &IR : RG)
+ dbgs().indent(4) << *IR << "\n";
+ n++;
+ }
+ dbgs() << "\n";
+ });
+
+ return true;
+}
+
+CacheCostTy
+CacheCost::computeLoopCacheCost(const Loop &L,
+ const ReferenceGroupsTy &RefGroups) const {
+ if (!L.isLoopSimplifyForm())
+ return InvalidCost;
+
+ LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName()
+ << "' as innermost loop.\n");
+
+ // Compute the product of the trip counts of each other loop in the nest.
+ CacheCostTy TripCountsProduct = 1;
+ for (const auto &TC : TripCounts) {
+ if (TC.first == &L)
+ continue;
+ TripCountsProduct *= TC.second;
+ }
+
+ CacheCostTy LoopCost = 0;
+ for (const ReferenceGroupTy &RG : RefGroups) {
+ CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L);
+ LoopCost += RefGroupCost * TripCountsProduct;
+ }
+
+ LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName()
+ << "' has cost=" << LoopCost << "\n");
+
+ return LoopCost;
+}
+
+CacheCostTy CacheCost::computeRefGroupCacheCost(const ReferenceGroupTy &RG,
+ const Loop &L) const {
+ assert(!RG.empty() && "Reference group should have at least one member.");
+
+ const IndexedReference *Representative = RG.front().get();
+ return Representative->computeRefCost(L, TTI.getCacheLineSize());
+}
+
+//===----------------------------------------------------------------------===//
+// LoopCachePrinterPass implementation
+//
+PreservedAnalyses LoopCachePrinterPass::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &U) {
+ Function *F = L.getHeader()->getParent();
+ DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI);
+
+ if (auto CC = CacheCost::getCacheCost(L, AR, DI))
+ OS << *CC;
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp
new file mode 100644
index 000000000000..dbab5db7dbc2
--- /dev/null
+++ b/llvm/lib/Analysis/LoopInfo.cpp
@@ -0,0 +1,1109 @@
+//===- LoopInfo.cpp - Natural Loop Calculator -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the LoopInfo class that is used to identify natural loops
+// and determine the loop depth of various nodes of the CFG. Note that the
+// loops identified may actually be several natural loops that share the same
+// header node... not just a single natural loop.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/IVDescriptors.h"
+#include "llvm/Analysis/LoopInfoImpl.h"
+#include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugLoc.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRPrintingPasses.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+using namespace llvm;
+
+// Explicitly instantiate methods in LoopInfoImpl.h for IR-level Loops.
+template class llvm::LoopBase<BasicBlock, Loop>;
+template class llvm::LoopInfoBase<BasicBlock, Loop>;
+
+// Always verify loopinfo if expensive checking is enabled.
+#ifdef EXPENSIVE_CHECKS
+bool llvm::VerifyLoopInfo = true;
+#else
+bool llvm::VerifyLoopInfo = false;
+#endif
+static cl::opt<bool, true>
+ VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo),
+ cl::Hidden, cl::desc("Verify loop info (time consuming)"));
+
+//===----------------------------------------------------------------------===//
+// Loop implementation
+//
+
+bool Loop::isLoopInvariant(const Value *V) const {
+ if (const Instruction *I = dyn_cast<Instruction>(V))
+ return !contains(I);
+ return true; // All non-instructions are loop invariant
+}
+
+bool Loop::hasLoopInvariantOperands(const Instruction *I) const {
+ return all_of(I->operands(), [this](Value *V) { return isLoopInvariant(V); });
+}
+
+bool Loop::makeLoopInvariant(Value *V, bool &Changed, Instruction *InsertPt,
+ MemorySSAUpdater *MSSAU) const {
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ return makeLoopInvariant(I, Changed, InsertPt, MSSAU);
+ return true; // All non-instructions are loop-invariant.
+}
+
+bool Loop::makeLoopInvariant(Instruction *I, bool &Changed,
+ Instruction *InsertPt,
+ MemorySSAUpdater *MSSAU) const {
+ // Test if the value is already loop-invariant.
+ if (isLoopInvariant(I))
+ return true;
+ if (!isSafeToSpeculativelyExecute(I))
+ return false;
+ if (I->mayReadFromMemory())
+ return false;
+ // EH block instructions are immobile.
+ if (I->isEHPad())
+ return false;
+ // Determine the insertion point, unless one was given.
+ if (!InsertPt) {
+ BasicBlock *Preheader = getLoopPreheader();
+ // Without a preheader, hoisting is not feasible.
+ if (!Preheader)
+ return false;
+ InsertPt = Preheader->getTerminator();
+ }
+ // Don't hoist instructions with loop-variant operands.
+ for (Value *Operand : I->operands())
+ if (!makeLoopInvariant(Operand, Changed, InsertPt, MSSAU))
+ return false;
+
+ // Hoist.
+ I->moveBefore(InsertPt);
+ if (MSSAU)
+ if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(I))
+ MSSAU->moveToPlace(MUD, InsertPt->getParent(), MemorySSA::End);
+
+ // There is possibility of hoisting this instruction above some arbitrary
+ // condition. Any metadata defined on it can be control dependent on this
+ // condition. Conservatively strip it here so that we don't give any wrong
+ // information to the optimizer.
+ I->dropUnknownNonDebugMetadata();
+
+ Changed = true;
+ return true;
+}
+
+bool Loop::getIncomingAndBackEdge(BasicBlock *&Incoming,
+ BasicBlock *&Backedge) const {
+ BasicBlock *H = getHeader();
+
+ Incoming = nullptr;
+ Backedge = nullptr;
+ pred_iterator PI = pred_begin(H);
+ assert(PI != pred_end(H) && "Loop must have at least one backedge!");
+ Backedge = *PI++;
+ if (PI == pred_end(H))
+ return false; // dead loop
+ Incoming = *PI++;
+ if (PI != pred_end(H))
+ return false; // multiple backedges?
+
+ if (contains(Incoming)) {
+ if (contains(Backedge))
+ return false;
+ std::swap(Incoming, Backedge);
+ } else if (!contains(Backedge))
+ return false;
+
+ assert(Incoming && Backedge && "expected non-null incoming and backedges");
+ return true;
+}
+
+PHINode *Loop::getCanonicalInductionVariable() const {
+ BasicBlock *H = getHeader();
+
+ BasicBlock *Incoming = nullptr, *Backedge = nullptr;
+ if (!getIncomingAndBackEdge(Incoming, Backedge))
+ return nullptr;
+
+ // Loop over all of the PHI nodes, looking for a canonical indvar.
+ for (BasicBlock::iterator I = H->begin(); isa<PHINode>(I); ++I) {
+ PHINode *PN = cast<PHINode>(I);
+ if (ConstantInt *CI =
+ dyn_cast<ConstantInt>(PN->getIncomingValueForBlock(Incoming)))
+ if (CI->isZero())
+ if (Instruction *Inc =
+ dyn_cast<Instruction>(PN->getIncomingValueForBlock(Backedge)))
+ if (Inc->getOpcode() == Instruction::Add && Inc->getOperand(0) == PN)
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Inc->getOperand(1)))
+ if (CI->isOne())
+ return PN;
+ }
+ return nullptr;
+}
+
+/// Get the latch condition instruction.
+static ICmpInst *getLatchCmpInst(const Loop &L) {
+ if (BasicBlock *Latch = L.getLoopLatch())
+ if (BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator()))
+ if (BI->isConditional())
+ return dyn_cast<ICmpInst>(BI->getCondition());
+
+ return nullptr;
+}
+
+/// Return the final value of the loop induction variable if found.
+static Value *findFinalIVValue(const Loop &L, const PHINode &IndVar,
+ const Instruction &StepInst) {
+ ICmpInst *LatchCmpInst = getLatchCmpInst(L);
+ if (!LatchCmpInst)
+ return nullptr;
+
+ Value *Op0 = LatchCmpInst->getOperand(0);
+ Value *Op1 = LatchCmpInst->getOperand(1);
+ if (Op0 == &IndVar || Op0 == &StepInst)
+ return Op1;
+
+ if (Op1 == &IndVar || Op1 == &StepInst)
+ return Op0;
+
+ return nullptr;
+}
+
+Optional<Loop::LoopBounds> Loop::LoopBounds::getBounds(const Loop &L,
+ PHINode &IndVar,
+ ScalarEvolution &SE) {
+ InductionDescriptor IndDesc;
+ if (!InductionDescriptor::isInductionPHI(&IndVar, &L, &SE, IndDesc))
+ return None;
+
+ Value *InitialIVValue = IndDesc.getStartValue();
+ Instruction *StepInst = IndDesc.getInductionBinOp();
+ if (!InitialIVValue || !StepInst)
+ return None;
+
+ const SCEV *Step = IndDesc.getStep();
+ Value *StepInstOp1 = StepInst->getOperand(1);
+ Value *StepInstOp0 = StepInst->getOperand(0);
+ Value *StepValue = nullptr;
+ if (SE.getSCEV(StepInstOp1) == Step)
+ StepValue = StepInstOp1;
+ else if (SE.getSCEV(StepInstOp0) == Step)
+ StepValue = StepInstOp0;
+
+ Value *FinalIVValue = findFinalIVValue(L, IndVar, *StepInst);
+ if (!FinalIVValue)
+ return None;
+
+ return LoopBounds(L, *InitialIVValue, *StepInst, StepValue, *FinalIVValue,
+ SE);
+}
+
+using Direction = Loop::LoopBounds::Direction;
+
+ICmpInst::Predicate Loop::LoopBounds::getCanonicalPredicate() const {
+ BasicBlock *Latch = L.getLoopLatch();
+ assert(Latch && "Expecting valid latch");
+
+ BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator());
+ assert(BI && BI->isConditional() && "Expecting conditional latch branch");
+
+ ICmpInst *LatchCmpInst = dyn_cast<ICmpInst>(BI->getCondition());
+ assert(LatchCmpInst &&
+ "Expecting the latch compare instruction to be a CmpInst");
+
+ // Need to inverse the predicate when first successor is not the loop
+ // header
+ ICmpInst::Predicate Pred = (BI->getSuccessor(0) == L.getHeader())
+ ? LatchCmpInst->getPredicate()
+ : LatchCmpInst->getInversePredicate();
+
+ if (LatchCmpInst->getOperand(0) == &getFinalIVValue())
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+
+ // Need to flip strictness of the predicate when the latch compare instruction
+ // is not using StepInst
+ if (LatchCmpInst->getOperand(0) == &getStepInst() ||
+ LatchCmpInst->getOperand(1) == &getStepInst())
+ return Pred;
+
+ // Cannot flip strictness of NE and EQ
+ if (Pred != ICmpInst::ICMP_NE && Pred != ICmpInst::ICMP_EQ)
+ return ICmpInst::getFlippedStrictnessPredicate(Pred);
+
+ Direction D = getDirection();
+ if (D == Direction::Increasing)
+ return ICmpInst::ICMP_SLT;
+
+ if (D == Direction::Decreasing)
+ return ICmpInst::ICMP_SGT;
+
+ // If cannot determine the direction, then unable to find the canonical
+ // predicate
+ return ICmpInst::BAD_ICMP_PREDICATE;
+}
+
+Direction Loop::LoopBounds::getDirection() const {
+ if (const SCEVAddRecExpr *StepAddRecExpr =
+ dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&getStepInst())))
+ if (const SCEV *StepRecur = StepAddRecExpr->getStepRecurrence(SE)) {
+ if (SE.isKnownPositive(StepRecur))
+ return Direction::Increasing;
+ if (SE.isKnownNegative(StepRecur))
+ return Direction::Decreasing;
+ }
+
+ return Direction::Unknown;
+}
+
+Optional<Loop::LoopBounds> Loop::getBounds(ScalarEvolution &SE) const {
+ if (PHINode *IndVar = getInductionVariable(SE))
+ return LoopBounds::getBounds(*this, *IndVar, SE);
+
+ return None;
+}
+
+PHINode *Loop::getInductionVariable(ScalarEvolution &SE) const {
+ if (!isLoopSimplifyForm())
+ return nullptr;
+
+ BasicBlock *Header = getHeader();
+ assert(Header && "Expected a valid loop header");
+ ICmpInst *CmpInst = getLatchCmpInst(*this);
+ if (!CmpInst)
+ return nullptr;
+
+ Instruction *LatchCmpOp0 = dyn_cast<Instruction>(CmpInst->getOperand(0));
+ Instruction *LatchCmpOp1 = dyn_cast<Instruction>(CmpInst->getOperand(1));
+
+ for (PHINode &IndVar : Header->phis()) {
+ InductionDescriptor IndDesc;
+ if (!InductionDescriptor::isInductionPHI(&IndVar, this, &SE, IndDesc))
+ continue;
+
+ Instruction *StepInst = IndDesc.getInductionBinOp();
+
+ // case 1:
+ // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+ // StepInst = IndVar + step
+ // cmp = StepInst < FinalValue
+ if (StepInst == LatchCmpOp0 || StepInst == LatchCmpOp1)
+ return &IndVar;
+
+ // case 2:
+ // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+ // StepInst = IndVar + step
+ // cmp = IndVar < FinalValue
+ if (&IndVar == LatchCmpOp0 || &IndVar == LatchCmpOp1)
+ return &IndVar;
+ }
+
+ return nullptr;
+}
+
+bool Loop::getInductionDescriptor(ScalarEvolution &SE,
+ InductionDescriptor &IndDesc) const {
+ if (PHINode *IndVar = getInductionVariable(SE))
+ return InductionDescriptor::isInductionPHI(IndVar, this, &SE, IndDesc);
+
+ return false;
+}
+
+bool Loop::isAuxiliaryInductionVariable(PHINode &AuxIndVar,
+ ScalarEvolution &SE) const {
+ // Located in the loop header
+ BasicBlock *Header = getHeader();
+ if (AuxIndVar.getParent() != Header)
+ return false;
+
+ // No uses outside of the loop
+ for (User *U : AuxIndVar.users())
+ if (const Instruction *I = dyn_cast<Instruction>(U))
+ if (!contains(I))
+ return false;
+
+ InductionDescriptor IndDesc;
+ if (!InductionDescriptor::isInductionPHI(&AuxIndVar, this, &SE, IndDesc))
+ return false;
+
+ // The step instruction opcode should be add or sub.
+ if (IndDesc.getInductionOpcode() != Instruction::Add &&
+ IndDesc.getInductionOpcode() != Instruction::Sub)
+ return false;
+
+ // Incremented by a loop invariant step for each loop iteration
+ return SE.isLoopInvariant(IndDesc.getStep(), this);
+}
+
+BranchInst *Loop::getLoopGuardBranch() const {
+ if (!isLoopSimplifyForm())
+ return nullptr;
+
+ BasicBlock *Preheader = getLoopPreheader();
+ BasicBlock *Latch = getLoopLatch();
+ assert(Preheader && Latch &&
+ "Expecting a loop with valid preheader and latch");
+
+ // Loop should be in rotate form.
+ if (!isLoopExiting(Latch))
+ return nullptr;
+
+ // Disallow loops with more than one unique exit block, as we do not verify
+ // that GuardOtherSucc post dominates all exit blocks.
+ BasicBlock *ExitFromLatch = getUniqueExitBlock();
+ if (!ExitFromLatch)
+ return nullptr;
+
+ BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor();
+ if (!ExitFromLatchSucc)
+ return nullptr;
+
+ BasicBlock *GuardBB = Preheader->getUniquePredecessor();
+ if (!GuardBB)
+ return nullptr;
+
+ assert(GuardBB->getTerminator() && "Expecting valid guard terminator");
+
+ BranchInst *GuardBI = dyn_cast<BranchInst>(GuardBB->getTerminator());
+ if (!GuardBI || GuardBI->isUnconditional())
+ return nullptr;
+
+ BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader)
+ ? GuardBI->getSuccessor(1)
+ : GuardBI->getSuccessor(0);
+ return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr;
+}
+
+bool Loop::isCanonical(ScalarEvolution &SE) const {
+ InductionDescriptor IndDesc;
+ if (!getInductionDescriptor(SE, IndDesc))
+ return false;
+
+ ConstantInt *Init = dyn_cast_or_null<ConstantInt>(IndDesc.getStartValue());
+ if (!Init || !Init->isZero())
+ return false;
+
+ if (IndDesc.getInductionOpcode() != Instruction::Add)
+ return false;
+
+ ConstantInt *Step = IndDesc.getConstIntStepValue();
+ if (!Step || !Step->isOne())
+ return false;
+
+ return true;
+}
+
+// Check that 'BB' doesn't have any uses outside of the 'L'
+static bool isBlockInLCSSAForm(const Loop &L, const BasicBlock &BB,
+ DominatorTree &DT) {
+ for (const Instruction &I : BB) {
+ // Tokens can't be used in PHI nodes and live-out tokens prevent loop
+ // optimizations, so for the purposes of considered LCSSA form, we
+ // can ignore them.
+ if (I.getType()->isTokenTy())
+ continue;
+
+ for (const Use &U : I.uses()) {
+ const Instruction *UI = cast<Instruction>(U.getUser());
+ const BasicBlock *UserBB = UI->getParent();
+ if (const PHINode *P = dyn_cast<PHINode>(UI))
+ UserBB = P->getIncomingBlock(U);
+
+ // Check the current block, as a fast-path, before checking whether
+ // the use is anywhere in the loop. Most values are used in the same
+ // block they are defined in. Also, blocks not reachable from the
+ // entry are special; uses in them don't need to go through PHIs.
+ if (UserBB != &BB && !L.contains(UserBB) &&
+ DT.isReachableFromEntry(UserBB))
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Loop::isLCSSAForm(DominatorTree &DT) const {
+ // For each block we check that it doesn't have any uses outside of this loop.
+ return all_of(this->blocks(), [&](const BasicBlock *BB) {
+ return isBlockInLCSSAForm(*this, *BB, DT);
+ });
+}
+
+bool Loop::isRecursivelyLCSSAForm(DominatorTree &DT, const LoopInfo &LI) const {
+ // For each block we check that it doesn't have any uses outside of its
+ // innermost loop. This process will transitively guarantee that the current
+ // loop and all of the nested loops are in LCSSA form.
+ return all_of(this->blocks(), [&](const BasicBlock *BB) {
+ return isBlockInLCSSAForm(*LI.getLoopFor(BB), *BB, DT);
+ });
+}
+
+bool Loop::isLoopSimplifyForm() const {
+ // Normal-form loops have a preheader, a single backedge, and all of their
+ // exits have all their predecessors inside the loop.
+ return getLoopPreheader() && getLoopLatch() && hasDedicatedExits();
+}
+
+// Routines that reform the loop CFG and split edges often fail on indirectbr.
+bool Loop::isSafeToClone() const {
+ // Return false if any loop blocks contain indirectbrs, or there are any calls
+ // to noduplicate functions.
+ // FIXME: it should be ok to clone CallBrInst's if we correctly update the
+ // operand list to reflect the newly cloned labels.
+ for (BasicBlock *BB : this->blocks()) {
+ if (isa<IndirectBrInst>(BB->getTerminator()) ||
+ isa<CallBrInst>(BB->getTerminator()))
+ return false;
+
+ for (Instruction &I : *BB)
+ if (auto CS = CallSite(&I))
+ if (CS.cannotDuplicate())
+ return false;
+ }
+ return true;
+}
+
+MDNode *Loop::getLoopID() const {
+ MDNode *LoopID = nullptr;
+
+ // Go through the latch blocks and check the terminator for the metadata.
+ SmallVector<BasicBlock *, 4> LatchesBlocks;
+ getLoopLatches(LatchesBlocks);
+ for (BasicBlock *BB : LatchesBlocks) {
+ Instruction *TI = BB->getTerminator();
+ MDNode *MD = TI->getMetadata(LLVMContext::MD_loop);
+
+ if (!MD)
+ return nullptr;
+
+ if (!LoopID)
+ LoopID = MD;
+ else if (MD != LoopID)
+ return nullptr;
+ }
+ if (!LoopID || LoopID->getNumOperands() == 0 ||
+ LoopID->getOperand(0) != LoopID)
+ return nullptr;
+ return LoopID;
+}
+
+void Loop::setLoopID(MDNode *LoopID) const {
+ assert((!LoopID || LoopID->getNumOperands() > 0) &&
+ "Loop ID needs at least one operand");
+ assert((!LoopID || LoopID->getOperand(0) == LoopID) &&
+ "Loop ID should refer to itself");
+
+ SmallVector<BasicBlock *, 4> LoopLatches;
+ getLoopLatches(LoopLatches);
+ for (BasicBlock *BB : LoopLatches)
+ BB->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
+}
+
+void Loop::setLoopAlreadyUnrolled() {
+ LLVMContext &Context = getHeader()->getContext();
+
+ MDNode *DisableUnrollMD =
+ MDNode::get(Context, MDString::get(Context, "llvm.loop.unroll.disable"));
+ MDNode *LoopID = getLoopID();
+ MDNode *NewLoopID = makePostTransformationMetadata(
+ Context, LoopID, {"llvm.loop.unroll."}, {DisableUnrollMD});
+ setLoopID(NewLoopID);
+}
+
+bool Loop::isAnnotatedParallel() const {
+ MDNode *DesiredLoopIdMetadata = getLoopID();
+
+ if (!DesiredLoopIdMetadata)
+ return false;
+
+ MDNode *ParallelAccesses =
+ findOptionMDForLoop(this, "llvm.loop.parallel_accesses");
+ SmallPtrSet<MDNode *, 4>
+ ParallelAccessGroups; // For scalable 'contains' check.
+ if (ParallelAccesses) {
+ for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) {
+ MDNode *AccGroup = cast<MDNode>(MD.get());
+ assert(isValidAsAccessGroup(AccGroup) &&
+ "List item must be an access group");
+ ParallelAccessGroups.insert(AccGroup);
+ }
+ }
+
+ // The loop branch contains the parallel loop metadata. In order to ensure
+ // that any parallel-loop-unaware optimization pass hasn't added loop-carried
+ // dependencies (thus converted the loop back to a sequential loop), check
+ // that all the memory instructions in the loop belong to an access group that
+ // is parallel to this loop.
+ for (BasicBlock *BB : this->blocks()) {
+ for (Instruction &I : *BB) {
+ if (!I.mayReadOrWriteMemory())
+ continue;
+
+ if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) {
+ auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool {
+ if (AG->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(AG) && "Item must be an access group");
+ return ParallelAccessGroups.count(AG);
+ }
+
+ for (const MDOperand &AccessListItem : AG->operands()) {
+ MDNode *AccGroup = cast<MDNode>(AccessListItem.get());
+ assert(isValidAsAccessGroup(AccGroup) &&
+ "List item must be an access group");
+ if (ParallelAccessGroups.count(AccGroup))
+ return true;
+ }
+ return false;
+ };
+
+ if (ContainsAccessGroup(AccessGroup))
+ continue;
+ }
+
+ // The memory instruction can refer to the loop identifier metadata
+ // directly or indirectly through another list metadata (in case of
+ // nested parallel loops). The loop identifier metadata refers to
+ // itself so we can check both cases with the same routine.
+ MDNode *LoopIdMD =
+ I.getMetadata(LLVMContext::MD_mem_parallel_loop_access);
+
+ if (!LoopIdMD)
+ return false;
+
+ bool LoopIdMDFound = false;
+ for (const MDOperand &MDOp : LoopIdMD->operands()) {
+ if (MDOp == DesiredLoopIdMetadata) {
+ LoopIdMDFound = true;
+ break;
+ }
+ }
+
+ if (!LoopIdMDFound)
+ return false;
+ }
+ }
+ return true;
+}
+
+DebugLoc Loop::getStartLoc() const { return getLocRange().getStart(); }
+
+Loop::LocRange Loop::getLocRange() const {
+ // If we have a debug location in the loop ID, then use it.
+ if (MDNode *LoopID = getLoopID()) {
+ DebugLoc Start;
+ // We use the first DebugLoc in the header as the start location of the loop
+ // and if there is a second DebugLoc in the header we use it as end location
+ // of the loop.
+ for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
+ if (DILocation *L = dyn_cast<DILocation>(LoopID->getOperand(i))) {
+ if (!Start)
+ Start = DebugLoc(L);
+ else
+ return LocRange(Start, DebugLoc(L));
+ }
+ }
+
+ if (Start)
+ return LocRange(Start);
+ }
+
+ // Try the pre-header first.
+ if (BasicBlock *PHeadBB = getLoopPreheader())
+ if (DebugLoc DL = PHeadBB->getTerminator()->getDebugLoc())
+ return LocRange(DL);
+
+ // If we have no pre-header or there are no instructions with debug
+ // info in it, try the header.
+ if (BasicBlock *HeadBB = getHeader())
+ return LocRange(HeadBB->getTerminator()->getDebugLoc());
+
+ return LocRange();
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void Loop::dump() const { print(dbgs()); }
+
+LLVM_DUMP_METHOD void Loop::dumpVerbose() const {
+ print(dbgs(), /*Depth=*/0, /*Verbose=*/true);
+}
+#endif
+
+//===----------------------------------------------------------------------===//
+// UnloopUpdater implementation
+//
+
+namespace {
+/// Find the new parent loop for all blocks within the "unloop" whose last
+/// backedges has just been removed.
+class UnloopUpdater {
+ Loop &Unloop;
+ LoopInfo *LI;
+
+ LoopBlocksDFS DFS;
+
+ // Map unloop's immediate subloops to their nearest reachable parents. Nested
+ // loops within these subloops will not change parents. However, an immediate
+ // subloop's new parent will be the nearest loop reachable from either its own
+ // exits *or* any of its nested loop's exits.
+ DenseMap<Loop *, Loop *> SubloopParents;
+
+ // Flag the presence of an irreducible backedge whose destination is a block
+ // directly contained by the original unloop.
+ bool FoundIB;
+
+public:
+ UnloopUpdater(Loop *UL, LoopInfo *LInfo)
+ : Unloop(*UL), LI(LInfo), DFS(UL), FoundIB(false) {}
+
+ void updateBlockParents();
+
+ void removeBlocksFromAncestors();
+
+ void updateSubloopParents();
+
+protected:
+ Loop *getNearestLoop(BasicBlock *BB, Loop *BBLoop);
+};
+} // end anonymous namespace
+
+/// Update the parent loop for all blocks that are directly contained within the
+/// original "unloop".
+void UnloopUpdater::updateBlockParents() {
+ if (Unloop.getNumBlocks()) {
+ // Perform a post order CFG traversal of all blocks within this loop,
+ // propagating the nearest loop from successors to predecessors.
+ LoopBlocksTraversal Traversal(DFS, LI);
+ for (BasicBlock *POI : Traversal) {
+
+ Loop *L = LI->getLoopFor(POI);
+ Loop *NL = getNearestLoop(POI, L);
+
+ if (NL != L) {
+ // For reducible loops, NL is now an ancestor of Unloop.
+ assert((NL != &Unloop && (!NL || NL->contains(&Unloop))) &&
+ "uninitialized successor");
+ LI->changeLoopFor(POI, NL);
+ } else {
+ // Or the current block is part of a subloop, in which case its parent
+ // is unchanged.
+ assert((FoundIB || Unloop.contains(L)) && "uninitialized successor");
+ }
+ }
+ }
+ // Each irreducible loop within the unloop induces a round of iteration using
+ // the DFS result cached by Traversal.
+ bool Changed = FoundIB;
+ for (unsigned NIters = 0; Changed; ++NIters) {
+ assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm");
+
+ // Iterate over the postorder list of blocks, propagating the nearest loop
+ // from successors to predecessors as before.
+ Changed = false;
+ for (LoopBlocksDFS::POIterator POI = DFS.beginPostorder(),
+ POE = DFS.endPostorder();
+ POI != POE; ++POI) {
+
+ Loop *L = LI->getLoopFor(*POI);
+ Loop *NL = getNearestLoop(*POI, L);
+ if (NL != L) {
+ assert(NL != &Unloop && (!NL || NL->contains(&Unloop)) &&
+ "uninitialized successor");
+ LI->changeLoopFor(*POI, NL);
+ Changed = true;
+ }
+ }
+ }
+}
+
+/// Remove unloop's blocks from all ancestors below their new parents.
+void UnloopUpdater::removeBlocksFromAncestors() {
+ // Remove all unloop's blocks (including those in nested subloops) from
+ // ancestors below the new parent loop.
+ for (Loop::block_iterator BI = Unloop.block_begin(), BE = Unloop.block_end();
+ BI != BE; ++BI) {
+ Loop *OuterParent = LI->getLoopFor(*BI);
+ if (Unloop.contains(OuterParent)) {
+ while (OuterParent->getParentLoop() != &Unloop)
+ OuterParent = OuterParent->getParentLoop();
+ OuterParent = SubloopParents[OuterParent];
+ }
+ // Remove blocks from former Ancestors except Unloop itself which will be
+ // deleted.
+ for (Loop *OldParent = Unloop.getParentLoop(); OldParent != OuterParent;
+ OldParent = OldParent->getParentLoop()) {
+ assert(OldParent && "new loop is not an ancestor of the original");
+ OldParent->removeBlockFromLoop(*BI);
+ }
+ }
+}
+
+/// Update the parent loop for all subloops directly nested within unloop.
+void UnloopUpdater::updateSubloopParents() {
+ while (!Unloop.empty()) {
+ Loop *Subloop = *std::prev(Unloop.end());
+ Unloop.removeChildLoop(std::prev(Unloop.end()));
+
+ assert(SubloopParents.count(Subloop) && "DFS failed to visit subloop");
+ if (Loop *Parent = SubloopParents[Subloop])
+ Parent->addChildLoop(Subloop);
+ else
+ LI->addTopLevelLoop(Subloop);
+ }
+}
+
+/// Return the nearest parent loop among this block's successors. If a successor
+/// is a subloop header, consider its parent to be the nearest parent of the
+/// subloop's exits.
+///
+/// For subloop blocks, simply update SubloopParents and return NULL.
+Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) {
+
+ // Initially for blocks directly contained by Unloop, NearLoop == Unloop and
+ // is considered uninitialized.
+ Loop *NearLoop = BBLoop;
+
+ Loop *Subloop = nullptr;
+ if (NearLoop != &Unloop && Unloop.contains(NearLoop)) {
+ Subloop = NearLoop;
+ // Find the subloop ancestor that is directly contained within Unloop.
+ while (Subloop->getParentLoop() != &Unloop) {
+ Subloop = Subloop->getParentLoop();
+ assert(Subloop && "subloop is not an ancestor of the original loop");
+ }
+ // Get the current nearest parent of the Subloop exits, initially Unloop.
+ NearLoop = SubloopParents.insert({Subloop, &Unloop}).first->second;
+ }
+
+ succ_iterator I = succ_begin(BB), E = succ_end(BB);
+ if (I == E) {
+ assert(!Subloop && "subloop blocks must have a successor");
+ NearLoop = nullptr; // unloop blocks may now exit the function.
+ }
+ for (; I != E; ++I) {
+ if (*I == BB)
+ continue; // self loops are uninteresting
+
+ Loop *L = LI->getLoopFor(*I);
+ if (L == &Unloop) {
+ // This successor has not been processed. This path must lead to an
+ // irreducible backedge.
+ assert((FoundIB || !DFS.hasPostorder(*I)) && "should have seen IB");
+ FoundIB = true;
+ }
+ if (L != &Unloop && Unloop.contains(L)) {
+ // Successor is in a subloop.
+ if (Subloop)
+ continue; // Branching within subloops. Ignore it.
+
+ // BB branches from the original into a subloop header.
+ assert(L->getParentLoop() == &Unloop && "cannot skip into nested loops");
+
+ // Get the current nearest parent of the Subloop's exits.
+ L = SubloopParents[L];
+ // L could be Unloop if the only exit was an irreducible backedge.
+ }
+ if (L == &Unloop) {
+ continue;
+ }
+ // Handle critical edges from Unloop into a sibling loop.
+ if (L && !L->contains(&Unloop)) {
+ L = L->getParentLoop();
+ }
+ // Remember the nearest parent loop among successors or subloop exits.
+ if (NearLoop == &Unloop || !NearLoop || NearLoop->contains(L))
+ NearLoop = L;
+ }
+ if (Subloop) {
+ SubloopParents[Subloop] = NearLoop;
+ return BBLoop;
+ }
+ return NearLoop;
+}
+
+LoopInfo::LoopInfo(const DomTreeBase<BasicBlock> &DomTree) { analyze(DomTree); }
+
+bool LoopInfo::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // Check whether the analysis, all analyses on functions, or the function's
+ // CFG have been preserved.
+ auto PAC = PA.getChecker<LoopAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
+ PAC.preservedSet<CFGAnalyses>());
+}
+
+void LoopInfo::erase(Loop *Unloop) {
+ assert(!Unloop->isInvalid() && "Loop has already been erased!");
+
+ auto InvalidateOnExit = make_scope_exit([&]() { destroy(Unloop); });
+
+ // First handle the special case of no parent loop to simplify the algorithm.
+ if (!Unloop->getParentLoop()) {
+ // Since BBLoop had no parent, Unloop blocks are no longer in a loop.
+ for (Loop::block_iterator I = Unloop->block_begin(),
+ E = Unloop->block_end();
+ I != E; ++I) {
+
+ // Don't reparent blocks in subloops.
+ if (getLoopFor(*I) != Unloop)
+ continue;
+
+ // Blocks no longer have a parent but are still referenced by Unloop until
+ // the Unloop object is deleted.
+ changeLoopFor(*I, nullptr);
+ }
+
+ // Remove the loop from the top-level LoopInfo object.
+ for (iterator I = begin();; ++I) {
+ assert(I != end() && "Couldn't find loop");
+ if (*I == Unloop) {
+ removeLoop(I);
+ break;
+ }
+ }
+
+ // Move all of the subloops to the top-level.
+ while (!Unloop->empty())
+ addTopLevelLoop(Unloop->removeChildLoop(std::prev(Unloop->end())));
+
+ return;
+ }
+
+ // Update the parent loop for all blocks within the loop. Blocks within
+ // subloops will not change parents.
+ UnloopUpdater Updater(Unloop, this);
+ Updater.updateBlockParents();
+
+ // Remove blocks from former ancestor loops.
+ Updater.removeBlocksFromAncestors();
+
+ // Add direct subloops as children in their new parent loop.
+ Updater.updateSubloopParents();
+
+ // Remove unloop from its parent loop.
+ Loop *ParentLoop = Unloop->getParentLoop();
+ for (Loop::iterator I = ParentLoop->begin();; ++I) {
+ assert(I != ParentLoop->end() && "Couldn't find loop");
+ if (*I == Unloop) {
+ ParentLoop->removeChildLoop(I);
+ break;
+ }
+ }
+}
+
+AnalysisKey LoopAnalysis::Key;
+
+LoopInfo LoopAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+ // FIXME: Currently we create a LoopInfo from scratch for every function.
+ // This may prove to be too wasteful due to deallocating and re-allocating
+ // memory each time for the underlying map and vector datastructures. At some
+ // point it may prove worthwhile to use a freelist and recycle LoopInfo
+ // objects. I don't want to add that kind of complexity until the scope of
+ // the problem is better understood.
+ LoopInfo LI;
+ LI.analyze(AM.getResult<DominatorTreeAnalysis>(F));
+ return LI;
+}
+
+PreservedAnalyses LoopPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AM.getResult<LoopAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
+
+void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) {
+
+ if (forcePrintModuleIR()) {
+ // handling -print-module-scope
+ OS << Banner << " (loop: ";
+ L.getHeader()->printAsOperand(OS, false);
+ OS << ")\n";
+
+ // printing whole module
+ OS << *L.getHeader()->getModule();
+ return;
+ }
+
+ OS << Banner;
+
+ auto *PreHeader = L.getLoopPreheader();
+ if (PreHeader) {
+ OS << "\n; Preheader:";
+ PreHeader->print(OS);
+ OS << "\n; Loop:";
+ }
+
+ for (auto *Block : L.blocks())
+ if (Block)
+ Block->print(OS);
+ else
+ OS << "Printing <null> block";
+
+ SmallVector<BasicBlock *, 8> ExitBlocks;
+ L.getExitBlocks(ExitBlocks);
+ if (!ExitBlocks.empty()) {
+ OS << "\n; Exit blocks";
+ for (auto *Block : ExitBlocks)
+ if (Block)
+ Block->print(OS);
+ else
+ OS << "Printing <null> block";
+ }
+}
+
+MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) {
+ // No loop metadata node, no loop properties.
+ if (!LoopID)
+ return nullptr;
+
+ // First operand should refer to the metadata node itself, for legacy reasons.
+ assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
+ assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
+
+ // Iterate over the metdata node operands and look for MDString metadata.
+ for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
+ MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
+ if (!MD || MD->getNumOperands() < 1)
+ continue;
+ MDString *S = dyn_cast<MDString>(MD->getOperand(0));
+ if (!S)
+ continue;
+ // Return the operand node if MDString holds expected metadata.
+ if (Name.equals(S->getString()))
+ return MD;
+ }
+
+ // Loop property not found.
+ return nullptr;
+}
+
+MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) {
+ return findOptionMDForLoopID(TheLoop->getLoopID(), Name);
+}
+
+bool llvm::isValidAsAccessGroup(MDNode *Node) {
+ return Node->getNumOperands() == 0 && Node->isDistinct();
+}
+
+MDNode *llvm::makePostTransformationMetadata(LLVMContext &Context,
+ MDNode *OrigLoopID,
+ ArrayRef<StringRef> RemovePrefixes,
+ ArrayRef<MDNode *> AddAttrs) {
+ // First remove any existing loop metadata related to this transformation.
+ SmallVector<Metadata *, 4> MDs;
+
+ // Reserve first location for self reference to the LoopID metadata node.
+ TempMDTuple TempNode = MDNode::getTemporary(Context, None);
+ MDs.push_back(TempNode.get());
+
+ // Remove metadata for the transformation that has been applied or that became
+ // outdated.
+ if (OrigLoopID) {
+ for (unsigned i = 1, ie = OrigLoopID->getNumOperands(); i < ie; ++i) {
+ bool IsVectorMetadata = false;
+ Metadata *Op = OrigLoopID->getOperand(i);
+ if (MDNode *MD = dyn_cast<MDNode>(Op)) {
+ const MDString *S = dyn_cast<MDString>(MD->getOperand(0));
+ if (S)
+ IsVectorMetadata =
+ llvm::any_of(RemovePrefixes, [S](StringRef Prefix) -> bool {
+ return S->getString().startswith(Prefix);
+ });
+ }
+ if (!IsVectorMetadata)
+ MDs.push_back(Op);
+ }
+ }
+
+ // Add metadata to avoid reapplying a transformation, such as
+ // llvm.loop.unroll.disable and llvm.loop.isvectorized.
+ MDs.append(AddAttrs.begin(), AddAttrs.end());
+
+ MDNode *NewLoopID = MDNode::getDistinct(Context, MDs);
+ // Replace the temporary node with a self-reference.
+ NewLoopID->replaceOperandWith(0, NewLoopID);
+ return NewLoopID;
+}
+
+//===----------------------------------------------------------------------===//
+// LoopInfo implementation
+//
+
+char LoopInfoWrapperPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LoopInfoWrapperPass, "loops", "Natural Loop Information",
+ true, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(LoopInfoWrapperPass, "loops", "Natural Loop Information",
+ true, true)
+
+bool LoopInfoWrapperPass::runOnFunction(Function &) {
+ releaseMemory();
+ LI.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree());
+ return false;
+}
+
+void LoopInfoWrapperPass::verifyAnalysis() const {
+ // LoopInfoWrapperPass is a FunctionPass, but verifying every loop in the
+ // function each time verifyAnalysis is called is very expensive. The
+ // -verify-loop-info option can enable this. In order to perform some
+ // checking by default, LoopPass has been taught to call verifyLoop manually
+ // during loop pass sequences.
+ if (VerifyLoopInfo) {
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ LI.verify(DT);
+ }
+}
+
+void LoopInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<DominatorTreeWrapperPass>();
+}
+
+void LoopInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
+ LI.print(OS);
+}
+
+PreservedAnalyses LoopVerifierPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ LI.verify(DT);
+ return PreservedAnalyses::all();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlocksDFS implementation
+//
+
+/// Traverse the loop blocks and store the DFS result.
+/// Useful for clients that just want the final DFS result and don't need to
+/// visit blocks during the initial traversal.
+void LoopBlocksDFS::perform(LoopInfo *LI) {
+ LoopBlocksTraversal Traversal(*this, LI);
+ for (LoopBlocksTraversal::POTIterator POI = Traversal.begin(),
+ POE = Traversal.end();
+ POI != POE; ++POI)
+ ;
+}
diff --git a/llvm/lib/Analysis/LoopPass.cpp b/llvm/lib/Analysis/LoopPass.cpp
new file mode 100644
index 000000000000..4ab3798039d8
--- /dev/null
+++ b/llvm/lib/Analysis/LoopPass.cpp
@@ -0,0 +1,414 @@
+//===- LoopPass.cpp - Loop Pass and Loop Pass Manager ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements LoopPass and LPPassManager. All loop optimization
+// and transformation passes are derived from LoopPass. LPPassManager is
+// responsible for managing LoopPasses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRPrintingPasses.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/OptBisect.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PassTimingInfo.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Timer.h"
+#include "llvm/Support/TimeProfiler.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-pass-manager"
+
+namespace {
+
+/// PrintLoopPass - Print a Function corresponding to a Loop.
+///
+class PrintLoopPassWrapper : public LoopPass {
+ raw_ostream &OS;
+ std::string Banner;
+
+public:
+ static char ID;
+ PrintLoopPassWrapper() : LoopPass(ID), OS(dbgs()) {}
+ PrintLoopPassWrapper(raw_ostream &OS, const std::string &Banner)
+ : LoopPass(ID), OS(OS), Banner(Banner) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+
+ bool runOnLoop(Loop *L, LPPassManager &) override {
+ auto BBI = llvm::find_if(L->blocks(), [](BasicBlock *BB) { return BB; });
+ if (BBI != L->blocks().end() &&
+ isFunctionInPrintList((*BBI)->getParent()->getName())) {
+ printLoop(*L, OS, Banner);
+ }
+ return false;
+ }
+
+ StringRef getPassName() const override { return "Print Loop IR"; }
+};
+
+char PrintLoopPassWrapper::ID = 0;
+}
+
+//===----------------------------------------------------------------------===//
+// LPPassManager
+//
+
+char LPPassManager::ID = 0;
+
+LPPassManager::LPPassManager()
+ : FunctionPass(ID), PMDataManager() {
+ LI = nullptr;
+ CurrentLoop = nullptr;
+}
+
+// Insert loop into loop nest (LoopInfo) and loop queue (LQ).
+void LPPassManager::addLoop(Loop &L) {
+ if (!L.getParentLoop()) {
+ // This is the top level loop.
+ LQ.push_front(&L);
+ return;
+ }
+
+ // Insert L into the loop queue after the parent loop.
+ for (auto I = LQ.begin(), E = LQ.end(); I != E; ++I) {
+ if (*I == L.getParentLoop()) {
+ // deque does not support insert after.
+ ++I;
+ LQ.insert(I, 1, &L);
+ return;
+ }
+ }
+}
+
+/// cloneBasicBlockSimpleAnalysis - Invoke cloneBasicBlockAnalysis hook for
+/// all loop passes.
+void LPPassManager::cloneBasicBlockSimpleAnalysis(BasicBlock *From,
+ BasicBlock *To, Loop *L) {
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *LP = getContainedPass(Index);
+ LP->cloneBasicBlockAnalysis(From, To, L);
+ }
+}
+
+/// deleteSimpleAnalysisValue - Invoke deleteAnalysisValue hook for all passes.
+void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) {
+ if (BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
+ for (Instruction &I : *BB) {
+ deleteSimpleAnalysisValue(&I, L);
+ }
+ }
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *LP = getContainedPass(Index);
+ LP->deleteAnalysisValue(V, L);
+ }
+}
+
+/// Invoke deleteAnalysisLoop hook for all passes.
+void LPPassManager::deleteSimpleAnalysisLoop(Loop *L) {
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *LP = getContainedPass(Index);
+ LP->deleteAnalysisLoop(L);
+ }
+}
+
+
+// Recurse through all subloops and all loops into LQ.
+static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) {
+ LQ.push_back(L);
+ for (Loop *I : reverse(*L))
+ addLoopIntoQueue(I, LQ);
+}
+
+/// Pass Manager itself does not invalidate any analysis info.
+void LPPassManager::getAnalysisUsage(AnalysisUsage &Info) const {
+ // LPPassManager needs LoopInfo. In the long term LoopInfo class will
+ // become part of LPPassManager.
+ Info.addRequired<LoopInfoWrapperPass>();
+ Info.addRequired<DominatorTreeWrapperPass>();
+ Info.setPreservesAll();
+}
+
+void LPPassManager::markLoopAsDeleted(Loop &L) {
+ assert((&L == CurrentLoop || CurrentLoop->contains(&L)) &&
+ "Must not delete loop outside the current loop tree!");
+ // If this loop appears elsewhere within the queue, we also need to remove it
+ // there. However, we have to be careful to not remove the back of the queue
+ // as that is assumed to match the current loop.
+ assert(LQ.back() == CurrentLoop && "Loop queue back isn't the current loop!");
+ LQ.erase(std::remove(LQ.begin(), LQ.end(), &L), LQ.end());
+
+ if (&L == CurrentLoop) {
+ CurrentLoopDeleted = true;
+ // Add this loop back onto the back of the queue to preserve our invariants.
+ LQ.push_back(&L);
+ }
+}
+
+/// run - Execute all of the passes scheduled for execution. Keep track of
+/// whether any of the passes modifies the function, and if so, return true.
+bool LPPassManager::runOnFunction(Function &F) {
+ auto &LIWP = getAnalysis<LoopInfoWrapperPass>();
+ LI = &LIWP.getLoopInfo();
+ Module &M = *F.getParent();
+#if 0
+ DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+#endif
+ bool Changed = false;
+
+ // Collect inherited analysis from Module level pass manager.
+ populateInheritedAnalysis(TPM->activeStack);
+
+ // Populate the loop queue in reverse program order. There is no clear need to
+ // process sibling loops in either forward or reverse order. There may be some
+ // advantage in deleting uses in a later loop before optimizing the
+ // definitions in an earlier loop. If we find a clear reason to process in
+ // forward order, then a forward variant of LoopPassManager should be created.
+ //
+ // Note that LoopInfo::iterator visits loops in reverse program
+ // order. Here, reverse_iterator gives us a forward order, and the LoopQueue
+ // reverses the order a third time by popping from the back.
+ for (Loop *L : reverse(*LI))
+ addLoopIntoQueue(L, LQ);
+
+ if (LQ.empty()) // No loops, skip calling finalizers
+ return false;
+
+ // Initialization
+ for (Loop *L : LQ) {
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *P = getContainedPass(Index);
+ Changed |= P->doInitialization(L, *this);
+ }
+ }
+
+ // Walk Loops
+ unsigned InstrCount, FunctionSize = 0;
+ StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount;
+ bool EmitICRemark = M.shouldEmitInstrCountChangedRemark();
+ // Collect the initial size of the module and the function we're looking at.
+ if (EmitICRemark) {
+ InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount);
+ FunctionSize = F.getInstructionCount();
+ }
+ while (!LQ.empty()) {
+ CurrentLoopDeleted = false;
+ CurrentLoop = LQ.back();
+
+ // Run all passes on the current Loop.
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *P = getContainedPass(Index);
+
+ llvm::TimeTraceScope LoopPassScope("RunLoopPass", P->getPassName());
+
+ dumpPassInfo(P, EXECUTION_MSG, ON_LOOP_MSG,
+ CurrentLoop->getHeader()->getName());
+ dumpRequiredSet(P);
+
+ initializeAnalysisImpl(P);
+
+ bool LocalChanged = false;
+ {
+ PassManagerPrettyStackEntry X(P, *CurrentLoop->getHeader());
+ TimeRegion PassTimer(getPassTimer(P));
+ LocalChanged = P->runOnLoop(CurrentLoop, *this);
+ Changed |= LocalChanged;
+ if (EmitICRemark) {
+ unsigned NewSize = F.getInstructionCount();
+ // Update the size of the function, emit a remark, and update the
+ // size of the module.
+ if (NewSize != FunctionSize) {
+ int64_t Delta = static_cast<int64_t>(NewSize) -
+ static_cast<int64_t>(FunctionSize);
+ emitInstrCountChangedRemark(P, M, Delta, InstrCount,
+ FunctionToInstrCount, &F);
+ InstrCount = static_cast<int64_t>(InstrCount) + Delta;
+ FunctionSize = NewSize;
+ }
+ }
+ }
+
+ if (LocalChanged)
+ dumpPassInfo(P, MODIFICATION_MSG, ON_LOOP_MSG,
+ CurrentLoopDeleted ? "<deleted loop>"
+ : CurrentLoop->getName());
+ dumpPreservedSet(P);
+
+ if (CurrentLoopDeleted) {
+ // Notify passes that the loop is being deleted.
+ deleteSimpleAnalysisLoop(CurrentLoop);
+ } else {
+ // Manually check that this loop is still healthy. This is done
+ // instead of relying on LoopInfo::verifyLoop since LoopInfo
+ // is a function pass and it's really expensive to verify every
+ // loop in the function every time. That level of checking can be
+ // enabled with the -verify-loop-info option.
+ {
+ TimeRegion PassTimer(getPassTimer(&LIWP));
+ CurrentLoop->verifyLoop();
+ }
+ // Here we apply same reasoning as in the above case. Only difference
+ // is that LPPassManager might run passes which do not require LCSSA
+ // form (LoopPassPrinter for example). We should skip verification for
+ // such passes.
+ // FIXME: Loop-sink currently break LCSSA. Fix it and reenable the
+ // verification!
+#if 0
+ if (mustPreserveAnalysisID(LCSSAVerificationPass::ID))
+ assert(CurrentLoop->isRecursivelyLCSSAForm(*DT, *LI));
+#endif
+
+ // Then call the regular verifyAnalysis functions.
+ verifyPreservedAnalysis(P);
+
+ F.getContext().yield();
+ }
+
+ removeNotPreservedAnalysis(P);
+ recordAvailableAnalysis(P);
+ removeDeadPasses(P,
+ CurrentLoopDeleted ? "<deleted>"
+ : CurrentLoop->getHeader()->getName(),
+ ON_LOOP_MSG);
+
+ if (CurrentLoopDeleted)
+ // Do not run other passes on this loop.
+ break;
+ }
+
+ // If the loop was deleted, release all the loop passes. This frees up
+ // some memory, and avoids trouble with the pass manager trying to call
+ // verifyAnalysis on them.
+ if (CurrentLoopDeleted) {
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ Pass *P = getContainedPass(Index);
+ freePass(P, "<deleted>", ON_LOOP_MSG);
+ }
+ }
+
+ // Pop the loop from queue after running all passes.
+ LQ.pop_back();
+ }
+
+ // Finalization
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ LoopPass *P = getContainedPass(Index);
+ Changed |= P->doFinalization();
+ }
+
+ return Changed;
+}
+
+/// Print passes managed by this manager
+void LPPassManager::dumpPassStructure(unsigned Offset) {
+ errs().indent(Offset*2) << "Loop Pass Manager\n";
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ Pass *P = getContainedPass(Index);
+ P->dumpPassStructure(Offset + 1);
+ dumpLastUses(P, Offset+1);
+ }
+}
+
+
+//===----------------------------------------------------------------------===//
+// LoopPass
+
+Pass *LoopPass::createPrinterPass(raw_ostream &O,
+ const std::string &Banner) const {
+ return new PrintLoopPassWrapper(O, Banner);
+}
+
+// Check if this pass is suitable for the current LPPassManager, if
+// available. This pass P is not suitable for a LPPassManager if P
+// is not preserving higher level analysis info used by other
+// LPPassManager passes. In such case, pop LPPassManager from the
+// stack. This will force assignPassManager() to create new
+// LPPassManger as expected.
+void LoopPass::preparePassManager(PMStack &PMS) {
+
+ // Find LPPassManager
+ while (!PMS.empty() &&
+ PMS.top()->getPassManagerType() > PMT_LoopPassManager)
+ PMS.pop();
+
+ // If this pass is destroying high level information that is used
+ // by other passes that are managed by LPM then do not insert
+ // this pass in current LPM. Use new LPPassManager.
+ if (PMS.top()->getPassManagerType() == PMT_LoopPassManager &&
+ !PMS.top()->preserveHigherLevelAnalysis(this))
+ PMS.pop();
+}
+
+/// Assign pass manager to manage this pass.
+void LoopPass::assignPassManager(PMStack &PMS,
+ PassManagerType PreferredType) {
+ // Find LPPassManager
+ while (!PMS.empty() &&
+ PMS.top()->getPassManagerType() > PMT_LoopPassManager)
+ PMS.pop();
+
+ LPPassManager *LPPM;
+ if (PMS.top()->getPassManagerType() == PMT_LoopPassManager)
+ LPPM = (LPPassManager*)PMS.top();
+ else {
+ // Create new Loop Pass Manager if it does not exist.
+ assert (!PMS.empty() && "Unable to create Loop Pass Manager");
+ PMDataManager *PMD = PMS.top();
+
+ // [1] Create new Loop Pass Manager
+ LPPM = new LPPassManager();
+ LPPM->populateInheritedAnalysis(PMS);
+
+ // [2] Set up new manager's top level manager
+ PMTopLevelManager *TPM = PMD->getTopLevelManager();
+ TPM->addIndirectPassManager(LPPM);
+
+ // [3] Assign manager to manage this new manager. This may create
+ // and push new managers into PMS
+ Pass *P = LPPM->getAsPass();
+ TPM->schedulePass(P);
+
+ // [4] Push new manager into PMS
+ PMS.push(LPPM);
+ }
+
+ LPPM->add(this);
+}
+
+static std::string getDescription(const Loop &L) {
+ return "loop";
+}
+
+bool LoopPass::skipLoop(const Loop *L) const {
+ const Function *F = L->getHeader()->getParent();
+ if (!F)
+ return false;
+ // Check the opt bisect limit.
+ OptPassGate &Gate = F->getContext().getOptPassGate();
+ if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(*L)))
+ return true;
+ // Check for the OptimizeNone attribute.
+ if (F->hasOptNone()) {
+ // FIXME: Report this to dbgs() only once per function.
+ LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName() << "' in function "
+ << F->getName() << "\n");
+ // FIXME: Delete loop from pass manager's queue?
+ return true;
+ }
+ return false;
+}
+
+char LCSSAVerificationPass::ID = 0;
+INITIALIZE_PASS(LCSSAVerificationPass, "lcssa-verification", "LCSSA Verifier",
+ false, false)
diff --git a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp
new file mode 100644
index 000000000000..762623de41e9
--- /dev/null
+++ b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp
@@ -0,0 +1,214 @@
+//===- LoopUnrollAnalyzer.cpp - Unrolling Effect Estimation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements UnrolledInstAnalyzer class. It's used for predicting
+// potential effects that loop unrolling might have, such as enabling constant
+// propagation and other optimizations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopUnrollAnalyzer.h"
+
+using namespace llvm;
+
+/// Try to simplify instruction \param I using its SCEV expression.
+///
+/// The idea is that some AddRec expressions become constants, which then
+/// could trigger folding of other instructions. However, that only happens
+/// for expressions whose start value is also constant, which isn't always the
+/// case. In another common and important case the start value is just some
+/// address (i.e. SCEVUnknown) - in this case we compute the offset and save
+/// it along with the base address instead.
+bool UnrolledInstAnalyzer::simplifyInstWithSCEV(Instruction *I) {
+ if (!SE.isSCEVable(I->getType()))
+ return false;
+
+ const SCEV *S = SE.getSCEV(I);
+ if (auto *SC = dyn_cast<SCEVConstant>(S)) {
+ SimplifiedValues[I] = SC->getValue();
+ return true;
+ }
+
+ auto *AR = dyn_cast<SCEVAddRecExpr>(S);
+ if (!AR || AR->getLoop() != L)
+ return false;
+
+ const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE);
+ // Check if the AddRec expression becomes a constant.
+ if (auto *SC = dyn_cast<SCEVConstant>(ValueAtIteration)) {
+ SimplifiedValues[I] = SC->getValue();
+ return true;
+ }
+
+ // Check if the offset from the base address becomes a constant.
+ auto *Base = dyn_cast<SCEVUnknown>(SE.getPointerBase(S));
+ if (!Base)
+ return false;
+ auto *Offset =
+ dyn_cast<SCEVConstant>(SE.getMinusSCEV(ValueAtIteration, Base));
+ if (!Offset)
+ return false;
+ SimplifiedAddress Address;
+ Address.Base = Base->getValue();
+ Address.Offset = Offset->getValue();
+ SimplifiedAddresses[I] = Address;
+ return false;
+}
+
+/// Try to simplify binary operator I.
+///
+/// TODO: Probably it's worth to hoist the code for estimating the
+/// simplifications effects to a separate class, since we have a very similar
+/// code in InlineCost already.
+bool UnrolledInstAnalyzer::visitBinaryOperator(BinaryOperator &I) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ if (!isa<Constant>(LHS))
+ if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS))
+ LHS = SimpleLHS;
+ if (!isa<Constant>(RHS))
+ if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS))
+ RHS = SimpleRHS;
+
+ Value *SimpleV = nullptr;
+ const DataLayout &DL = I.getModule()->getDataLayout();
+ if (auto FI = dyn_cast<FPMathOperator>(&I))
+ SimpleV =
+ SimplifyBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL);
+ else
+ SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL);
+
+ if (Constant *C = dyn_cast_or_null<Constant>(SimpleV))
+ SimplifiedValues[&I] = C;
+
+ if (SimpleV)
+ return true;
+ return Base::visitBinaryOperator(I);
+}
+
+/// Try to fold load I.
+bool UnrolledInstAnalyzer::visitLoad(LoadInst &I) {
+ Value *AddrOp = I.getPointerOperand();
+
+ auto AddressIt = SimplifiedAddresses.find(AddrOp);
+ if (AddressIt == SimplifiedAddresses.end())
+ return false;
+ ConstantInt *SimplifiedAddrOp = AddressIt->second.Offset;
+
+ auto *GV = dyn_cast<GlobalVariable>(AddressIt->second.Base);
+ // We're only interested in loads that can be completely folded to a
+ // constant.
+ if (!GV || !GV->hasDefinitiveInitializer() || !GV->isConstant())
+ return false;
+
+ ConstantDataSequential *CDS =
+ dyn_cast<ConstantDataSequential>(GV->getInitializer());
+ if (!CDS)
+ return false;
+
+ // We might have a vector load from an array. FIXME: for now we just bail
+ // out in this case, but we should be able to resolve and simplify such
+ // loads.
+ if (CDS->getElementType() != I.getType())
+ return false;
+
+ unsigned ElemSize = CDS->getElementType()->getPrimitiveSizeInBits() / 8U;
+ if (SimplifiedAddrOp->getValue().getActiveBits() > 64)
+ return false;
+ int64_t SimplifiedAddrOpV = SimplifiedAddrOp->getSExtValue();
+ if (SimplifiedAddrOpV < 0) {
+ // FIXME: For now we conservatively ignore out of bound accesses, but
+ // we're allowed to perform the optimization in this case.
+ return false;
+ }
+ uint64_t Index = static_cast<uint64_t>(SimplifiedAddrOpV) / ElemSize;
+ if (Index >= CDS->getNumElements()) {
+ // FIXME: For now we conservatively ignore out of bound accesses, but
+ // we're allowed to perform the optimization in this case.
+ return false;
+ }
+
+ Constant *CV = CDS->getElementAsConstant(Index);
+ assert(CV && "Constant expected.");
+ SimplifiedValues[&I] = CV;
+
+ return true;
+}
+
+/// Try to simplify cast instruction.
+bool UnrolledInstAnalyzer::visitCastInst(CastInst &I) {
+ // Propagate constants through casts.
+ Constant *COp = dyn_cast<Constant>(I.getOperand(0));
+ if (!COp)
+ COp = SimplifiedValues.lookup(I.getOperand(0));
+
+ // If we know a simplified value for this operand and cast is valid, save the
+ // result to SimplifiedValues.
+ // The cast can be invalid, because SimplifiedValues contains results of SCEV
+ // analysis, which operates on integers (and, e.g., might convert i8* null to
+ // i32 0).
+ if (COp && CastInst::castIsValid(I.getOpcode(), COp, I.getType())) {
+ if (Constant *C =
+ ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) {
+ SimplifiedValues[&I] = C;
+ return true;
+ }
+ }
+
+ return Base::visitCastInst(I);
+}
+
+/// Try to simplify cmp instruction.
+bool UnrolledInstAnalyzer::visitCmpInst(CmpInst &I) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+
+ // First try to handle simplified comparisons.
+ if (!isa<Constant>(LHS))
+ if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS))
+ LHS = SimpleLHS;
+ if (!isa<Constant>(RHS))
+ if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS))
+ RHS = SimpleRHS;
+
+ if (!isa<Constant>(LHS) && !isa<Constant>(RHS)) {
+ auto SimplifiedLHS = SimplifiedAddresses.find(LHS);
+ if (SimplifiedLHS != SimplifiedAddresses.end()) {
+ auto SimplifiedRHS = SimplifiedAddresses.find(RHS);
+ if (SimplifiedRHS != SimplifiedAddresses.end()) {
+ SimplifiedAddress &LHSAddr = SimplifiedLHS->second;
+ SimplifiedAddress &RHSAddr = SimplifiedRHS->second;
+ if (LHSAddr.Base == RHSAddr.Base) {
+ LHS = LHSAddr.Offset;
+ RHS = RHSAddr.Offset;
+ }
+ }
+ }
+ }
+
+ if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
+ if (Constant *CRHS = dyn_cast<Constant>(RHS)) {
+ if (CLHS->getType() == CRHS->getType()) {
+ if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) {
+ SimplifiedValues[&I] = C;
+ return true;
+ }
+ }
+ }
+ }
+
+ return Base::visitCmpInst(I);
+}
+
+bool UnrolledInstAnalyzer::visitPHINode(PHINode &PN) {
+ // Run base visitor first. This way we can gather some useful for later
+ // analysis information.
+ if (Base::visitPHINode(PN))
+ return true;
+
+ // The loop induction PHI nodes are definitionally free.
+ return PN.getParent() == L->getHeader();
+}
diff --git a/llvm/lib/Analysis/MemDepPrinter.cpp b/llvm/lib/Analysis/MemDepPrinter.cpp
new file mode 100644
index 000000000000..6e1bb50e8893
--- /dev/null
+++ b/llvm/lib/Analysis/MemDepPrinter.cpp
@@ -0,0 +1,164 @@
+//===- MemDepPrinter.cpp - Printer for MemoryDependenceAnalysis -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+namespace {
+ struct MemDepPrinter : public FunctionPass {
+ const Function *F;
+
+ enum DepType {
+ Clobber = 0,
+ Def,
+ NonFuncLocal,
+ Unknown
+ };
+
+ static const char *const DepTypeStr[];
+
+ typedef PointerIntPair<const Instruction *, 2, DepType> InstTypePair;
+ typedef std::pair<InstTypePair, const BasicBlock *> Dep;
+ typedef SmallSetVector<Dep, 4> DepSet;
+ typedef DenseMap<const Instruction *, DepSet> DepSetMap;
+ DepSetMap Deps;
+
+ static char ID; // Pass identifcation, replacement for typeid
+ MemDepPrinter() : FunctionPass(ID) {
+ initializeMemDepPrinterPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+ void print(raw_ostream &OS, const Module * = nullptr) const override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequiredTransitive<AAResultsWrapperPass>();
+ AU.addRequiredTransitive<MemoryDependenceWrapperPass>();
+ AU.setPreservesAll();
+ }
+
+ void releaseMemory() override {
+ Deps.clear();
+ F = nullptr;
+ }
+
+ private:
+ static InstTypePair getInstTypePair(MemDepResult dep) {
+ if (dep.isClobber())
+ return InstTypePair(dep.getInst(), Clobber);
+ if (dep.isDef())
+ return InstTypePair(dep.getInst(), Def);
+ if (dep.isNonFuncLocal())
+ return InstTypePair(dep.getInst(), NonFuncLocal);
+ assert(dep.isUnknown() && "unexpected dependence type");
+ return InstTypePair(dep.getInst(), Unknown);
+ }
+ static InstTypePair getInstTypePair(const Instruction* inst, DepType type) {
+ return InstTypePair(inst, type);
+ }
+ };
+}
+
+char MemDepPrinter::ID = 0;
+INITIALIZE_PASS_BEGIN(MemDepPrinter, "print-memdeps",
+ "Print MemDeps of function", false, true)
+INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
+INITIALIZE_PASS_END(MemDepPrinter, "print-memdeps",
+ "Print MemDeps of function", false, true)
+
+FunctionPass *llvm::createMemDepPrinter() {
+ return new MemDepPrinter();
+}
+
+const char *const MemDepPrinter::DepTypeStr[]
+ = {"Clobber", "Def", "NonFuncLocal", "Unknown"};
+
+bool MemDepPrinter::runOnFunction(Function &F) {
+ this->F = &F;
+ MemoryDependenceResults &MDA = getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
+
+ // All this code uses non-const interfaces because MemDep is not
+ // const-friendly, though nothing is actually modified.
+ for (auto &I : instructions(F)) {
+ Instruction *Inst = &I;
+
+ if (!Inst->mayReadFromMemory() && !Inst->mayWriteToMemory())
+ continue;
+
+ MemDepResult Res = MDA.getDependency(Inst);
+ if (!Res.isNonLocal()) {
+ Deps[Inst].insert(std::make_pair(getInstTypePair(Res),
+ static_cast<BasicBlock *>(nullptr)));
+ } else if (auto *Call = dyn_cast<CallBase>(Inst)) {
+ const MemoryDependenceResults::NonLocalDepInfo &NLDI =
+ MDA.getNonLocalCallDependency(Call);
+
+ DepSet &InstDeps = Deps[Inst];
+ for (const NonLocalDepEntry &I : NLDI) {
+ const MemDepResult &Res = I.getResult();
+ InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB()));
+ }
+ } else {
+ SmallVector<NonLocalDepResult, 4> NLDI;
+ assert( (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) ||
+ isa<VAArgInst>(Inst)) && "Unknown memory instruction!");
+ MDA.getNonLocalPointerDependency(Inst, NLDI);
+
+ DepSet &InstDeps = Deps[Inst];
+ for (const NonLocalDepResult &I : NLDI) {
+ const MemDepResult &Res = I.getResult();
+ InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB()));
+ }
+ }
+ }
+
+ return false;
+}
+
+void MemDepPrinter::print(raw_ostream &OS, const Module *M) const {
+ for (const auto &I : instructions(*F)) {
+ const Instruction *Inst = &I;
+
+ DepSetMap::const_iterator DI = Deps.find(Inst);
+ if (DI == Deps.end())
+ continue;
+
+ const DepSet &InstDeps = DI->second;
+
+ for (const auto &I : InstDeps) {
+ const Instruction *DepInst = I.first.getPointer();
+ DepType type = I.first.getInt();
+ const BasicBlock *DepBB = I.second;
+
+ OS << " ";
+ OS << DepTypeStr[type];
+ if (DepBB) {
+ OS << " in block ";
+ DepBB->printAsOperand(OS, /*PrintType=*/false, M);
+ }
+ if (DepInst) {
+ OS << " from: ";
+ DepInst->print(OS);
+ }
+ OS << "\n";
+ }
+
+ Inst->print(OS);
+ OS << "\n\n";
+ }
+}
diff --git a/llvm/lib/Analysis/MemDerefPrinter.cpp b/llvm/lib/Analysis/MemDerefPrinter.cpp
new file mode 100644
index 000000000000..5cf516a538b5
--- /dev/null
+++ b/llvm/lib/Analysis/MemDerefPrinter.cpp
@@ -0,0 +1,76 @@
+//===- MemDerefPrinter.cpp - Printer for isDereferenceablePointer ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+namespace {
+ struct MemDerefPrinter : public FunctionPass {
+ SmallVector<Value *, 4> Deref;
+ SmallPtrSet<Value *, 4> DerefAndAligned;
+
+ static char ID; // Pass identification, replacement for typeid
+ MemDerefPrinter() : FunctionPass(ID) {
+ initializeMemDerefPrinterPass(*PassRegistry::getPassRegistry());
+ }
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ bool runOnFunction(Function &F) override;
+ void print(raw_ostream &OS, const Module * = nullptr) const override;
+ void releaseMemory() override {
+ Deref.clear();
+ DerefAndAligned.clear();
+ }
+ };
+}
+
+char MemDerefPrinter::ID = 0;
+INITIALIZE_PASS_BEGIN(MemDerefPrinter, "print-memderefs",
+ "Memory Dereferenciblity of pointers in function", false, true)
+INITIALIZE_PASS_END(MemDerefPrinter, "print-memderefs",
+ "Memory Dereferenciblity of pointers in function", false, true)
+
+FunctionPass *llvm::createMemDerefPrinter() {
+ return new MemDerefPrinter();
+}
+
+bool MemDerefPrinter::runOnFunction(Function &F) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ for (auto &I: instructions(F)) {
+ if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
+ Value *PO = LI->getPointerOperand();
+ if (isDereferenceablePointer(PO, LI->getType(), DL))
+ Deref.push_back(PO);
+ if (isDereferenceableAndAlignedPointer(
+ PO, LI->getType(), MaybeAlign(LI->getAlignment()), DL))
+ DerefAndAligned.insert(PO);
+ }
+ }
+ return false;
+}
+
+void MemDerefPrinter::print(raw_ostream &OS, const Module *M) const {
+ OS << "The following are dereferenceable:\n";
+ for (Value *V: Deref) {
+ V->print(OS);
+ if (DerefAndAligned.count(V))
+ OS << "\t(aligned)";
+ else
+ OS << "\t(unaligned)";
+ OS << "\n\n";
+ }
+}
diff --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp
new file mode 100644
index 000000000000..172c86eb4646
--- /dev/null
+++ b/llvm/lib/Analysis/MemoryBuiltins.cpp
@@ -0,0 +1,1051 @@
+//===- MemoryBuiltins.cpp - Identify calls to memory builtins -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This family of functions identifies calls to builtin functions that allocate
+// or free memory.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/TargetFolder.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/Utils/Local.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "memory-builtins"
+
+enum AllocType : uint8_t {
+ OpNewLike = 1<<0, // allocates; never returns null
+ MallocLike = 1<<1 | OpNewLike, // allocates; may return null
+ CallocLike = 1<<2, // allocates + bzero
+ ReallocLike = 1<<3, // reallocates
+ StrDupLike = 1<<4,
+ MallocOrCallocLike = MallocLike | CallocLike,
+ AllocLike = MallocLike | CallocLike | StrDupLike,
+ AnyAlloc = AllocLike | ReallocLike
+};
+
+struct AllocFnsTy {
+ AllocType AllocTy;
+ unsigned NumParams;
+ // First and Second size parameters (or -1 if unused)
+ int FstParam, SndParam;
+};
+
+// FIXME: certain users need more information. E.g., SimplifyLibCalls needs to
+// know which functions are nounwind, noalias, nocapture parameters, etc.
+static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = {
+ {LibFunc_malloc, {MallocLike, 1, 0, -1}},
+ {LibFunc_valloc, {MallocLike, 1, 0, -1}},
+ {LibFunc_Znwj, {OpNewLike, 1, 0, -1}}, // new(unsigned int)
+ {LibFunc_ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow)
+ {LibFunc_ZnwjSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new(unsigned int, align_val_t)
+ {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, // new(unsigned int, align_val_t, nothrow)
+ {MallocLike, 3, 0, -1}},
+ {LibFunc_Znwm, {OpNewLike, 1, 0, -1}}, // new(unsigned long)
+ {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned long, nothrow)
+ {LibFunc_ZnwmSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new(unsigned long, align_val_t)
+ {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, // new(unsigned long, align_val_t, nothrow)
+ {MallocLike, 3, 0, -1}},
+ {LibFunc_Znaj, {OpNewLike, 1, 0, -1}}, // new[](unsigned int)
+ {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow)
+ {LibFunc_ZnajSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new[](unsigned int, align_val_t)
+ {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, // new[](unsigned int, align_val_t, nothrow)
+ {MallocLike, 3, 0, -1}},
+ {LibFunc_Znam, {OpNewLike, 1, 0, -1}}, // new[](unsigned long)
+ {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned long, nothrow)
+ {LibFunc_ZnamSt11align_val_t, {OpNewLike, 2, 0, -1}}, // new[](unsigned long, align_val_t)
+ {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, // new[](unsigned long, align_val_t, nothrow)
+ {MallocLike, 3, 0, -1}},
+ {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1}}, // new(unsigned int)
+ {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow)
+ {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1}}, // new(unsigned long long)
+ {LibFunc_msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned long long, nothrow)
+ {LibFunc_msvc_new_array_int, {OpNewLike, 1, 0, -1}}, // new[](unsigned int)
+ {LibFunc_msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow)
+ {LibFunc_msvc_new_array_longlong, {OpNewLike, 1, 0, -1}}, // new[](unsigned long long)
+ {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned long long, nothrow)
+ {LibFunc_calloc, {CallocLike, 2, 0, 1}},
+ {LibFunc_realloc, {ReallocLike, 2, 1, -1}},
+ {LibFunc_reallocf, {ReallocLike, 2, 1, -1}},
+ {LibFunc_strdup, {StrDupLike, 1, -1, -1}},
+ {LibFunc_strndup, {StrDupLike, 2, 1, -1}}
+ // TODO: Handle "int posix_memalign(void **, size_t, size_t)"
+};
+
+static const Function *getCalledFunction(const Value *V, bool LookThroughBitCast,
+ bool &IsNoBuiltin) {
+ // Don't care about intrinsics in this case.
+ if (isa<IntrinsicInst>(V))
+ return nullptr;
+
+ if (LookThroughBitCast)
+ V = V->stripPointerCasts();
+
+ ImmutableCallSite CS(V);
+ if (!CS.getInstruction())
+ return nullptr;
+
+ IsNoBuiltin = CS.isNoBuiltin();
+
+ if (const Function *Callee = CS.getCalledFunction())
+ return Callee;
+ return nullptr;
+}
+
+/// Returns the allocation data for the given value if it's either a call to a
+/// known allocation function, or a call to a function with the allocsize
+/// attribute.
+static Optional<AllocFnsTy>
+getAllocationDataForFunction(const Function *Callee, AllocType AllocTy,
+ const TargetLibraryInfo *TLI) {
+ // Make sure that the function is available.
+ StringRef FnName = Callee->getName();
+ LibFunc TLIFn;
+ if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn))
+ return None;
+
+ const auto *Iter = find_if(
+ AllocationFnData, [TLIFn](const std::pair<LibFunc, AllocFnsTy> &P) {
+ return P.first == TLIFn;
+ });
+
+ if (Iter == std::end(AllocationFnData))
+ return None;
+
+ const AllocFnsTy *FnData = &Iter->second;
+ if ((FnData->AllocTy & AllocTy) != FnData->AllocTy)
+ return None;
+
+ // Check function prototype.
+ int FstParam = FnData->FstParam;
+ int SndParam = FnData->SndParam;
+ FunctionType *FTy = Callee->getFunctionType();
+
+ if (FTy->getReturnType() == Type::getInt8PtrTy(FTy->getContext()) &&
+ FTy->getNumParams() == FnData->NumParams &&
+ (FstParam < 0 ||
+ (FTy->getParamType(FstParam)->isIntegerTy(32) ||
+ FTy->getParamType(FstParam)->isIntegerTy(64))) &&
+ (SndParam < 0 ||
+ FTy->getParamType(SndParam)->isIntegerTy(32) ||
+ FTy->getParamType(SndParam)->isIntegerTy(64)))
+ return *FnData;
+ return None;
+}
+
+static Optional<AllocFnsTy> getAllocationData(const Value *V, AllocType AllocTy,
+ const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast = false) {
+ bool IsNoBuiltinCall;
+ if (const Function *Callee =
+ getCalledFunction(V, LookThroughBitCast, IsNoBuiltinCall))
+ if (!IsNoBuiltinCall)
+ return getAllocationDataForFunction(Callee, AllocTy, TLI);
+ return None;
+}
+
+static Optional<AllocFnsTy>
+getAllocationData(const Value *V, AllocType AllocTy,
+ function_ref<const TargetLibraryInfo &(Function &)> GetTLI,
+ bool LookThroughBitCast = false) {
+ bool IsNoBuiltinCall;
+ if (const Function *Callee =
+ getCalledFunction(V, LookThroughBitCast, IsNoBuiltinCall))
+ if (!IsNoBuiltinCall)
+ return getAllocationDataForFunction(
+ Callee, AllocTy, &GetTLI(const_cast<Function &>(*Callee)));
+ return None;
+}
+
+static Optional<AllocFnsTy> getAllocationSize(const Value *V,
+ const TargetLibraryInfo *TLI) {
+ bool IsNoBuiltinCall;
+ const Function *Callee =
+ getCalledFunction(V, /*LookThroughBitCast=*/false, IsNoBuiltinCall);
+ if (!Callee)
+ return None;
+
+ // Prefer to use existing information over allocsize. This will give us an
+ // accurate AllocTy.
+ if (!IsNoBuiltinCall)
+ if (Optional<AllocFnsTy> Data =
+ getAllocationDataForFunction(Callee, AnyAlloc, TLI))
+ return Data;
+
+ Attribute Attr = Callee->getFnAttribute(Attribute::AllocSize);
+ if (Attr == Attribute())
+ return None;
+
+ std::pair<unsigned, Optional<unsigned>> Args = Attr.getAllocSizeArgs();
+
+ AllocFnsTy Result;
+ // Because allocsize only tells us how many bytes are allocated, we're not
+ // really allowed to assume anything, so we use MallocLike.
+ Result.AllocTy = MallocLike;
+ Result.NumParams = Callee->getNumOperands();
+ Result.FstParam = Args.first;
+ Result.SndParam = Args.second.getValueOr(-1);
+ return Result;
+}
+
+static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) {
+ ImmutableCallSite CS(LookThroughBitCast ? V->stripPointerCasts() : V);
+ return CS && CS.hasRetAttr(Attribute::NoAlias);
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates or reallocates memory (either malloc, calloc, realloc, or strdup
+/// like).
+bool llvm::isAllocationFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, AnyAlloc, TLI, LookThroughBitCast).hasValue();
+}
+bool llvm::isAllocationFn(
+ const Value *V, function_ref<const TargetLibraryInfo &(Function &)> GetTLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, AnyAlloc, GetTLI, LookThroughBitCast).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a function that returns a
+/// NoAlias pointer (including malloc/calloc/realloc/strdup-like functions).
+bool llvm::isNoAliasFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ // it's safe to consider realloc as noalias since accessing the original
+ // pointer is undefined behavior
+ return isAllocationFn(V, TLI, LookThroughBitCast) ||
+ hasNoAliasAttr(V, LookThroughBitCast);
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates uninitialized memory (such as malloc).
+bool llvm::isMallocLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, MallocLike, TLI, LookThroughBitCast).hasValue();
+}
+bool llvm::isMallocLikeFn(
+ const Value *V, function_ref<const TargetLibraryInfo &(Function &)> GetTLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, MallocLike, GetTLI, LookThroughBitCast)
+ .hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates zero-filled memory (such as calloc).
+bool llvm::isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, CallocLike, TLI, LookThroughBitCast).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates memory similar to malloc or calloc.
+bool llvm::isMallocOrCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, MallocOrCallocLike, TLI,
+ LookThroughBitCast).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates memory (either malloc, calloc, or strdup like).
+bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, AllocLike, TLI, LookThroughBitCast).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// reallocates memory (e.g., realloc).
+bool llvm::isReallocLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, ReallocLike, TLI, LookThroughBitCast).hasValue();
+}
+
+/// Tests if a functions is a call or invoke to a library function that
+/// reallocates memory (e.g., realloc).
+bool llvm::isReallocLikeFn(const Function *F, const TargetLibraryInfo *TLI) {
+ return getAllocationDataForFunction(F, ReallocLike, TLI).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates memory and throws if an allocation failed (e.g., new).
+bool llvm::isOpNewLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, OpNewLike, TLI, LookThroughBitCast).hasValue();
+}
+
+/// Tests if a value is a call or invoke to a library function that
+/// allocates memory (strdup, strndup).
+bool llvm::isStrdupLikeFn(const Value *V, const TargetLibraryInfo *TLI,
+ bool LookThroughBitCast) {
+ return getAllocationData(V, StrDupLike, TLI, LookThroughBitCast).hasValue();
+}
+
+/// extractMallocCall - Returns the corresponding CallInst if the instruction
+/// is a malloc call. Since CallInst::CreateMalloc() only creates calls, we
+/// ignore InvokeInst here.
+const CallInst *llvm::extractMallocCall(
+ const Value *I,
+ function_ref<const TargetLibraryInfo &(Function &)> GetTLI) {
+ return isMallocLikeFn(I, GetTLI) ? dyn_cast<CallInst>(I) : nullptr;
+}
+
+static Value *computeArraySize(const CallInst *CI, const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ bool LookThroughSExt = false) {
+ if (!CI)
+ return nullptr;
+
+ // The size of the malloc's result type must be known to determine array size.
+ Type *T = getMallocAllocatedType(CI, TLI);
+ if (!T || !T->isSized())
+ return nullptr;
+
+ unsigned ElementSize = DL.getTypeAllocSize(T);
+ if (StructType *ST = dyn_cast<StructType>(T))
+ ElementSize = DL.getStructLayout(ST)->getSizeInBytes();
+
+ // If malloc call's arg can be determined to be a multiple of ElementSize,
+ // return the multiple. Otherwise, return NULL.
+ Value *MallocArg = CI->getArgOperand(0);
+ Value *Multiple = nullptr;
+ if (ComputeMultiple(MallocArg, ElementSize, Multiple, LookThroughSExt))
+ return Multiple;
+
+ return nullptr;
+}
+
+/// getMallocType - Returns the PointerType resulting from the malloc call.
+/// The PointerType depends on the number of bitcast uses of the malloc call:
+/// 0: PointerType is the calls' return type.
+/// 1: PointerType is the bitcast's result type.
+/// >1: Unique PointerType cannot be determined, return NULL.
+PointerType *llvm::getMallocType(const CallInst *CI,
+ const TargetLibraryInfo *TLI) {
+ assert(isMallocLikeFn(CI, TLI) && "getMallocType and not malloc call");
+
+ PointerType *MallocType = nullptr;
+ unsigned NumOfBitCastUses = 0;
+
+ // Determine if CallInst has a bitcast use.
+ for (Value::const_user_iterator UI = CI->user_begin(), E = CI->user_end();
+ UI != E;)
+ if (const BitCastInst *BCI = dyn_cast<BitCastInst>(*UI++)) {
+ MallocType = cast<PointerType>(BCI->getDestTy());
+ NumOfBitCastUses++;
+ }
+
+ // Malloc call has 1 bitcast use, so type is the bitcast's destination type.
+ if (NumOfBitCastUses == 1)
+ return MallocType;
+
+ // Malloc call was not bitcast, so type is the malloc function's return type.
+ if (NumOfBitCastUses == 0)
+ return cast<PointerType>(CI->getType());
+
+ // Type could not be determined.
+ return nullptr;
+}
+
+/// getMallocAllocatedType - Returns the Type allocated by malloc call.
+/// The Type depends on the number of bitcast uses of the malloc call:
+/// 0: PointerType is the malloc calls' return type.
+/// 1: PointerType is the bitcast's result type.
+/// >1: Unique PointerType cannot be determined, return NULL.
+Type *llvm::getMallocAllocatedType(const CallInst *CI,
+ const TargetLibraryInfo *TLI) {
+ PointerType *PT = getMallocType(CI, TLI);
+ return PT ? PT->getElementType() : nullptr;
+}
+
+/// getMallocArraySize - Returns the array size of a malloc call. If the
+/// argument passed to malloc is a multiple of the size of the malloced type,
+/// then return that multiple. For non-array mallocs, the multiple is
+/// constant 1. Otherwise, return NULL for mallocs whose array size cannot be
+/// determined.
+Value *llvm::getMallocArraySize(CallInst *CI, const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ bool LookThroughSExt) {
+ assert(isMallocLikeFn(CI, TLI) && "getMallocArraySize and not malloc call");
+ return computeArraySize(CI, DL, TLI, LookThroughSExt);
+}
+
+/// extractCallocCall - Returns the corresponding CallInst if the instruction
+/// is a calloc call.
+const CallInst *llvm::extractCallocCall(const Value *I,
+ const TargetLibraryInfo *TLI) {
+ return isCallocLikeFn(I, TLI) ? cast<CallInst>(I) : nullptr;
+}
+
+/// isLibFreeFunction - Returns true if the function is a builtin free()
+bool llvm::isLibFreeFunction(const Function *F, const LibFunc TLIFn) {
+ unsigned ExpectedNumParams;
+ if (TLIFn == LibFunc_free ||
+ TLIFn == LibFunc_ZdlPv || // operator delete(void*)
+ TLIFn == LibFunc_ZdaPv || // operator delete[](void*)
+ TLIFn == LibFunc_msvc_delete_ptr32 || // operator delete(void*)
+ TLIFn == LibFunc_msvc_delete_ptr64 || // operator delete(void*)
+ TLIFn == LibFunc_msvc_delete_array_ptr32 || // operator delete[](void*)
+ TLIFn == LibFunc_msvc_delete_array_ptr64) // operator delete[](void*)
+ ExpectedNumParams = 1;
+ else if (TLIFn == LibFunc_ZdlPvj || // delete(void*, uint)
+ TLIFn == LibFunc_ZdlPvm || // delete(void*, ulong)
+ TLIFn == LibFunc_ZdlPvRKSt9nothrow_t || // delete(void*, nothrow)
+ TLIFn == LibFunc_ZdlPvSt11align_val_t || // delete(void*, align_val_t)
+ TLIFn == LibFunc_ZdaPvj || // delete[](void*, uint)
+ TLIFn == LibFunc_ZdaPvm || // delete[](void*, ulong)
+ TLIFn == LibFunc_ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow)
+ TLIFn == LibFunc_ZdaPvSt11align_val_t || // delete[](void*, align_val_t)
+ TLIFn == LibFunc_msvc_delete_ptr32_int || // delete(void*, uint)
+ TLIFn == LibFunc_msvc_delete_ptr64_longlong || // delete(void*, ulonglong)
+ TLIFn == LibFunc_msvc_delete_ptr32_nothrow || // delete(void*, nothrow)
+ TLIFn == LibFunc_msvc_delete_ptr64_nothrow || // delete(void*, nothrow)
+ TLIFn == LibFunc_msvc_delete_array_ptr32_int || // delete[](void*, uint)
+ TLIFn == LibFunc_msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong)
+ TLIFn == LibFunc_msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow)
+ TLIFn == LibFunc_msvc_delete_array_ptr64_nothrow) // delete[](void*, nothrow)
+ ExpectedNumParams = 2;
+ else if (TLIFn == LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t || // delete(void*, align_val_t, nothrow)
+ TLIFn == LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t) // delete[](void*, align_val_t, nothrow)
+ ExpectedNumParams = 3;
+ else
+ return false;
+
+ // Check free prototype.
+ // FIXME: workaround for PR5130, this will be obsolete when a nobuiltin
+ // attribute will exist.
+ FunctionType *FTy = F->getFunctionType();
+ if (!FTy->getReturnType()->isVoidTy())
+ return false;
+ if (FTy->getNumParams() != ExpectedNumParams)
+ return false;
+ if (FTy->getParamType(0) != Type::getInt8PtrTy(F->getContext()))
+ return false;
+
+ return true;
+}
+
+/// isFreeCall - Returns non-null if the value is a call to the builtin free()
+const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) {
+ bool IsNoBuiltinCall;
+ const Function *Callee =
+ getCalledFunction(I, /*LookThroughBitCast=*/false, IsNoBuiltinCall);
+ if (Callee == nullptr || IsNoBuiltinCall)
+ return nullptr;
+
+ StringRef FnName = Callee->getName();
+ LibFunc TLIFn;
+ if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn))
+ return nullptr;
+
+ return isLibFreeFunction(Callee, TLIFn) ? dyn_cast<CallInst>(I) : nullptr;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Utility functions to compute size of objects.
+//
+static APInt getSizeWithOverflow(const SizeOffsetType &Data) {
+ if (Data.second.isNegative() || Data.first.ult(Data.second))
+ return APInt(Data.first.getBitWidth(), 0);
+ return Data.first - Data.second;
+}
+
+/// Compute the size of the object pointed by Ptr. Returns true and the
+/// object size in Size if successful, and false otherwise.
+/// If RoundToAlign is true, then Size is rounded up to the alignment of
+/// allocas, byval arguments, and global variables.
+bool llvm::getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL,
+ const TargetLibraryInfo *TLI, ObjectSizeOpts Opts) {
+ ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), Opts);
+ SizeOffsetType Data = Visitor.compute(const_cast<Value*>(Ptr));
+ if (!Visitor.bothKnown(Data))
+ return false;
+
+ Size = getSizeWithOverflow(Data).getZExtValue();
+ return true;
+}
+
+Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ bool MustSucceed) {
+ assert(ObjectSize->getIntrinsicID() == Intrinsic::objectsize &&
+ "ObjectSize must be a call to llvm.objectsize!");
+
+ bool MaxVal = cast<ConstantInt>(ObjectSize->getArgOperand(1))->isZero();
+ ObjectSizeOpts EvalOptions;
+ // Unless we have to fold this to something, try to be as accurate as
+ // possible.
+ if (MustSucceed)
+ EvalOptions.EvalMode =
+ MaxVal ? ObjectSizeOpts::Mode::Max : ObjectSizeOpts::Mode::Min;
+ else
+ EvalOptions.EvalMode = ObjectSizeOpts::Mode::Exact;
+
+ EvalOptions.NullIsUnknownSize =
+ cast<ConstantInt>(ObjectSize->getArgOperand(2))->isOne();
+
+ auto *ResultType = cast<IntegerType>(ObjectSize->getType());
+ bool StaticOnly = cast<ConstantInt>(ObjectSize->getArgOperand(3))->isZero();
+ if (StaticOnly) {
+ // FIXME: Does it make sense to just return a failure value if the size won't
+ // fit in the output and `!MustSucceed`?
+ uint64_t Size;
+ if (getObjectSize(ObjectSize->getArgOperand(0), Size, DL, TLI, EvalOptions) &&
+ isUIntN(ResultType->getBitWidth(), Size))
+ return ConstantInt::get(ResultType, Size);
+ } else {
+ LLVMContext &Ctx = ObjectSize->getFunction()->getContext();
+ ObjectSizeOffsetEvaluator Eval(DL, TLI, Ctx, EvalOptions);
+ SizeOffsetEvalType SizeOffsetPair =
+ Eval.compute(ObjectSize->getArgOperand(0));
+
+ if (SizeOffsetPair != ObjectSizeOffsetEvaluator::unknown()) {
+ IRBuilder<TargetFolder> Builder(Ctx, TargetFolder(DL));
+ Builder.SetInsertPoint(ObjectSize);
+
+ // If we've outside the end of the object, then we can always access
+ // exactly 0 bytes.
+ Value *ResultSize =
+ Builder.CreateSub(SizeOffsetPair.first, SizeOffsetPair.second);
+ Value *UseZero =
+ Builder.CreateICmpULT(SizeOffsetPair.first, SizeOffsetPair.second);
+ return Builder.CreateSelect(UseZero, ConstantInt::get(ResultType, 0),
+ ResultSize);
+ }
+ }
+
+ if (!MustSucceed)
+ return nullptr;
+
+ return ConstantInt::get(ResultType, MaxVal ? -1ULL : 0);
+}
+
+STATISTIC(ObjectVisitorArgument,
+ "Number of arguments with unsolved size and offset");
+STATISTIC(ObjectVisitorLoad,
+ "Number of load instructions with unsolved size and offset");
+
+APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Alignment) {
+ if (Options.RoundToAlign && Alignment)
+ return APInt(IntTyBits, alignTo(Size.getZExtValue(), Align(Alignment)));
+ return Size;
+}
+
+ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL,
+ const TargetLibraryInfo *TLI,
+ LLVMContext &Context,
+ ObjectSizeOpts Options)
+ : DL(DL), TLI(TLI), Options(Options) {
+ // Pointer size must be rechecked for each object visited since it could have
+ // a different address space.
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) {
+ IntTyBits = DL.getPointerTypeSizeInBits(V->getType());
+ Zero = APInt::getNullValue(IntTyBits);
+
+ V = V->stripPointerCasts();
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ // If we have already seen this instruction, bail out. Cycles can happen in
+ // unreachable code after constant propagation.
+ if (!SeenInsts.insert(I).second)
+ return unknown();
+
+ if (GEPOperator *GEP = dyn_cast<GEPOperator>(V))
+ return visitGEPOperator(*GEP);
+ return visit(*I);
+ }
+ if (Argument *A = dyn_cast<Argument>(V))
+ return visitArgument(*A);
+ if (ConstantPointerNull *P = dyn_cast<ConstantPointerNull>(V))
+ return visitConstantPointerNull(*P);
+ if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
+ return visitGlobalAlias(*GA);
+ if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+ return visitGlobalVariable(*GV);
+ if (UndefValue *UV = dyn_cast<UndefValue>(V))
+ return visitUndefValue(*UV);
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
+ if (CE->getOpcode() == Instruction::IntToPtr)
+ return unknown(); // clueless
+ if (CE->getOpcode() == Instruction::GetElementPtr)
+ return visitGEPOperator(cast<GEPOperator>(*CE));
+ }
+
+ LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor::compute() unhandled value: "
+ << *V << '\n');
+ return unknown();
+}
+
+/// When we're compiling N-bit code, and the user uses parameters that are
+/// greater than N bits (e.g. uint64_t on a 32-bit build), we can run into
+/// trouble with APInt size issues. This function handles resizing + overflow
+/// checks for us. Check and zext or trunc \p I depending on IntTyBits and
+/// I's value.
+bool ObjectSizeOffsetVisitor::CheckedZextOrTrunc(APInt &I) {
+ // More bits than we can handle. Checking the bit width isn't necessary, but
+ // it's faster than checking active bits, and should give `false` in the
+ // vast majority of cases.
+ if (I.getBitWidth() > IntTyBits && I.getActiveBits() > IntTyBits)
+ return false;
+ if (I.getBitWidth() != IntTyBits)
+ I = I.zextOrTrunc(IntTyBits);
+ return true;
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) {
+ if (!I.getAllocatedType()->isSized())
+ return unknown();
+
+ APInt Size(IntTyBits, DL.getTypeAllocSize(I.getAllocatedType()));
+ if (!I.isArrayAllocation())
+ return std::make_pair(align(Size, I.getAlignment()), Zero);
+
+ Value *ArraySize = I.getArraySize();
+ if (const ConstantInt *C = dyn_cast<ConstantInt>(ArraySize)) {
+ APInt NumElems = C->getValue();
+ if (!CheckedZextOrTrunc(NumElems))
+ return unknown();
+
+ bool Overflow;
+ Size = Size.umul_ov(NumElems, Overflow);
+ return Overflow ? unknown() : std::make_pair(align(Size, I.getAlignment()),
+ Zero);
+ }
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitArgument(Argument &A) {
+ // No interprocedural analysis is done at the moment.
+ if (!A.hasByValOrInAllocaAttr()) {
+ ++ObjectVisitorArgument;
+ return unknown();
+ }
+ PointerType *PT = cast<PointerType>(A.getType());
+ APInt Size(IntTyBits, DL.getTypeAllocSize(PT->getElementType()));
+ return std::make_pair(align(Size, A.getParamAlignment()), Zero);
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) {
+ Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI);
+ if (!FnData)
+ return unknown();
+
+ // Handle strdup-like functions separately.
+ if (FnData->AllocTy == StrDupLike) {
+ APInt Size(IntTyBits, GetStringLength(CS.getArgument(0)));
+ if (!Size)
+ return unknown();
+
+ // Strndup limits strlen.
+ if (FnData->FstParam > 0) {
+ ConstantInt *Arg =
+ dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam));
+ if (!Arg)
+ return unknown();
+
+ APInt MaxSize = Arg->getValue().zextOrSelf(IntTyBits);
+ if (Size.ugt(MaxSize))
+ Size = MaxSize + 1;
+ }
+ return std::make_pair(Size, Zero);
+ }
+
+ ConstantInt *Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam));
+ if (!Arg)
+ return unknown();
+
+ APInt Size = Arg->getValue();
+ if (!CheckedZextOrTrunc(Size))
+ return unknown();
+
+ // Size is determined by just 1 parameter.
+ if (FnData->SndParam < 0)
+ return std::make_pair(Size, Zero);
+
+ Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->SndParam));
+ if (!Arg)
+ return unknown();
+
+ APInt NumElems = Arg->getValue();
+ if (!CheckedZextOrTrunc(NumElems))
+ return unknown();
+
+ bool Overflow;
+ Size = Size.umul_ov(NumElems, Overflow);
+ return Overflow ? unknown() : std::make_pair(Size, Zero);
+
+ // TODO: handle more standard functions (+ wchar cousins):
+ // - strdup / strndup
+ // - strcpy / strncpy
+ // - strcat / strncat
+ // - memcpy / memmove
+ // - strcat / strncat
+ // - memset
+}
+
+SizeOffsetType
+ObjectSizeOffsetVisitor::visitConstantPointerNull(ConstantPointerNull& CPN) {
+ // If null is unknown, there's nothing we can do. Additionally, non-zero
+ // address spaces can make use of null, so we don't presume to know anything
+ // about that.
+ //
+ // TODO: How should this work with address space casts? We currently just drop
+ // them on the floor, but it's unclear what we should do when a NULL from
+ // addrspace(1) gets casted to addrspace(0) (or vice-versa).
+ if (Options.NullIsUnknownSize || CPN.getType()->getAddressSpace())
+ return unknown();
+ return std::make_pair(Zero, Zero);
+}
+
+SizeOffsetType
+ObjectSizeOffsetVisitor::visitExtractElementInst(ExtractElementInst&) {
+ return unknown();
+}
+
+SizeOffsetType
+ObjectSizeOffsetVisitor::visitExtractValueInst(ExtractValueInst&) {
+ // Easy cases were already folded by previous passes.
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitGEPOperator(GEPOperator &GEP) {
+ SizeOffsetType PtrData = compute(GEP.getPointerOperand());
+ APInt Offset(IntTyBits, 0);
+ if (!bothKnown(PtrData) || !GEP.accumulateConstantOffset(DL, Offset))
+ return unknown();
+
+ return std::make_pair(PtrData.first, PtrData.second + Offset);
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) {
+ if (GA.isInterposable())
+ return unknown();
+ return compute(GA.getAliasee());
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalVariable(GlobalVariable &GV){
+ if (!GV.hasDefinitiveInitializer())
+ return unknown();
+
+ APInt Size(IntTyBits, DL.getTypeAllocSize(GV.getValueType()));
+ return std::make_pair(align(Size, GV.getAlignment()), Zero);
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitIntToPtrInst(IntToPtrInst&) {
+ // clueless
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitLoadInst(LoadInst&) {
+ ++ObjectVisitorLoad;
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode&) {
+ // too complex to analyze statically.
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) {
+ SizeOffsetType TrueSide = compute(I.getTrueValue());
+ SizeOffsetType FalseSide = compute(I.getFalseValue());
+ if (bothKnown(TrueSide) && bothKnown(FalseSide)) {
+ if (TrueSide == FalseSide) {
+ return TrueSide;
+ }
+
+ APInt TrueResult = getSizeWithOverflow(TrueSide);
+ APInt FalseResult = getSizeWithOverflow(FalseSide);
+
+ if (TrueResult == FalseResult) {
+ return TrueSide;
+ }
+ if (Options.EvalMode == ObjectSizeOpts::Mode::Min) {
+ if (TrueResult.slt(FalseResult))
+ return TrueSide;
+ return FalseSide;
+ }
+ if (Options.EvalMode == ObjectSizeOpts::Mode::Max) {
+ if (TrueResult.sgt(FalseResult))
+ return TrueSide;
+ return FalseSide;
+ }
+ }
+ return unknown();
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitUndefValue(UndefValue&) {
+ return std::make_pair(Zero, Zero);
+}
+
+SizeOffsetType ObjectSizeOffsetVisitor::visitInstruction(Instruction &I) {
+ LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor unknown instruction:" << I
+ << '\n');
+ return unknown();
+}
+
+ObjectSizeOffsetEvaluator::ObjectSizeOffsetEvaluator(
+ const DataLayout &DL, const TargetLibraryInfo *TLI, LLVMContext &Context,
+ ObjectSizeOpts EvalOpts)
+ : DL(DL), TLI(TLI), Context(Context),
+ Builder(Context, TargetFolder(DL),
+ IRBuilderCallbackInserter(
+ [&](Instruction *I) { InsertedInstructions.insert(I); })),
+ EvalOpts(EvalOpts) {
+ // IntTy and Zero must be set for each compute() since the address space may
+ // be different for later objects.
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) {
+ // XXX - Are vectors of pointers possible here?
+ IntTy = cast<IntegerType>(DL.getIntPtrType(V->getType()));
+ Zero = ConstantInt::get(IntTy, 0);
+
+ SizeOffsetEvalType Result = compute_(V);
+
+ if (!bothKnown(Result)) {
+ // Erase everything that was computed in this iteration from the cache, so
+ // that no dangling references are left behind. We could be a bit smarter if
+ // we kept a dependency graph. It's probably not worth the complexity.
+ for (const Value *SeenVal : SeenVals) {
+ CacheMapTy::iterator CacheIt = CacheMap.find(SeenVal);
+ // non-computable results can be safely cached
+ if (CacheIt != CacheMap.end() && anyKnown(CacheIt->second))
+ CacheMap.erase(CacheIt);
+ }
+
+ // Erase any instructions we inserted as part of the traversal.
+ for (Instruction *I : InsertedInstructions) {
+ I->replaceAllUsesWith(UndefValue::get(I->getType()));
+ I->eraseFromParent();
+ }
+ }
+
+ SeenVals.clear();
+ InsertedInstructions.clear();
+ return Result;
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) {
+ ObjectSizeOffsetVisitor Visitor(DL, TLI, Context, EvalOpts);
+ SizeOffsetType Const = Visitor.compute(V);
+ if (Visitor.bothKnown(Const))
+ return std::make_pair(ConstantInt::get(Context, Const.first),
+ ConstantInt::get(Context, Const.second));
+
+ V = V->stripPointerCasts();
+
+ // Check cache.
+ CacheMapTy::iterator CacheIt = CacheMap.find(V);
+ if (CacheIt != CacheMap.end())
+ return CacheIt->second;
+
+ // Always generate code immediately before the instruction being
+ // processed, so that the generated code dominates the same BBs.
+ BuilderTy::InsertPointGuard Guard(Builder);
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ Builder.SetInsertPoint(I);
+
+ // Now compute the size and offset.
+ SizeOffsetEvalType Result;
+
+ // Record the pointers that were handled in this run, so that they can be
+ // cleaned later if something fails. We also use this set to break cycles that
+ // can occur in dead code.
+ if (!SeenVals.insert(V).second) {
+ Result = unknown();
+ } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+ Result = visitGEPOperator(*GEP);
+ } else if (Instruction *I = dyn_cast<Instruction>(V)) {
+ Result = visit(*I);
+ } else if (isa<Argument>(V) ||
+ (isa<ConstantExpr>(V) &&
+ cast<ConstantExpr>(V)->getOpcode() == Instruction::IntToPtr) ||
+ isa<GlobalAlias>(V) ||
+ isa<GlobalVariable>(V)) {
+ // Ignore values where we cannot do more than ObjectSizeVisitor.
+ Result = unknown();
+ } else {
+ LLVM_DEBUG(
+ dbgs() << "ObjectSizeOffsetEvaluator::compute() unhandled value: " << *V
+ << '\n');
+ Result = unknown();
+ }
+
+ // Don't reuse CacheIt since it may be invalid at this point.
+ CacheMap[V] = Result;
+ return Result;
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitAllocaInst(AllocaInst &I) {
+ if (!I.getAllocatedType()->isSized())
+ return unknown();
+
+ // must be a VLA
+ assert(I.isArrayAllocation());
+ Value *ArraySize = I.getArraySize();
+ Value *Size = ConstantInt::get(ArraySize->getType(),
+ DL.getTypeAllocSize(I.getAllocatedType()));
+ Size = Builder.CreateMul(Size, ArraySize);
+ return std::make_pair(Size, Zero);
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitCallSite(CallSite CS) {
+ Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI);
+ if (!FnData)
+ return unknown();
+
+ // Handle strdup-like functions separately.
+ if (FnData->AllocTy == StrDupLike) {
+ // TODO
+ return unknown();
+ }
+
+ Value *FirstArg = CS.getArgument(FnData->FstParam);
+ FirstArg = Builder.CreateZExt(FirstArg, IntTy);
+ if (FnData->SndParam < 0)
+ return std::make_pair(FirstArg, Zero);
+
+ Value *SecondArg = CS.getArgument(FnData->SndParam);
+ SecondArg = Builder.CreateZExt(SecondArg, IntTy);
+ Value *Size = Builder.CreateMul(FirstArg, SecondArg);
+ return std::make_pair(Size, Zero);
+
+ // TODO: handle more standard functions (+ wchar cousins):
+ // - strdup / strndup
+ // - strcpy / strncpy
+ // - strcat / strncat
+ // - memcpy / memmove
+ // - strcat / strncat
+ // - memset
+}
+
+SizeOffsetEvalType
+ObjectSizeOffsetEvaluator::visitExtractElementInst(ExtractElementInst&) {
+ return unknown();
+}
+
+SizeOffsetEvalType
+ObjectSizeOffsetEvaluator::visitExtractValueInst(ExtractValueInst&) {
+ return unknown();
+}
+
+SizeOffsetEvalType
+ObjectSizeOffsetEvaluator::visitGEPOperator(GEPOperator &GEP) {
+ SizeOffsetEvalType PtrData = compute_(GEP.getPointerOperand());
+ if (!bothKnown(PtrData))
+ return unknown();
+
+ Value *Offset = EmitGEPOffset(&Builder, DL, &GEP, /*NoAssumptions=*/true);
+ Offset = Builder.CreateAdd(PtrData.second, Offset);
+ return std::make_pair(PtrData.first, Offset);
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitIntToPtrInst(IntToPtrInst&) {
+ // clueless
+ return unknown();
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst&) {
+ return unknown();
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitPHINode(PHINode &PHI) {
+ // Create 2 PHIs: one for size and another for offset.
+ PHINode *SizePHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues());
+ PHINode *OffsetPHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues());
+
+ // Insert right away in the cache to handle recursive PHIs.
+ CacheMap[&PHI] = std::make_pair(SizePHI, OffsetPHI);
+
+ // Compute offset/size for each PHI incoming pointer.
+ for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) {
+ Builder.SetInsertPoint(&*PHI.getIncomingBlock(i)->getFirstInsertionPt());
+ SizeOffsetEvalType EdgeData = compute_(PHI.getIncomingValue(i));
+
+ if (!bothKnown(EdgeData)) {
+ OffsetPHI->replaceAllUsesWith(UndefValue::get(IntTy));
+ OffsetPHI->eraseFromParent();
+ InsertedInstructions.erase(OffsetPHI);
+ SizePHI->replaceAllUsesWith(UndefValue::get(IntTy));
+ SizePHI->eraseFromParent();
+ InsertedInstructions.erase(SizePHI);
+ return unknown();
+ }
+ SizePHI->addIncoming(EdgeData.first, PHI.getIncomingBlock(i));
+ OffsetPHI->addIncoming(EdgeData.second, PHI.getIncomingBlock(i));
+ }
+
+ Value *Size = SizePHI, *Offset = OffsetPHI;
+ if (Value *Tmp = SizePHI->hasConstantValue()) {
+ Size = Tmp;
+ SizePHI->replaceAllUsesWith(Size);
+ SizePHI->eraseFromParent();
+ InsertedInstructions.erase(SizePHI);
+ }
+ if (Value *Tmp = OffsetPHI->hasConstantValue()) {
+ Offset = Tmp;
+ OffsetPHI->replaceAllUsesWith(Offset);
+ OffsetPHI->eraseFromParent();
+ InsertedInstructions.erase(OffsetPHI);
+ }
+ return std::make_pair(Size, Offset);
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitSelectInst(SelectInst &I) {
+ SizeOffsetEvalType TrueSide = compute_(I.getTrueValue());
+ SizeOffsetEvalType FalseSide = compute_(I.getFalseValue());
+
+ if (!bothKnown(TrueSide) || !bothKnown(FalseSide))
+ return unknown();
+ if (TrueSide == FalseSide)
+ return TrueSide;
+
+ Value *Size = Builder.CreateSelect(I.getCondition(), TrueSide.first,
+ FalseSide.first);
+ Value *Offset = Builder.CreateSelect(I.getCondition(), TrueSide.second,
+ FalseSide.second);
+ return std::make_pair(Size, Offset);
+}
+
+SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitInstruction(Instruction &I) {
+ LLVM_DEBUG(dbgs() << "ObjectSizeOffsetEvaluator unknown instruction:" << I
+ << '\n');
+ return unknown();
+}
diff --git a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp
new file mode 100644
index 000000000000..884587e020bb
--- /dev/null
+++ b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp
@@ -0,0 +1,1824 @@
+//===- MemoryDependenceAnalysis.cpp - Mem Deps Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an analysis that determines, for a given memory
+// operation, what preceding memory operations it depends on. It builds on
+// alias analysis information, and tries to provide a lazy, caching interface to
+// a common kind of alias information query.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/OrderedBasicBlock.h"
+#include "llvm/Analysis/PHITransAddr.h"
+#include "llvm/Analysis/PhiValues.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PredIteratorCache.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "memdep"
+
+STATISTIC(NumCacheNonLocal, "Number of fully cached non-local responses");
+STATISTIC(NumCacheDirtyNonLocal, "Number of dirty cached non-local responses");
+STATISTIC(NumUncacheNonLocal, "Number of uncached non-local responses");
+
+STATISTIC(NumCacheNonLocalPtr,
+ "Number of fully cached non-local ptr responses");
+STATISTIC(NumCacheDirtyNonLocalPtr,
+ "Number of cached, but dirty, non-local ptr responses");
+STATISTIC(NumUncacheNonLocalPtr, "Number of uncached non-local ptr responses");
+STATISTIC(NumCacheCompleteNonLocalPtr,
+ "Number of block queries that were completely cached");
+
+// Limit for the number of instructions to scan in a block.
+
+static cl::opt<unsigned> BlockScanLimit(
+ "memdep-block-scan-limit", cl::Hidden, cl::init(100),
+ cl::desc("The number of instructions to scan in a block in memory "
+ "dependency analysis (default = 100)"));
+
+static cl::opt<unsigned>
+ BlockNumberLimit("memdep-block-number-limit", cl::Hidden, cl::init(1000),
+ cl::desc("The number of blocks to scan during memory "
+ "dependency analysis (default = 1000)"));
+
+// Limit on the number of memdep results to process.
+static const unsigned int NumResultsLimit = 100;
+
+/// This is a helper function that removes Val from 'Inst's set in ReverseMap.
+///
+/// If the set becomes empty, remove Inst's entry.
+template <typename KeyTy>
+static void
+RemoveFromReverseMap(DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>> &ReverseMap,
+ Instruction *Inst, KeyTy Val) {
+ typename DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>>::iterator InstIt =
+ ReverseMap.find(Inst);
+ assert(InstIt != ReverseMap.end() && "Reverse map out of sync?");
+ bool Found = InstIt->second.erase(Val);
+ assert(Found && "Invalid reverse map!");
+ (void)Found;
+ if (InstIt->second.empty())
+ ReverseMap.erase(InstIt);
+}
+
+/// If the given instruction references a specific memory location, fill in Loc
+/// with the details, otherwise set Loc.Ptr to null.
+///
+/// Returns a ModRefInfo value describing the general behavior of the
+/// instruction.
+static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc,
+ const TargetLibraryInfo &TLI) {
+ if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
+ if (LI->isUnordered()) {
+ Loc = MemoryLocation::get(LI);
+ return ModRefInfo::Ref;
+ }
+ if (LI->getOrdering() == AtomicOrdering::Monotonic) {
+ Loc = MemoryLocation::get(LI);
+ return ModRefInfo::ModRef;
+ }
+ Loc = MemoryLocation();
+ return ModRefInfo::ModRef;
+ }
+
+ if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+ if (SI->isUnordered()) {
+ Loc = MemoryLocation::get(SI);
+ return ModRefInfo::Mod;
+ }
+ if (SI->getOrdering() == AtomicOrdering::Monotonic) {
+ Loc = MemoryLocation::get(SI);
+ return ModRefInfo::ModRef;
+ }
+ Loc = MemoryLocation();
+ return ModRefInfo::ModRef;
+ }
+
+ if (const VAArgInst *V = dyn_cast<VAArgInst>(Inst)) {
+ Loc = MemoryLocation::get(V);
+ return ModRefInfo::ModRef;
+ }
+
+ if (const CallInst *CI = isFreeCall(Inst, &TLI)) {
+ // calls to free() deallocate the entire structure
+ Loc = MemoryLocation(CI->getArgOperand(0));
+ return ModRefInfo::Mod;
+ }
+
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ case Intrinsic::invariant_start:
+ Loc = MemoryLocation::getForArgument(II, 1, TLI);
+ // These intrinsics don't really modify the memory, but returning Mod
+ // will allow them to be handled conservatively.
+ return ModRefInfo::Mod;
+ case Intrinsic::invariant_end:
+ Loc = MemoryLocation::getForArgument(II, 2, TLI);
+ // These intrinsics don't really modify the memory, but returning Mod
+ // will allow them to be handled conservatively.
+ return ModRefInfo::Mod;
+ default:
+ break;
+ }
+ }
+
+ // Otherwise, just do the coarse-grained thing that always works.
+ if (Inst->mayWriteToMemory())
+ return ModRefInfo::ModRef;
+ if (Inst->mayReadFromMemory())
+ return ModRefInfo::Ref;
+ return ModRefInfo::NoModRef;
+}
+
+/// Private helper for finding the local dependencies of a call site.
+MemDepResult MemoryDependenceResults::getCallDependencyFrom(
+ CallBase *Call, bool isReadOnlyCall, BasicBlock::iterator ScanIt,
+ BasicBlock *BB) {
+ unsigned Limit = getDefaultBlockScanLimit();
+
+ // Walk backwards through the block, looking for dependencies.
+ while (ScanIt != BB->begin()) {
+ Instruction *Inst = &*--ScanIt;
+ // Debug intrinsics don't cause dependences and should not affect Limit
+ if (isa<DbgInfoIntrinsic>(Inst))
+ continue;
+
+ // Limit the amount of scanning we do so we don't end up with quadratic
+ // running time on extreme testcases.
+ --Limit;
+ if (!Limit)
+ return MemDepResult::getUnknown();
+
+ // If this inst is a memory op, get the pointer it accessed
+ MemoryLocation Loc;
+ ModRefInfo MR = GetLocation(Inst, Loc, TLI);
+ if (Loc.Ptr) {
+ // A simple instruction.
+ if (isModOrRefSet(AA.getModRefInfo(Call, Loc)))
+ return MemDepResult::getClobber(Inst);
+ continue;
+ }
+
+ if (auto *CallB = dyn_cast<CallBase>(Inst)) {
+ // If these two calls do not interfere, look past it.
+ if (isNoModRef(AA.getModRefInfo(Call, CallB))) {
+ // If the two calls are the same, return Inst as a Def, so that
+ // Call can be found redundant and eliminated.
+ if (isReadOnlyCall && !isModSet(MR) &&
+ Call->isIdenticalToWhenDefined(CallB))
+ return MemDepResult::getDef(Inst);
+
+ // Otherwise if the two calls don't interact (e.g. CallB is readnone)
+ // keep scanning.
+ continue;
+ } else
+ return MemDepResult::getClobber(Inst);
+ }
+
+ // If we could not obtain a pointer for the instruction and the instruction
+ // touches memory then assume that this is a dependency.
+ if (isModOrRefSet(MR))
+ return MemDepResult::getClobber(Inst);
+ }
+
+ // No dependence found. If this is the entry block of the function, it is
+ // unknown, otherwise it is non-local.
+ if (BB != &BB->getParent()->getEntryBlock())
+ return MemDepResult::getNonLocal();
+ return MemDepResult::getNonFuncLocal();
+}
+
+unsigned MemoryDependenceResults::getLoadLoadClobberFullWidthSize(
+ const Value *MemLocBase, int64_t MemLocOffs, unsigned MemLocSize,
+ const LoadInst *LI) {
+ // We can only extend simple integer loads.
+ if (!isa<IntegerType>(LI->getType()) || !LI->isSimple())
+ return 0;
+
+ // Load widening is hostile to ThreadSanitizer: it may cause false positives
+ // or make the reports more cryptic (access sizes are wrong).
+ if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeThread))
+ return 0;
+
+ const DataLayout &DL = LI->getModule()->getDataLayout();
+
+ // Get the base of this load.
+ int64_t LIOffs = 0;
+ const Value *LIBase =
+ GetPointerBaseWithConstantOffset(LI->getPointerOperand(), LIOffs, DL);
+
+ // If the two pointers are not based on the same pointer, we can't tell that
+ // they are related.
+ if (LIBase != MemLocBase)
+ return 0;
+
+ // Okay, the two values are based on the same pointer, but returned as
+ // no-alias. This happens when we have things like two byte loads at "P+1"
+ // and "P+3". Check to see if increasing the size of the "LI" load up to its
+ // alignment (or the largest native integer type) will allow us to load all
+ // the bits required by MemLoc.
+
+ // If MemLoc is before LI, then no widening of LI will help us out.
+ if (MemLocOffs < LIOffs)
+ return 0;
+
+ // Get the alignment of the load in bytes. We assume that it is safe to load
+ // any legal integer up to this size without a problem. For example, if we're
+ // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can
+ // widen it up to an i32 load. If it is known 2-byte aligned, we can widen it
+ // to i16.
+ unsigned LoadAlign = LI->getAlignment();
+
+ int64_t MemLocEnd = MemLocOffs + MemLocSize;
+
+ // If no amount of rounding up will let MemLoc fit into LI, then bail out.
+ if (LIOffs + LoadAlign < MemLocEnd)
+ return 0;
+
+ // This is the size of the load to try. Start with the next larger power of
+ // two.
+ unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U;
+ NewLoadByteSize = NextPowerOf2(NewLoadByteSize);
+
+ while (true) {
+ // If this load size is bigger than our known alignment or would not fit
+ // into a native integer register, then we fail.
+ if (NewLoadByteSize > LoadAlign ||
+ !DL.fitsInLegalInteger(NewLoadByteSize * 8))
+ return 0;
+
+ if (LIOffs + NewLoadByteSize > MemLocEnd &&
+ (LI->getParent()->getParent()->hasFnAttribute(
+ Attribute::SanitizeAddress) ||
+ LI->getParent()->getParent()->hasFnAttribute(
+ Attribute::SanitizeHWAddress)))
+ // We will be reading past the location accessed by the original program.
+ // While this is safe in a regular build, Address Safety analysis tools
+ // may start reporting false warnings. So, don't do widening.
+ return 0;
+
+ // If a load of this width would include all of MemLoc, then we succeed.
+ if (LIOffs + NewLoadByteSize >= MemLocEnd)
+ return NewLoadByteSize;
+
+ NewLoadByteSize <<= 1;
+ }
+}
+
+static bool isVolatile(Instruction *Inst) {
+ if (auto *LI = dyn_cast<LoadInst>(Inst))
+ return LI->isVolatile();
+ if (auto *SI = dyn_cast<StoreInst>(Inst))
+ return SI->isVolatile();
+ if (auto *AI = dyn_cast<AtomicCmpXchgInst>(Inst))
+ return AI->isVolatile();
+ return false;
+}
+
+MemDepResult MemoryDependenceResults::getPointerDependencyFrom(
+ const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt,
+ BasicBlock *BB, Instruction *QueryInst, unsigned *Limit,
+ OrderedBasicBlock *OBB) {
+ MemDepResult InvariantGroupDependency = MemDepResult::getUnknown();
+ if (QueryInst != nullptr) {
+ if (auto *LI = dyn_cast<LoadInst>(QueryInst)) {
+ InvariantGroupDependency = getInvariantGroupPointerDependency(LI, BB);
+
+ if (InvariantGroupDependency.isDef())
+ return InvariantGroupDependency;
+ }
+ }
+ MemDepResult SimpleDep = getSimplePointerDependencyFrom(
+ MemLoc, isLoad, ScanIt, BB, QueryInst, Limit, OBB);
+ if (SimpleDep.isDef())
+ return SimpleDep;
+ // Non-local invariant group dependency indicates there is non local Def
+ // (it only returns nonLocal if it finds nonLocal def), which is better than
+ // local clobber and everything else.
+ if (InvariantGroupDependency.isNonLocal())
+ return InvariantGroupDependency;
+
+ assert(InvariantGroupDependency.isUnknown() &&
+ "InvariantGroupDependency should be only unknown at this point");
+ return SimpleDep;
+}
+
+MemDepResult
+MemoryDependenceResults::getInvariantGroupPointerDependency(LoadInst *LI,
+ BasicBlock *BB) {
+
+ if (!LI->hasMetadata(LLVMContext::MD_invariant_group))
+ return MemDepResult::getUnknown();
+
+ // Take the ptr operand after all casts and geps 0. This way we can search
+ // cast graph down only.
+ Value *LoadOperand = LI->getPointerOperand()->stripPointerCasts();
+
+ // It's is not safe to walk the use list of global value, because function
+ // passes aren't allowed to look outside their functions.
+ // FIXME: this could be fixed by filtering instructions from outside
+ // of current function.
+ if (isa<GlobalValue>(LoadOperand))
+ return MemDepResult::getUnknown();
+
+ // Queue to process all pointers that are equivalent to load operand.
+ SmallVector<const Value *, 8> LoadOperandsQueue;
+ LoadOperandsQueue.push_back(LoadOperand);
+
+ Instruction *ClosestDependency = nullptr;
+ // Order of instructions in uses list is unpredictible. In order to always
+ // get the same result, we will look for the closest dominance.
+ auto GetClosestDependency = [this](Instruction *Best, Instruction *Other) {
+ assert(Other && "Must call it with not null instruction");
+ if (Best == nullptr || DT.dominates(Best, Other))
+ return Other;
+ return Best;
+ };
+
+ // FIXME: This loop is O(N^2) because dominates can be O(n) and in worst case
+ // we will see all the instructions. This should be fixed in MSSA.
+ while (!LoadOperandsQueue.empty()) {
+ const Value *Ptr = LoadOperandsQueue.pop_back_val();
+ assert(Ptr && !isa<GlobalValue>(Ptr) &&
+ "Null or GlobalValue should not be inserted");
+
+ for (const Use &Us : Ptr->uses()) {
+ auto *U = dyn_cast<Instruction>(Us.getUser());
+ if (!U || U == LI || !DT.dominates(U, LI))
+ continue;
+
+ // Bitcast or gep with zeros are using Ptr. Add to queue to check it's
+ // users. U = bitcast Ptr
+ if (isa<BitCastInst>(U)) {
+ LoadOperandsQueue.push_back(U);
+ continue;
+ }
+ // Gep with zeros is equivalent to bitcast.
+ // FIXME: we are not sure if some bitcast should be canonicalized to gep 0
+ // or gep 0 to bitcast because of SROA, so there are 2 forms. When
+ // typeless pointers will be ready then both cases will be gone
+ // (and this BFS also won't be needed).
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(U))
+ if (GEP->hasAllZeroIndices()) {
+ LoadOperandsQueue.push_back(U);
+ continue;
+ }
+
+ // If we hit load/store with the same invariant.group metadata (and the
+ // same pointer operand) we can assume that value pointed by pointer
+ // operand didn't change.
+ if ((isa<LoadInst>(U) || isa<StoreInst>(U)) &&
+ U->hasMetadata(LLVMContext::MD_invariant_group))
+ ClosestDependency = GetClosestDependency(ClosestDependency, U);
+ }
+ }
+
+ if (!ClosestDependency)
+ return MemDepResult::getUnknown();
+ if (ClosestDependency->getParent() == BB)
+ return MemDepResult::getDef(ClosestDependency);
+ // Def(U) can't be returned here because it is non-local. If local
+ // dependency won't be found then return nonLocal counting that the
+ // user will call getNonLocalPointerDependency, which will return cached
+ // result.
+ NonLocalDefsCache.try_emplace(
+ LI, NonLocalDepResult(ClosestDependency->getParent(),
+ MemDepResult::getDef(ClosestDependency), nullptr));
+ ReverseNonLocalDefsCache[ClosestDependency].insert(LI);
+ return MemDepResult::getNonLocal();
+}
+
+MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom(
+ const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt,
+ BasicBlock *BB, Instruction *QueryInst, unsigned *Limit,
+ OrderedBasicBlock *OBB) {
+ bool isInvariantLoad = false;
+
+ unsigned DefaultLimit = getDefaultBlockScanLimit();
+ if (!Limit)
+ Limit = &DefaultLimit;
+
+ // We must be careful with atomic accesses, as they may allow another thread
+ // to touch this location, clobbering it. We are conservative: if the
+ // QueryInst is not a simple (non-atomic) memory access, we automatically
+ // return getClobber.
+ // If it is simple, we know based on the results of
+ // "Compiler testing via a theory of sound optimisations in the C11/C++11
+ // memory model" in PLDI 2013, that a non-atomic location can only be
+ // clobbered between a pair of a release and an acquire action, with no
+ // access to the location in between.
+ // Here is an example for giving the general intuition behind this rule.
+ // In the following code:
+ // store x 0;
+ // release action; [1]
+ // acquire action; [4]
+ // %val = load x;
+ // It is unsafe to replace %val by 0 because another thread may be running:
+ // acquire action; [2]
+ // store x 42;
+ // release action; [3]
+ // with synchronization from 1 to 2 and from 3 to 4, resulting in %val
+ // being 42. A key property of this program however is that if either
+ // 1 or 4 were missing, there would be a race between the store of 42
+ // either the store of 0 or the load (making the whole program racy).
+ // The paper mentioned above shows that the same property is respected
+ // by every program that can detect any optimization of that kind: either
+ // it is racy (undefined) or there is a release followed by an acquire
+ // between the pair of accesses under consideration.
+
+ // If the load is invariant, we "know" that it doesn't alias *any* write. We
+ // do want to respect mustalias results since defs are useful for value
+ // forwarding, but any mayalias write can be assumed to be noalias.
+ // Arguably, this logic should be pushed inside AliasAnalysis itself.
+ if (isLoad && QueryInst) {
+ LoadInst *LI = dyn_cast<LoadInst>(QueryInst);
+ if (LI && LI->hasMetadata(LLVMContext::MD_invariant_load))
+ isInvariantLoad = true;
+ }
+
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+
+ // If the caller did not provide an ordered basic block,
+ // create one to lazily compute and cache instruction
+ // positions inside a BB. This is used to provide fast queries for relative
+ // position between two instructions in a BB and can be used by
+ // AliasAnalysis::callCapturesBefore.
+ OrderedBasicBlock OBBTmp(BB);
+ if (!OBB)
+ OBB = &OBBTmp;
+
+ // Return "true" if and only if the instruction I is either a non-simple
+ // load or a non-simple store.
+ auto isNonSimpleLoadOrStore = [](Instruction *I) -> bool {
+ if (auto *LI = dyn_cast<LoadInst>(I))
+ return !LI->isSimple();
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ return !SI->isSimple();
+ return false;
+ };
+
+ // Return "true" if I is not a load and not a store, but it does access
+ // memory.
+ auto isOtherMemAccess = [](Instruction *I) -> bool {
+ return !isa<LoadInst>(I) && !isa<StoreInst>(I) && I->mayReadOrWriteMemory();
+ };
+
+ // Walk backwards through the basic block, looking for dependencies.
+ while (ScanIt != BB->begin()) {
+ Instruction *Inst = &*--ScanIt;
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst))
+ // Debug intrinsics don't (and can't) cause dependencies.
+ if (isa<DbgInfoIntrinsic>(II))
+ continue;
+
+ // Limit the amount of scanning we do so we don't end up with quadratic
+ // running time on extreme testcases.
+ --*Limit;
+ if (!*Limit)
+ return MemDepResult::getUnknown();
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
+ // If we reach a lifetime begin or end marker, then the query ends here
+ // because the value is undefined.
+ if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
+ // FIXME: This only considers queries directly on the invariant-tagged
+ // pointer, not on query pointers that are indexed off of them. It'd
+ // be nice to handle that at some point (the right approach is to use
+ // GetPointerBaseWithConstantOffset).
+ if (AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), MemLoc))
+ return MemDepResult::getDef(II);
+ continue;
+ }
+ }
+
+ // Values depend on loads if the pointers are must aliased. This means
+ // that a load depends on another must aliased load from the same value.
+ // One exception is atomic loads: a value can depend on an atomic load that
+ // it does not alias with when this atomic load indicates that another
+ // thread may be accessing the location.
+ if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
+ // While volatile access cannot be eliminated, they do not have to clobber
+ // non-aliasing locations, as normal accesses, for example, can be safely
+ // reordered with volatile accesses.
+ if (LI->isVolatile()) {
+ if (!QueryInst)
+ // Original QueryInst *may* be volatile
+ return MemDepResult::getClobber(LI);
+ if (isVolatile(QueryInst))
+ // Ordering required if QueryInst is itself volatile
+ return MemDepResult::getClobber(LI);
+ // Otherwise, volatile doesn't imply any special ordering
+ }
+
+ // Atomic loads have complications involved.
+ // A Monotonic (or higher) load is OK if the query inst is itself not
+ // atomic.
+ // FIXME: This is overly conservative.
+ if (LI->isAtomic() && isStrongerThanUnordered(LI->getOrdering())) {
+ if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) ||
+ isOtherMemAccess(QueryInst))
+ return MemDepResult::getClobber(LI);
+ if (LI->getOrdering() != AtomicOrdering::Monotonic)
+ return MemDepResult::getClobber(LI);
+ }
+
+ MemoryLocation LoadLoc = MemoryLocation::get(LI);
+
+ // If we found a pointer, check if it could be the same as our pointer.
+ AliasResult R = AA.alias(LoadLoc, MemLoc);
+
+ if (isLoad) {
+ if (R == NoAlias)
+ continue;
+
+ // Must aliased loads are defs of each other.
+ if (R == MustAlias)
+ return MemDepResult::getDef(Inst);
+
+#if 0 // FIXME: Temporarily disabled. GVN is cleverly rewriting loads
+ // in terms of clobbering loads, but since it does this by looking
+ // at the clobbering load directly, it doesn't know about any
+ // phi translation that may have happened along the way.
+
+ // If we have a partial alias, then return this as a clobber for the
+ // client to handle.
+ if (R == PartialAlias)
+ return MemDepResult::getClobber(Inst);
+#endif
+
+ // Random may-alias loads don't depend on each other without a
+ // dependence.
+ continue;
+ }
+
+ // Stores don't depend on other no-aliased accesses.
+ if (R == NoAlias)
+ continue;
+
+ // Stores don't alias loads from read-only memory.
+ if (AA.pointsToConstantMemory(LoadLoc))
+ continue;
+
+ // Stores depend on may/must aliased loads.
+ return MemDepResult::getDef(Inst);
+ }
+
+ if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+ // Atomic stores have complications involved.
+ // A Monotonic store is OK if the query inst is itself not atomic.
+ // FIXME: This is overly conservative.
+ if (!SI->isUnordered() && SI->isAtomic()) {
+ if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) ||
+ isOtherMemAccess(QueryInst))
+ return MemDepResult::getClobber(SI);
+ if (SI->getOrdering() != AtomicOrdering::Monotonic)
+ return MemDepResult::getClobber(SI);
+ }
+
+ // FIXME: this is overly conservative.
+ // While volatile access cannot be eliminated, they do not have to clobber
+ // non-aliasing locations, as normal accesses can for example be reordered
+ // with volatile accesses.
+ if (SI->isVolatile())
+ if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) ||
+ isOtherMemAccess(QueryInst))
+ return MemDepResult::getClobber(SI);
+
+ // If alias analysis can tell that this store is guaranteed to not modify
+ // the query pointer, ignore it. Use getModRefInfo to handle cases where
+ // the query pointer points to constant memory etc.
+ if (!isModOrRefSet(AA.getModRefInfo(SI, MemLoc)))
+ continue;
+
+ // Ok, this store might clobber the query pointer. Check to see if it is
+ // a must alias: in this case, we want to return this as a def.
+ // FIXME: Use ModRefInfo::Must bit from getModRefInfo call above.
+ MemoryLocation StoreLoc = MemoryLocation::get(SI);
+
+ // If we found a pointer, check if it could be the same as our pointer.
+ AliasResult R = AA.alias(StoreLoc, MemLoc);
+
+ if (R == NoAlias)
+ continue;
+ if (R == MustAlias)
+ return MemDepResult::getDef(Inst);
+ if (isInvariantLoad)
+ continue;
+ return MemDepResult::getClobber(Inst);
+ }
+
+ // If this is an allocation, and if we know that the accessed pointer is to
+ // the allocation, return Def. This means that there is no dependence and
+ // the access can be optimized based on that. For example, a load could
+ // turn into undef. Note that we can bypass the allocation itself when
+ // looking for a clobber in many cases; that's an alias property and is
+ // handled by BasicAA.
+ if (isa<AllocaInst>(Inst) || isNoAliasFn(Inst, &TLI)) {
+ const Value *AccessPtr = GetUnderlyingObject(MemLoc.Ptr, DL);
+ if (AccessPtr == Inst || AA.isMustAlias(Inst, AccessPtr))
+ return MemDepResult::getDef(Inst);
+ }
+
+ if (isInvariantLoad)
+ continue;
+
+ // A release fence requires that all stores complete before it, but does
+ // not prevent the reordering of following loads or stores 'before' the
+ // fence. As a result, we look past it when finding a dependency for
+ // loads. DSE uses this to find preceding stores to delete and thus we
+ // can't bypass the fence if the query instruction is a store.
+ if (FenceInst *FI = dyn_cast<FenceInst>(Inst))
+ if (isLoad && FI->getOrdering() == AtomicOrdering::Release)
+ continue;
+
+ // See if this instruction (e.g. a call or vaarg) mod/ref's the pointer.
+ ModRefInfo MR = AA.getModRefInfo(Inst, MemLoc);
+ // If necessary, perform additional analysis.
+ if (isModAndRefSet(MR))
+ MR = AA.callCapturesBefore(Inst, MemLoc, &DT, OBB);
+ switch (clearMust(MR)) {
+ case ModRefInfo::NoModRef:
+ // If the call has no effect on the queried pointer, just ignore it.
+ continue;
+ case ModRefInfo::Mod:
+ return MemDepResult::getClobber(Inst);
+ case ModRefInfo::Ref:
+ // If the call is known to never store to the pointer, and if this is a
+ // load query, we can safely ignore it (scan past it).
+ if (isLoad)
+ continue;
+ LLVM_FALLTHROUGH;
+ default:
+ // Otherwise, there is a potential dependence. Return a clobber.
+ return MemDepResult::getClobber(Inst);
+ }
+ }
+
+ // No dependence found. If this is the entry block of the function, it is
+ // unknown, otherwise it is non-local.
+ if (BB != &BB->getParent()->getEntryBlock())
+ return MemDepResult::getNonLocal();
+ return MemDepResult::getNonFuncLocal();
+}
+
+MemDepResult MemoryDependenceResults::getDependency(Instruction *QueryInst,
+ OrderedBasicBlock *OBB) {
+ Instruction *ScanPos = QueryInst;
+
+ // Check for a cached result
+ MemDepResult &LocalCache = LocalDeps[QueryInst];
+
+ // If the cached entry is non-dirty, just return it. Note that this depends
+ // on MemDepResult's default constructing to 'dirty'.
+ if (!LocalCache.isDirty())
+ return LocalCache;
+
+ // Otherwise, if we have a dirty entry, we know we can start the scan at that
+ // instruction, which may save us some work.
+ if (Instruction *Inst = LocalCache.getInst()) {
+ ScanPos = Inst;
+
+ RemoveFromReverseMap(ReverseLocalDeps, Inst, QueryInst);
+ }
+
+ BasicBlock *QueryParent = QueryInst->getParent();
+
+ // Do the scan.
+ if (BasicBlock::iterator(QueryInst) == QueryParent->begin()) {
+ // No dependence found. If this is the entry block of the function, it is
+ // unknown, otherwise it is non-local.
+ if (QueryParent != &QueryParent->getParent()->getEntryBlock())
+ LocalCache = MemDepResult::getNonLocal();
+ else
+ LocalCache = MemDepResult::getNonFuncLocal();
+ } else {
+ MemoryLocation MemLoc;
+ ModRefInfo MR = GetLocation(QueryInst, MemLoc, TLI);
+ if (MemLoc.Ptr) {
+ // If we can do a pointer scan, make it happen.
+ bool isLoad = !isModSet(MR);
+ if (auto *II = dyn_cast<IntrinsicInst>(QueryInst))
+ isLoad |= II->getIntrinsicID() == Intrinsic::lifetime_start;
+
+ LocalCache =
+ getPointerDependencyFrom(MemLoc, isLoad, ScanPos->getIterator(),
+ QueryParent, QueryInst, nullptr, OBB);
+ } else if (auto *QueryCall = dyn_cast<CallBase>(QueryInst)) {
+ bool isReadOnly = AA.onlyReadsMemory(QueryCall);
+ LocalCache = getCallDependencyFrom(QueryCall, isReadOnly,
+ ScanPos->getIterator(), QueryParent);
+ } else
+ // Non-memory instruction.
+ LocalCache = MemDepResult::getUnknown();
+ }
+
+ // Remember the result!
+ if (Instruction *I = LocalCache.getInst())
+ ReverseLocalDeps[I].insert(QueryInst);
+
+ return LocalCache;
+}
+
+#ifndef NDEBUG
+/// This method is used when -debug is specified to verify that cache arrays
+/// are properly kept sorted.
+static void AssertSorted(MemoryDependenceResults::NonLocalDepInfo &Cache,
+ int Count = -1) {
+ if (Count == -1)
+ Count = Cache.size();
+ assert(std::is_sorted(Cache.begin(), Cache.begin() + Count) &&
+ "Cache isn't sorted!");
+}
+#endif
+
+const MemoryDependenceResults::NonLocalDepInfo &
+MemoryDependenceResults::getNonLocalCallDependency(CallBase *QueryCall) {
+ assert(getDependency(QueryCall).isNonLocal() &&
+ "getNonLocalCallDependency should only be used on calls with "
+ "non-local deps!");
+ PerInstNLInfo &CacheP = NonLocalDeps[QueryCall];
+ NonLocalDepInfo &Cache = CacheP.first;
+
+ // This is the set of blocks that need to be recomputed. In the cached case,
+ // this can happen due to instructions being deleted etc. In the uncached
+ // case, this starts out as the set of predecessors we care about.
+ SmallVector<BasicBlock *, 32> DirtyBlocks;
+
+ if (!Cache.empty()) {
+ // Okay, we have a cache entry. If we know it is not dirty, just return it
+ // with no computation.
+ if (!CacheP.second) {
+ ++NumCacheNonLocal;
+ return Cache;
+ }
+
+ // If we already have a partially computed set of results, scan them to
+ // determine what is dirty, seeding our initial DirtyBlocks worklist.
+ for (auto &Entry : Cache)
+ if (Entry.getResult().isDirty())
+ DirtyBlocks.push_back(Entry.getBB());
+
+ // Sort the cache so that we can do fast binary search lookups below.
+ llvm::sort(Cache);
+
+ ++NumCacheDirtyNonLocal;
+ // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: "
+ // << Cache.size() << " cached: " << *QueryInst;
+ } else {
+ // Seed DirtyBlocks with each of the preds of QueryInst's block.
+ BasicBlock *QueryBB = QueryCall->getParent();
+ for (BasicBlock *Pred : PredCache.get(QueryBB))
+ DirtyBlocks.push_back(Pred);
+ ++NumUncacheNonLocal;
+ }
+
+ // isReadonlyCall - If this is a read-only call, we can be more aggressive.
+ bool isReadonlyCall = AA.onlyReadsMemory(QueryCall);
+
+ SmallPtrSet<BasicBlock *, 32> Visited;
+
+ unsigned NumSortedEntries = Cache.size();
+ LLVM_DEBUG(AssertSorted(Cache));
+
+ // Iterate while we still have blocks to update.
+ while (!DirtyBlocks.empty()) {
+ BasicBlock *DirtyBB = DirtyBlocks.back();
+ DirtyBlocks.pop_back();
+
+ // Already processed this block?
+ if (!Visited.insert(DirtyBB).second)
+ continue;
+
+ // Do a binary search to see if we already have an entry for this block in
+ // the cache set. If so, find it.
+ LLVM_DEBUG(AssertSorted(Cache, NumSortedEntries));
+ NonLocalDepInfo::iterator Entry =
+ std::upper_bound(Cache.begin(), Cache.begin() + NumSortedEntries,
+ NonLocalDepEntry(DirtyBB));
+ if (Entry != Cache.begin() && std::prev(Entry)->getBB() == DirtyBB)
+ --Entry;
+
+ NonLocalDepEntry *ExistingResult = nullptr;
+ if (Entry != Cache.begin() + NumSortedEntries &&
+ Entry->getBB() == DirtyBB) {
+ // If we already have an entry, and if it isn't already dirty, the block
+ // is done.
+ if (!Entry->getResult().isDirty())
+ continue;
+
+ // Otherwise, remember this slot so we can update the value.
+ ExistingResult = &*Entry;
+ }
+
+ // If the dirty entry has a pointer, start scanning from it so we don't have
+ // to rescan the entire block.
+ BasicBlock::iterator ScanPos = DirtyBB->end();
+ if (ExistingResult) {
+ if (Instruction *Inst = ExistingResult->getResult().getInst()) {
+ ScanPos = Inst->getIterator();
+ // We're removing QueryInst's use of Inst.
+ RemoveFromReverseMap<Instruction *>(ReverseNonLocalDeps, Inst,
+ QueryCall);
+ }
+ }
+
+ // Find out if this block has a local dependency for QueryInst.
+ MemDepResult Dep;
+
+ if (ScanPos != DirtyBB->begin()) {
+ Dep = getCallDependencyFrom(QueryCall, isReadonlyCall, ScanPos, DirtyBB);
+ } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) {
+ // No dependence found. If this is the entry block of the function, it is
+ // a clobber, otherwise it is unknown.
+ Dep = MemDepResult::getNonLocal();
+ } else {
+ Dep = MemDepResult::getNonFuncLocal();
+ }
+
+ // If we had a dirty entry for the block, update it. Otherwise, just add
+ // a new entry.
+ if (ExistingResult)
+ ExistingResult->setResult(Dep);
+ else
+ Cache.push_back(NonLocalDepEntry(DirtyBB, Dep));
+
+ // If the block has a dependency (i.e. it isn't completely transparent to
+ // the value), remember the association!
+ if (!Dep.isNonLocal()) {
+ // Keep the ReverseNonLocalDeps map up to date so we can efficiently
+ // update this when we remove instructions.
+ if (Instruction *Inst = Dep.getInst())
+ ReverseNonLocalDeps[Inst].insert(QueryCall);
+ } else {
+
+ // If the block *is* completely transparent to the load, we need to check
+ // the predecessors of this block. Add them to our worklist.
+ for (BasicBlock *Pred : PredCache.get(DirtyBB))
+ DirtyBlocks.push_back(Pred);
+ }
+ }
+
+ return Cache;
+}
+
+void MemoryDependenceResults::getNonLocalPointerDependency(
+ Instruction *QueryInst, SmallVectorImpl<NonLocalDepResult> &Result) {
+ const MemoryLocation Loc = MemoryLocation::get(QueryInst);
+ bool isLoad = isa<LoadInst>(QueryInst);
+ BasicBlock *FromBB = QueryInst->getParent();
+ assert(FromBB);
+
+ assert(Loc.Ptr->getType()->isPointerTy() &&
+ "Can't get pointer deps of a non-pointer!");
+ Result.clear();
+ {
+ // Check if there is cached Def with invariant.group.
+ auto NonLocalDefIt = NonLocalDefsCache.find(QueryInst);
+ if (NonLocalDefIt != NonLocalDefsCache.end()) {
+ Result.push_back(NonLocalDefIt->second);
+ ReverseNonLocalDefsCache[NonLocalDefIt->second.getResult().getInst()]
+ .erase(QueryInst);
+ NonLocalDefsCache.erase(NonLocalDefIt);
+ return;
+ }
+ }
+ // This routine does not expect to deal with volatile instructions.
+ // Doing so would require piping through the QueryInst all the way through.
+ // TODO: volatiles can't be elided, but they can be reordered with other
+ // non-volatile accesses.
+
+ // We currently give up on any instruction which is ordered, but we do handle
+ // atomic instructions which are unordered.
+ // TODO: Handle ordered instructions
+ auto isOrdered = [](Instruction *Inst) {
+ if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
+ return !LI->isUnordered();
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+ return !SI->isUnordered();
+ }
+ return false;
+ };
+ if (isVolatile(QueryInst) || isOrdered(QueryInst)) {
+ Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(),
+ const_cast<Value *>(Loc.Ptr)));
+ return;
+ }
+ const DataLayout &DL = FromBB->getModule()->getDataLayout();
+ PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, &AC);
+
+ // This is the set of blocks we've inspected, and the pointer we consider in
+ // each block. Because of critical edges, we currently bail out if querying
+ // a block with multiple different pointers. This can happen during PHI
+ // translation.
+ DenseMap<BasicBlock *, Value *> Visited;
+ if (getNonLocalPointerDepFromBB(QueryInst, Address, Loc, isLoad, FromBB,
+ Result, Visited, true))
+ return;
+ Result.clear();
+ Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(),
+ const_cast<Value *>(Loc.Ptr)));
+}
+
+/// Compute the memdep value for BB with Pointer/PointeeSize using either
+/// cached information in Cache or by doing a lookup (which may use dirty cache
+/// info if available).
+///
+/// If we do a lookup, add the result to the cache.
+MemDepResult MemoryDependenceResults::GetNonLocalInfoForBlock(
+ Instruction *QueryInst, const MemoryLocation &Loc, bool isLoad,
+ BasicBlock *BB, NonLocalDepInfo *Cache, unsigned NumSortedEntries) {
+
+ // Do a binary search to see if we already have an entry for this block in
+ // the cache set. If so, find it.
+ NonLocalDepInfo::iterator Entry = std::upper_bound(
+ Cache->begin(), Cache->begin() + NumSortedEntries, NonLocalDepEntry(BB));
+ if (Entry != Cache->begin() && (Entry - 1)->getBB() == BB)
+ --Entry;
+
+ NonLocalDepEntry *ExistingResult = nullptr;
+ if (Entry != Cache->begin() + NumSortedEntries && Entry->getBB() == BB)
+ ExistingResult = &*Entry;
+
+ // If we have a cached entry, and it is non-dirty, use it as the value for
+ // this dependency.
+ if (ExistingResult && !ExistingResult->getResult().isDirty()) {
+ ++NumCacheNonLocalPtr;
+ return ExistingResult->getResult();
+ }
+
+ // Otherwise, we have to scan for the value. If we have a dirty cache
+ // entry, start scanning from its position, otherwise we scan from the end
+ // of the block.
+ BasicBlock::iterator ScanPos = BB->end();
+ if (ExistingResult && ExistingResult->getResult().getInst()) {
+ assert(ExistingResult->getResult().getInst()->getParent() == BB &&
+ "Instruction invalidated?");
+ ++NumCacheDirtyNonLocalPtr;
+ ScanPos = ExistingResult->getResult().getInst()->getIterator();
+
+ // Eliminating the dirty entry from 'Cache', so update the reverse info.
+ ValueIsLoadPair CacheKey(Loc.Ptr, isLoad);
+ RemoveFromReverseMap(ReverseNonLocalPtrDeps, &*ScanPos, CacheKey);
+ } else {
+ ++NumUncacheNonLocalPtr;
+ }
+
+ // Scan the block for the dependency.
+ MemDepResult Dep =
+ getPointerDependencyFrom(Loc, isLoad, ScanPos, BB, QueryInst);
+
+ // If we had a dirty entry for the block, update it. Otherwise, just add
+ // a new entry.
+ if (ExistingResult)
+ ExistingResult->setResult(Dep);
+ else
+ Cache->push_back(NonLocalDepEntry(BB, Dep));
+
+ // If the block has a dependency (i.e. it isn't completely transparent to
+ // the value), remember the reverse association because we just added it
+ // to Cache!
+ if (!Dep.isDef() && !Dep.isClobber())
+ return Dep;
+
+ // Keep the ReverseNonLocalPtrDeps map up to date so we can efficiently
+ // update MemDep when we remove instructions.
+ Instruction *Inst = Dep.getInst();
+ assert(Inst && "Didn't depend on anything?");
+ ValueIsLoadPair CacheKey(Loc.Ptr, isLoad);
+ ReverseNonLocalPtrDeps[Inst].insert(CacheKey);
+ return Dep;
+}
+
+/// Sort the NonLocalDepInfo cache, given a certain number of elements in the
+/// array that are already properly ordered.
+///
+/// This is optimized for the case when only a few entries are added.
+static void
+SortNonLocalDepInfoCache(MemoryDependenceResults::NonLocalDepInfo &Cache,
+ unsigned NumSortedEntries) {
+ switch (Cache.size() - NumSortedEntries) {
+ case 0:
+ // done, no new entries.
+ break;
+ case 2: {
+ // Two new entries, insert the last one into place.
+ NonLocalDepEntry Val = Cache.back();
+ Cache.pop_back();
+ MemoryDependenceResults::NonLocalDepInfo::iterator Entry =
+ std::upper_bound(Cache.begin(), Cache.end() - 1, Val);
+ Cache.insert(Entry, Val);
+ LLVM_FALLTHROUGH;
+ }
+ case 1:
+ // One new entry, Just insert the new value at the appropriate position.
+ if (Cache.size() != 1) {
+ NonLocalDepEntry Val = Cache.back();
+ Cache.pop_back();
+ MemoryDependenceResults::NonLocalDepInfo::iterator Entry =
+ std::upper_bound(Cache.begin(), Cache.end(), Val);
+ Cache.insert(Entry, Val);
+ }
+ break;
+ default:
+ // Added many values, do a full scale sort.
+ llvm::sort(Cache);
+ break;
+ }
+}
+
+/// Perform a dependency query based on pointer/pointeesize starting at the end
+/// of StartBB.
+///
+/// Add any clobber/def results to the results vector and keep track of which
+/// blocks are visited in 'Visited'.
+///
+/// This has special behavior for the first block queries (when SkipFirstBlock
+/// is true). In this special case, it ignores the contents of the specified
+/// block and starts returning dependence info for its predecessors.
+///
+/// This function returns true on success, or false to indicate that it could
+/// not compute dependence information for some reason. This should be treated
+/// as a clobber dependence on the first instruction in the predecessor block.
+bool MemoryDependenceResults::getNonLocalPointerDepFromBB(
+ Instruction *QueryInst, const PHITransAddr &Pointer,
+ const MemoryLocation &Loc, bool isLoad, BasicBlock *StartBB,
+ SmallVectorImpl<NonLocalDepResult> &Result,
+ DenseMap<BasicBlock *, Value *> &Visited, bool SkipFirstBlock) {
+ // Look up the cached info for Pointer.
+ ValueIsLoadPair CacheKey(Pointer.getAddr(), isLoad);
+
+ // Set up a temporary NLPI value. If the map doesn't yet have an entry for
+ // CacheKey, this value will be inserted as the associated value. Otherwise,
+ // it'll be ignored, and we'll have to check to see if the cached size and
+ // aa tags are consistent with the current query.
+ NonLocalPointerInfo InitialNLPI;
+ InitialNLPI.Size = Loc.Size;
+ InitialNLPI.AATags = Loc.AATags;
+
+ // Get the NLPI for CacheKey, inserting one into the map if it doesn't
+ // already have one.
+ std::pair<CachedNonLocalPointerInfo::iterator, bool> Pair =
+ NonLocalPointerDeps.insert(std::make_pair(CacheKey, InitialNLPI));
+ NonLocalPointerInfo *CacheInfo = &Pair.first->second;
+
+ // If we already have a cache entry for this CacheKey, we may need to do some
+ // work to reconcile the cache entry and the current query.
+ if (!Pair.second) {
+ if (CacheInfo->Size != Loc.Size) {
+ bool ThrowOutEverything;
+ if (CacheInfo->Size.hasValue() && Loc.Size.hasValue()) {
+ // FIXME: We may be able to do better in the face of results with mixed
+ // precision. We don't appear to get them in practice, though, so just
+ // be conservative.
+ ThrowOutEverything =
+ CacheInfo->Size.isPrecise() != Loc.Size.isPrecise() ||
+ CacheInfo->Size.getValue() < Loc.Size.getValue();
+ } else {
+ // For our purposes, unknown size > all others.
+ ThrowOutEverything = !Loc.Size.hasValue();
+ }
+
+ if (ThrowOutEverything) {
+ // The query's Size is greater than the cached one. Throw out the
+ // cached data and proceed with the query at the greater size.
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+ CacheInfo->Size = Loc.Size;
+ for (auto &Entry : CacheInfo->NonLocalDeps)
+ if (Instruction *Inst = Entry.getResult().getInst())
+ RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey);
+ CacheInfo->NonLocalDeps.clear();
+ } else {
+ // This query's Size is less than the cached one. Conservatively restart
+ // the query using the greater size.
+ return getNonLocalPointerDepFromBB(
+ QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad,
+ StartBB, Result, Visited, SkipFirstBlock);
+ }
+ }
+
+ // If the query's AATags are inconsistent with the cached one,
+ // conservatively throw out the cached data and restart the query with
+ // no tag if needed.
+ if (CacheInfo->AATags != Loc.AATags) {
+ if (CacheInfo->AATags) {
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+ CacheInfo->AATags = AAMDNodes();
+ for (auto &Entry : CacheInfo->NonLocalDeps)
+ if (Instruction *Inst = Entry.getResult().getInst())
+ RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey);
+ CacheInfo->NonLocalDeps.clear();
+ }
+ if (Loc.AATags)
+ return getNonLocalPointerDepFromBB(
+ QueryInst, Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result,
+ Visited, SkipFirstBlock);
+ }
+ }
+
+ NonLocalDepInfo *Cache = &CacheInfo->NonLocalDeps;
+
+ // If we have valid cached information for exactly the block we are
+ // investigating, just return it with no recomputation.
+ if (CacheInfo->Pair == BBSkipFirstBlockPair(StartBB, SkipFirstBlock)) {
+ // We have a fully cached result for this query then we can just return the
+ // cached results and populate the visited set. However, we have to verify
+ // that we don't already have conflicting results for these blocks. Check
+ // to ensure that if a block in the results set is in the visited set that
+ // it was for the same pointer query.
+ if (!Visited.empty()) {
+ for (auto &Entry : *Cache) {
+ DenseMap<BasicBlock *, Value *>::iterator VI =
+ Visited.find(Entry.getBB());
+ if (VI == Visited.end() || VI->second == Pointer.getAddr())
+ continue;
+
+ // We have a pointer mismatch in a block. Just return false, saying
+ // that something was clobbered in this result. We could also do a
+ // non-fully cached query, but there is little point in doing this.
+ return false;
+ }
+ }
+
+ Value *Addr = Pointer.getAddr();
+ for (auto &Entry : *Cache) {
+ Visited.insert(std::make_pair(Entry.getBB(), Addr));
+ if (Entry.getResult().isNonLocal()) {
+ continue;
+ }
+
+ if (DT.isReachableFromEntry(Entry.getBB())) {
+ Result.push_back(
+ NonLocalDepResult(Entry.getBB(), Entry.getResult(), Addr));
+ }
+ }
+ ++NumCacheCompleteNonLocalPtr;
+ return true;
+ }
+
+ // Otherwise, either this is a new block, a block with an invalid cache
+ // pointer or one that we're about to invalidate by putting more info into it
+ // than its valid cache info. If empty, the result will be valid cache info,
+ // otherwise it isn't.
+ if (Cache->empty())
+ CacheInfo->Pair = BBSkipFirstBlockPair(StartBB, SkipFirstBlock);
+ else
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+
+ SmallVector<BasicBlock *, 32> Worklist;
+ Worklist.push_back(StartBB);
+
+ // PredList used inside loop.
+ SmallVector<std::pair<BasicBlock *, PHITransAddr>, 16> PredList;
+
+ // Keep track of the entries that we know are sorted. Previously cached
+ // entries will all be sorted. The entries we add we only sort on demand (we
+ // don't insert every element into its sorted position). We know that we
+ // won't get any reuse from currently inserted values, because we don't
+ // revisit blocks after we insert info for them.
+ unsigned NumSortedEntries = Cache->size();
+ unsigned WorklistEntries = BlockNumberLimit;
+ bool GotWorklistLimit = false;
+ LLVM_DEBUG(AssertSorted(*Cache));
+
+ while (!Worklist.empty()) {
+ BasicBlock *BB = Worklist.pop_back_val();
+
+ // If we do process a large number of blocks it becomes very expensive and
+ // likely it isn't worth worrying about
+ if (Result.size() > NumResultsLimit) {
+ Worklist.clear();
+ // Sort it now (if needed) so that recursive invocations of
+ // getNonLocalPointerDepFromBB and other routines that could reuse the
+ // cache value will only see properly sorted cache arrays.
+ if (Cache && NumSortedEntries != Cache->size()) {
+ SortNonLocalDepInfoCache(*Cache, NumSortedEntries);
+ }
+ // Since we bail out, the "Cache" set won't contain all of the
+ // results for the query. This is ok (we can still use it to accelerate
+ // specific block queries) but we can't do the fastpath "return all
+ // results from the set". Clear out the indicator for this.
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+ return false;
+ }
+
+ // Skip the first block if we have it.
+ if (!SkipFirstBlock) {
+ // Analyze the dependency of *Pointer in FromBB. See if we already have
+ // been here.
+ assert(Visited.count(BB) && "Should check 'visited' before adding to WL");
+
+ // Get the dependency info for Pointer in BB. If we have cached
+ // information, we will use it, otherwise we compute it.
+ LLVM_DEBUG(AssertSorted(*Cache, NumSortedEntries));
+ MemDepResult Dep = GetNonLocalInfoForBlock(QueryInst, Loc, isLoad, BB,
+ Cache, NumSortedEntries);
+
+ // If we got a Def or Clobber, add this to the list of results.
+ if (!Dep.isNonLocal()) {
+ if (DT.isReachableFromEntry(BB)) {
+ Result.push_back(NonLocalDepResult(BB, Dep, Pointer.getAddr()));
+ continue;
+ }
+ }
+ }
+
+ // If 'Pointer' is an instruction defined in this block, then we need to do
+ // phi translation to change it into a value live in the predecessor block.
+ // If not, we just add the predecessors to the worklist and scan them with
+ // the same Pointer.
+ if (!Pointer.NeedsPHITranslationFromBlock(BB)) {
+ SkipFirstBlock = false;
+ SmallVector<BasicBlock *, 16> NewBlocks;
+ for (BasicBlock *Pred : PredCache.get(BB)) {
+ // Verify that we haven't looked at this block yet.
+ std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes =
+ Visited.insert(std::make_pair(Pred, Pointer.getAddr()));
+ if (InsertRes.second) {
+ // First time we've looked at *PI.
+ NewBlocks.push_back(Pred);
+ continue;
+ }
+
+ // If we have seen this block before, but it was with a different
+ // pointer then we have a phi translation failure and we have to treat
+ // this as a clobber.
+ if (InsertRes.first->second != Pointer.getAddr()) {
+ // Make sure to clean up the Visited map before continuing on to
+ // PredTranslationFailure.
+ for (unsigned i = 0; i < NewBlocks.size(); i++)
+ Visited.erase(NewBlocks[i]);
+ goto PredTranslationFailure;
+ }
+ }
+ if (NewBlocks.size() > WorklistEntries) {
+ // Make sure to clean up the Visited map before continuing on to
+ // PredTranslationFailure.
+ for (unsigned i = 0; i < NewBlocks.size(); i++)
+ Visited.erase(NewBlocks[i]);
+ GotWorklistLimit = true;
+ goto PredTranslationFailure;
+ }
+ WorklistEntries -= NewBlocks.size();
+ Worklist.append(NewBlocks.begin(), NewBlocks.end());
+ continue;
+ }
+
+ // We do need to do phi translation, if we know ahead of time we can't phi
+ // translate this value, don't even try.
+ if (!Pointer.IsPotentiallyPHITranslatable())
+ goto PredTranslationFailure;
+
+ // We may have added values to the cache list before this PHI translation.
+ // If so, we haven't done anything to ensure that the cache remains sorted.
+ // Sort it now (if needed) so that recursive invocations of
+ // getNonLocalPointerDepFromBB and other routines that could reuse the cache
+ // value will only see properly sorted cache arrays.
+ if (Cache && NumSortedEntries != Cache->size()) {
+ SortNonLocalDepInfoCache(*Cache, NumSortedEntries);
+ NumSortedEntries = Cache->size();
+ }
+ Cache = nullptr;
+
+ PredList.clear();
+ for (BasicBlock *Pred : PredCache.get(BB)) {
+ PredList.push_back(std::make_pair(Pred, Pointer));
+
+ // Get the PHI translated pointer in this predecessor. This can fail if
+ // not translatable, in which case the getAddr() returns null.
+ PHITransAddr &PredPointer = PredList.back().second;
+ PredPointer.PHITranslateValue(BB, Pred, &DT, /*MustDominate=*/false);
+ Value *PredPtrVal = PredPointer.getAddr();
+
+ // Check to see if we have already visited this pred block with another
+ // pointer. If so, we can't do this lookup. This failure can occur
+ // with PHI translation when a critical edge exists and the PHI node in
+ // the successor translates to a pointer value different than the
+ // pointer the block was first analyzed with.
+ std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes =
+ Visited.insert(std::make_pair(Pred, PredPtrVal));
+
+ if (!InsertRes.second) {
+ // We found the pred; take it off the list of preds to visit.
+ PredList.pop_back();
+
+ // If the predecessor was visited with PredPtr, then we already did
+ // the analysis and can ignore it.
+ if (InsertRes.first->second == PredPtrVal)
+ continue;
+
+ // Otherwise, the block was previously analyzed with a different
+ // pointer. We can't represent the result of this case, so we just
+ // treat this as a phi translation failure.
+
+ // Make sure to clean up the Visited map before continuing on to
+ // PredTranslationFailure.
+ for (unsigned i = 0, n = PredList.size(); i < n; ++i)
+ Visited.erase(PredList[i].first);
+
+ goto PredTranslationFailure;
+ }
+ }
+
+ // Actually process results here; this need to be a separate loop to avoid
+ // calling getNonLocalPointerDepFromBB for blocks we don't want to return
+ // any results for. (getNonLocalPointerDepFromBB will modify our
+ // datastructures in ways the code after the PredTranslationFailure label
+ // doesn't expect.)
+ for (unsigned i = 0, n = PredList.size(); i < n; ++i) {
+ BasicBlock *Pred = PredList[i].first;
+ PHITransAddr &PredPointer = PredList[i].second;
+ Value *PredPtrVal = PredPointer.getAddr();
+
+ bool CanTranslate = true;
+ // If PHI translation was unable to find an available pointer in this
+ // predecessor, then we have to assume that the pointer is clobbered in
+ // that predecessor. We can still do PRE of the load, which would insert
+ // a computation of the pointer in this predecessor.
+ if (!PredPtrVal)
+ CanTranslate = false;
+
+ // FIXME: it is entirely possible that PHI translating will end up with
+ // the same value. Consider PHI translating something like:
+ // X = phi [x, bb1], [y, bb2]. PHI translating for bb1 doesn't *need*
+ // to recurse here, pedantically speaking.
+
+ // If getNonLocalPointerDepFromBB fails here, that means the cached
+ // result conflicted with the Visited list; we have to conservatively
+ // assume it is unknown, but this also does not block PRE of the load.
+ if (!CanTranslate ||
+ !getNonLocalPointerDepFromBB(QueryInst, PredPointer,
+ Loc.getWithNewPtr(PredPtrVal), isLoad,
+ Pred, Result, Visited)) {
+ // Add the entry to the Result list.
+ NonLocalDepResult Entry(Pred, MemDepResult::getUnknown(), PredPtrVal);
+ Result.push_back(Entry);
+
+ // Since we had a phi translation failure, the cache for CacheKey won't
+ // include all of the entries that we need to immediately satisfy future
+ // queries. Mark this in NonLocalPointerDeps by setting the
+ // BBSkipFirstBlockPair pointer to null. This requires reuse of the
+ // cached value to do more work but not miss the phi trans failure.
+ NonLocalPointerInfo &NLPI = NonLocalPointerDeps[CacheKey];
+ NLPI.Pair = BBSkipFirstBlockPair();
+ continue;
+ }
+ }
+
+ // Refresh the CacheInfo/Cache pointer so that it isn't invalidated.
+ CacheInfo = &NonLocalPointerDeps[CacheKey];
+ Cache = &CacheInfo->NonLocalDeps;
+ NumSortedEntries = Cache->size();
+
+ // Since we did phi translation, the "Cache" set won't contain all of the
+ // results for the query. This is ok (we can still use it to accelerate
+ // specific block queries) but we can't do the fastpath "return all
+ // results from the set" Clear out the indicator for this.
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+ SkipFirstBlock = false;
+ continue;
+
+ PredTranslationFailure:
+ // The following code is "failure"; we can't produce a sane translation
+ // for the given block. It assumes that we haven't modified any of
+ // our datastructures while processing the current block.
+
+ if (!Cache) {
+ // Refresh the CacheInfo/Cache pointer if it got invalidated.
+ CacheInfo = &NonLocalPointerDeps[CacheKey];
+ Cache = &CacheInfo->NonLocalDeps;
+ NumSortedEntries = Cache->size();
+ }
+
+ // Since we failed phi translation, the "Cache" set won't contain all of the
+ // results for the query. This is ok (we can still use it to accelerate
+ // specific block queries) but we can't do the fastpath "return all
+ // results from the set". Clear out the indicator for this.
+ CacheInfo->Pair = BBSkipFirstBlockPair();
+
+ // If *nothing* works, mark the pointer as unknown.
+ //
+ // If this is the magic first block, return this as a clobber of the whole
+ // incoming value. Since we can't phi translate to one of the predecessors,
+ // we have to bail out.
+ if (SkipFirstBlock)
+ return false;
+
+ bool foundBlock = false;
+ for (NonLocalDepEntry &I : llvm::reverse(*Cache)) {
+ if (I.getBB() != BB)
+ continue;
+
+ assert((GotWorklistLimit || I.getResult().isNonLocal() ||
+ !DT.isReachableFromEntry(BB)) &&
+ "Should only be here with transparent block");
+ foundBlock = true;
+ I.setResult(MemDepResult::getUnknown());
+ Result.push_back(
+ NonLocalDepResult(I.getBB(), I.getResult(), Pointer.getAddr()));
+ break;
+ }
+ (void)foundBlock; (void)GotWorklistLimit;
+ assert((foundBlock || GotWorklistLimit) && "Current block not in cache?");
+ }
+
+ // Okay, we're done now. If we added new values to the cache, re-sort it.
+ SortNonLocalDepInfoCache(*Cache, NumSortedEntries);
+ LLVM_DEBUG(AssertSorted(*Cache));
+ return true;
+}
+
+/// If P exists in CachedNonLocalPointerInfo or NonLocalDefsCache, remove it.
+void MemoryDependenceResults::RemoveCachedNonLocalPointerDependencies(
+ ValueIsLoadPair P) {
+
+ // Most of the time this cache is empty.
+ if (!NonLocalDefsCache.empty()) {
+ auto it = NonLocalDefsCache.find(P.getPointer());
+ if (it != NonLocalDefsCache.end()) {
+ RemoveFromReverseMap(ReverseNonLocalDefsCache,
+ it->second.getResult().getInst(), P.getPointer());
+ NonLocalDefsCache.erase(it);
+ }
+
+ if (auto *I = dyn_cast<Instruction>(P.getPointer())) {
+ auto toRemoveIt = ReverseNonLocalDefsCache.find(I);
+ if (toRemoveIt != ReverseNonLocalDefsCache.end()) {
+ for (const auto &entry : toRemoveIt->second)
+ NonLocalDefsCache.erase(entry);
+ ReverseNonLocalDefsCache.erase(toRemoveIt);
+ }
+ }
+ }
+
+ CachedNonLocalPointerInfo::iterator It = NonLocalPointerDeps.find(P);
+ if (It == NonLocalPointerDeps.end())
+ return;
+
+ // Remove all of the entries in the BB->val map. This involves removing
+ // instructions from the reverse map.
+ NonLocalDepInfo &PInfo = It->second.NonLocalDeps;
+
+ for (unsigned i = 0, e = PInfo.size(); i != e; ++i) {
+ Instruction *Target = PInfo[i].getResult().getInst();
+ if (!Target)
+ continue; // Ignore non-local dep results.
+ assert(Target->getParent() == PInfo[i].getBB());
+
+ // Eliminating the dirty entry from 'Cache', so update the reverse info.
+ RemoveFromReverseMap(ReverseNonLocalPtrDeps, Target, P);
+ }
+
+ // Remove P from NonLocalPointerDeps (which deletes NonLocalDepInfo).
+ NonLocalPointerDeps.erase(It);
+}
+
+void MemoryDependenceResults::invalidateCachedPointerInfo(Value *Ptr) {
+ // If Ptr isn't really a pointer, just ignore it.
+ if (!Ptr->getType()->isPointerTy())
+ return;
+ // Flush store info for the pointer.
+ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, false));
+ // Flush load info for the pointer.
+ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, true));
+ // Invalidate phis that use the pointer.
+ PV.invalidateValue(Ptr);
+}
+
+void MemoryDependenceResults::invalidateCachedPredecessors() {
+ PredCache.clear();
+}
+
+void MemoryDependenceResults::removeInstruction(Instruction *RemInst) {
+ // Walk through the Non-local dependencies, removing this one as the value
+ // for any cached queries.
+ NonLocalDepMapType::iterator NLDI = NonLocalDeps.find(RemInst);
+ if (NLDI != NonLocalDeps.end()) {
+ NonLocalDepInfo &BlockMap = NLDI->second.first;
+ for (auto &Entry : BlockMap)
+ if (Instruction *Inst = Entry.getResult().getInst())
+ RemoveFromReverseMap(ReverseNonLocalDeps, Inst, RemInst);
+ NonLocalDeps.erase(NLDI);
+ }
+
+ // If we have a cached local dependence query for this instruction, remove it.
+ LocalDepMapType::iterator LocalDepEntry = LocalDeps.find(RemInst);
+ if (LocalDepEntry != LocalDeps.end()) {
+ // Remove us from DepInst's reverse set now that the local dep info is gone.
+ if (Instruction *Inst = LocalDepEntry->second.getInst())
+ RemoveFromReverseMap(ReverseLocalDeps, Inst, RemInst);
+
+ // Remove this local dependency info.
+ LocalDeps.erase(LocalDepEntry);
+ }
+
+ // If we have any cached pointer dependencies on this instruction, remove
+ // them. If the instruction has non-pointer type, then it can't be a pointer
+ // base.
+
+ // Remove it from both the load info and the store info. The instruction
+ // can't be in either of these maps if it is non-pointer.
+ if (RemInst->getType()->isPointerTy()) {
+ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, false));
+ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, true));
+ }
+
+ // Loop over all of the things that depend on the instruction we're removing.
+ SmallVector<std::pair<Instruction *, Instruction *>, 8> ReverseDepsToAdd;
+
+ // If we find RemInst as a clobber or Def in any of the maps for other values,
+ // we need to replace its entry with a dirty version of the instruction after
+ // it. If RemInst is a terminator, we use a null dirty value.
+ //
+ // Using a dirty version of the instruction after RemInst saves having to scan
+ // the entire block to get to this point.
+ MemDepResult NewDirtyVal;
+ if (!RemInst->isTerminator())
+ NewDirtyVal = MemDepResult::getDirty(&*++RemInst->getIterator());
+
+ ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst);
+ if (ReverseDepIt != ReverseLocalDeps.end()) {
+ // RemInst can't be the terminator if it has local stuff depending on it.
+ assert(!ReverseDepIt->second.empty() && !RemInst->isTerminator() &&
+ "Nothing can locally depend on a terminator");
+
+ for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) {
+ assert(InstDependingOnRemInst != RemInst &&
+ "Already removed our local dep info");
+
+ LocalDeps[InstDependingOnRemInst] = NewDirtyVal;
+
+ // Make sure to remember that new things depend on NewDepInst.
+ assert(NewDirtyVal.getInst() &&
+ "There is no way something else can have "
+ "a local dep on this if it is a terminator!");
+ ReverseDepsToAdd.push_back(
+ std::make_pair(NewDirtyVal.getInst(), InstDependingOnRemInst));
+ }
+
+ ReverseLocalDeps.erase(ReverseDepIt);
+
+ // Add new reverse deps after scanning the set, to avoid invalidating the
+ // 'ReverseDeps' reference.
+ while (!ReverseDepsToAdd.empty()) {
+ ReverseLocalDeps[ReverseDepsToAdd.back().first].insert(
+ ReverseDepsToAdd.back().second);
+ ReverseDepsToAdd.pop_back();
+ }
+ }
+
+ ReverseDepIt = ReverseNonLocalDeps.find(RemInst);
+ if (ReverseDepIt != ReverseNonLocalDeps.end()) {
+ for (Instruction *I : ReverseDepIt->second) {
+ assert(I != RemInst && "Already removed NonLocalDep info for RemInst");
+
+ PerInstNLInfo &INLD = NonLocalDeps[I];
+ // The information is now dirty!
+ INLD.second = true;
+
+ for (auto &Entry : INLD.first) {
+ if (Entry.getResult().getInst() != RemInst)
+ continue;
+
+ // Convert to a dirty entry for the subsequent instruction.
+ Entry.setResult(NewDirtyVal);
+
+ if (Instruction *NextI = NewDirtyVal.getInst())
+ ReverseDepsToAdd.push_back(std::make_pair(NextI, I));
+ }
+ }
+
+ ReverseNonLocalDeps.erase(ReverseDepIt);
+
+ // Add new reverse deps after scanning the set, to avoid invalidating 'Set'
+ while (!ReverseDepsToAdd.empty()) {
+ ReverseNonLocalDeps[ReverseDepsToAdd.back().first].insert(
+ ReverseDepsToAdd.back().second);
+ ReverseDepsToAdd.pop_back();
+ }
+ }
+
+ // If the instruction is in ReverseNonLocalPtrDeps then it appears as a
+ // value in the NonLocalPointerDeps info.
+ ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt =
+ ReverseNonLocalPtrDeps.find(RemInst);
+ if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) {
+ SmallVector<std::pair<Instruction *, ValueIsLoadPair>, 8>
+ ReversePtrDepsToAdd;
+
+ for (ValueIsLoadPair P : ReversePtrDepIt->second) {
+ assert(P.getPointer() != RemInst &&
+ "Already removed NonLocalPointerDeps info for RemInst");
+
+ NonLocalDepInfo &NLPDI = NonLocalPointerDeps[P].NonLocalDeps;
+
+ // The cache is not valid for any specific block anymore.
+ NonLocalPointerDeps[P].Pair = BBSkipFirstBlockPair();
+
+ // Update any entries for RemInst to use the instruction after it.
+ for (auto &Entry : NLPDI) {
+ if (Entry.getResult().getInst() != RemInst)
+ continue;
+
+ // Convert to a dirty entry for the subsequent instruction.
+ Entry.setResult(NewDirtyVal);
+
+ if (Instruction *NewDirtyInst = NewDirtyVal.getInst())
+ ReversePtrDepsToAdd.push_back(std::make_pair(NewDirtyInst, P));
+ }
+
+ // Re-sort the NonLocalDepInfo. Changing the dirty entry to its
+ // subsequent value may invalidate the sortedness.
+ llvm::sort(NLPDI);
+ }
+
+ ReverseNonLocalPtrDeps.erase(ReversePtrDepIt);
+
+ while (!ReversePtrDepsToAdd.empty()) {
+ ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first].insert(
+ ReversePtrDepsToAdd.back().second);
+ ReversePtrDepsToAdd.pop_back();
+ }
+ }
+
+ // Invalidate phis that use the removed instruction.
+ PV.invalidateValue(RemInst);
+
+ assert(!NonLocalDeps.count(RemInst) && "RemInst got reinserted?");
+ LLVM_DEBUG(verifyRemoved(RemInst));
+}
+
+/// Verify that the specified instruction does not occur in our internal data
+/// structures.
+///
+/// This function verifies by asserting in debug builds.
+void MemoryDependenceResults::verifyRemoved(Instruction *D) const {
+#ifndef NDEBUG
+ for (const auto &DepKV : LocalDeps) {
+ assert(DepKV.first != D && "Inst occurs in data structures");
+ assert(DepKV.second.getInst() != D && "Inst occurs in data structures");
+ }
+
+ for (const auto &DepKV : NonLocalPointerDeps) {
+ assert(DepKV.first.getPointer() != D && "Inst occurs in NLPD map key");
+ for (const auto &Entry : DepKV.second.NonLocalDeps)
+ assert(Entry.getResult().getInst() != D && "Inst occurs as NLPD value");
+ }
+
+ for (const auto &DepKV : NonLocalDeps) {
+ assert(DepKV.first != D && "Inst occurs in data structures");
+ const PerInstNLInfo &INLD = DepKV.second;
+ for (const auto &Entry : INLD.first)
+ assert(Entry.getResult().getInst() != D &&
+ "Inst occurs in data structures");
+ }
+
+ for (const auto &DepKV : ReverseLocalDeps) {
+ assert(DepKV.first != D && "Inst occurs in data structures");
+ for (Instruction *Inst : DepKV.second)
+ assert(Inst != D && "Inst occurs in data structures");
+ }
+
+ for (const auto &DepKV : ReverseNonLocalDeps) {
+ assert(DepKV.first != D && "Inst occurs in data structures");
+ for (Instruction *Inst : DepKV.second)
+ assert(Inst != D && "Inst occurs in data structures");
+ }
+
+ for (const auto &DepKV : ReverseNonLocalPtrDeps) {
+ assert(DepKV.first != D && "Inst occurs in rev NLPD map");
+
+ for (ValueIsLoadPair P : DepKV.second)
+ assert(P != ValueIsLoadPair(D, false) && P != ValueIsLoadPair(D, true) &&
+ "Inst occurs in ReverseNonLocalPtrDeps map");
+ }
+#endif
+}
+
+AnalysisKey MemoryDependenceAnalysis::Key;
+
+MemoryDependenceAnalysis::MemoryDependenceAnalysis()
+ : DefaultBlockScanLimit(BlockScanLimit) {}
+
+MemoryDependenceResults
+MemoryDependenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+ auto &AA = AM.getResult<AAManager>(F);
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+ auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &PV = AM.getResult<PhiValuesAnalysis>(F);
+ return MemoryDependenceResults(AA, AC, TLI, DT, PV, DefaultBlockScanLimit);
+}
+
+char MemoryDependenceWrapperPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(MemoryDependenceWrapperPass, "memdep",
+ "Memory Dependence Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PhiValuesWrapperPass)
+INITIALIZE_PASS_END(MemoryDependenceWrapperPass, "memdep",
+ "Memory Dependence Analysis", false, true)
+
+MemoryDependenceWrapperPass::MemoryDependenceWrapperPass() : FunctionPass(ID) {
+ initializeMemoryDependenceWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+MemoryDependenceWrapperPass::~MemoryDependenceWrapperPass() = default;
+
+void MemoryDependenceWrapperPass::releaseMemory() {
+ MemDep.reset();
+}
+
+void MemoryDependenceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<PhiValuesWrapperPass>();
+ AU.addRequiredTransitive<AAResultsWrapperPass>();
+ AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
+}
+
+bool MemoryDependenceResults::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // Check whether our analysis is preserved.
+ auto PAC = PA.getChecker<MemoryDependenceAnalysis>();
+ if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>())
+ // If not, give up now.
+ return true;
+
+ // Check whether the analyses we depend on became invalid for any reason.
+ if (Inv.invalidate<AAManager>(F, PA) ||
+ Inv.invalidate<AssumptionAnalysis>(F, PA) ||
+ Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
+ Inv.invalidate<PhiValuesAnalysis>(F, PA))
+ return true;
+
+ // Otherwise this analysis result remains valid.
+ return false;
+}
+
+unsigned MemoryDependenceResults::getDefaultBlockScanLimit() const {
+ return DefaultBlockScanLimit;
+}
+
+bool MemoryDependenceWrapperPass::runOnFunction(Function &F) {
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &PV = getAnalysis<PhiValuesWrapperPass>().getResult();
+ MemDep.emplace(AA, AC, TLI, DT, PV, BlockScanLimit);
+ return false;
+}
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
new file mode 100644
index 000000000000..163830eee797
--- /dev/null
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -0,0 +1,211 @@
+//===- MemoryLocation.cpp - Memory location descriptions -------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+using namespace llvm;
+
+void LocationSize::print(raw_ostream &OS) const {
+ OS << "LocationSize::";
+ if (*this == unknown())
+ OS << "unknown";
+ else if (*this == mapEmpty())
+ OS << "mapEmpty";
+ else if (*this == mapTombstone())
+ OS << "mapTombstone";
+ else if (isPrecise())
+ OS << "precise(" << getValue() << ')';
+ else
+ OS << "upperBound(" << getValue() << ')';
+}
+
+MemoryLocation MemoryLocation::get(const LoadInst *LI) {
+ AAMDNodes AATags;
+ LI->getAAMetadata(AATags);
+ const auto &DL = LI->getModule()->getDataLayout();
+
+ return MemoryLocation(
+ LI->getPointerOperand(),
+ LocationSize::precise(DL.getTypeStoreSize(LI->getType())), AATags);
+}
+
+MemoryLocation MemoryLocation::get(const StoreInst *SI) {
+ AAMDNodes AATags;
+ SI->getAAMetadata(AATags);
+ const auto &DL = SI->getModule()->getDataLayout();
+
+ return MemoryLocation(SI->getPointerOperand(),
+ LocationSize::precise(DL.getTypeStoreSize(
+ SI->getValueOperand()->getType())),
+ AATags);
+}
+
+MemoryLocation MemoryLocation::get(const VAArgInst *VI) {
+ AAMDNodes AATags;
+ VI->getAAMetadata(AATags);
+
+ return MemoryLocation(VI->getPointerOperand(), LocationSize::unknown(),
+ AATags);
+}
+
+MemoryLocation MemoryLocation::get(const AtomicCmpXchgInst *CXI) {
+ AAMDNodes AATags;
+ CXI->getAAMetadata(AATags);
+ const auto &DL = CXI->getModule()->getDataLayout();
+
+ return MemoryLocation(CXI->getPointerOperand(),
+ LocationSize::precise(DL.getTypeStoreSize(
+ CXI->getCompareOperand()->getType())),
+ AATags);
+}
+
+MemoryLocation MemoryLocation::get(const AtomicRMWInst *RMWI) {
+ AAMDNodes AATags;
+ RMWI->getAAMetadata(AATags);
+ const auto &DL = RMWI->getModule()->getDataLayout();
+
+ return MemoryLocation(RMWI->getPointerOperand(),
+ LocationSize::precise(DL.getTypeStoreSize(
+ RMWI->getValOperand()->getType())),
+ AATags);
+}
+
+MemoryLocation MemoryLocation::getForSource(const MemTransferInst *MTI) {
+ return getForSource(cast<AnyMemTransferInst>(MTI));
+}
+
+MemoryLocation MemoryLocation::getForSource(const AtomicMemTransferInst *MTI) {
+ return getForSource(cast<AnyMemTransferInst>(MTI));
+}
+
+MemoryLocation MemoryLocation::getForSource(const AnyMemTransferInst *MTI) {
+ auto Size = LocationSize::unknown();
+ if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength()))
+ Size = LocationSize::precise(C->getValue().getZExtValue());
+
+ // memcpy/memmove can have AA tags. For memcpy, they apply
+ // to both the source and the destination.
+ AAMDNodes AATags;
+ MTI->getAAMetadata(AATags);
+
+ return MemoryLocation(MTI->getRawSource(), Size, AATags);
+}
+
+MemoryLocation MemoryLocation::getForDest(const MemIntrinsic *MI) {
+ return getForDest(cast<AnyMemIntrinsic>(MI));
+}
+
+MemoryLocation MemoryLocation::getForDest(const AtomicMemIntrinsic *MI) {
+ return getForDest(cast<AnyMemIntrinsic>(MI));
+}
+
+MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) {
+ auto Size = LocationSize::unknown();
+ if (ConstantInt *C = dyn_cast<ConstantInt>(MI->getLength()))
+ Size = LocationSize::precise(C->getValue().getZExtValue());
+
+ // memcpy/memmove can have AA tags. For memcpy, they apply
+ // to both the source and the destination.
+ AAMDNodes AATags;
+ MI->getAAMetadata(AATags);
+
+ return MemoryLocation(MI->getRawDest(), Size, AATags);
+}
+
+MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
+ unsigned ArgIdx,
+ const TargetLibraryInfo *TLI) {
+ AAMDNodes AATags;
+ Call->getAAMetadata(AATags);
+ const Value *Arg = Call->getArgOperand(ArgIdx);
+
+ // We may be able to produce an exact size for known intrinsics.
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call)) {
+ const DataLayout &DL = II->getModule()->getDataLayout();
+
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::memset:
+ case Intrinsic::memcpy:
+ case Intrinsic::memmove:
+ assert((ArgIdx == 0 || ArgIdx == 1) &&
+ "Invalid argument index for memory intrinsic");
+ if (ConstantInt *LenCI = dyn_cast<ConstantInt>(II->getArgOperand(2)))
+ return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()),
+ AATags);
+ break;
+
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ case Intrinsic::invariant_start:
+ assert(ArgIdx == 1 && "Invalid argument index");
+ return MemoryLocation(
+ Arg,
+ LocationSize::precise(
+ cast<ConstantInt>(II->getArgOperand(0))->getZExtValue()),
+ AATags);
+
+ case Intrinsic::invariant_end:
+ // The first argument to an invariant.end is a "descriptor" type (e.g. a
+ // pointer to a empty struct) which is never actually dereferenced.
+ if (ArgIdx == 0)
+ return MemoryLocation(Arg, LocationSize::precise(0), AATags);
+ assert(ArgIdx == 2 && "Invalid argument index");
+ return MemoryLocation(
+ Arg,
+ LocationSize::precise(
+ cast<ConstantInt>(II->getArgOperand(1))->getZExtValue()),
+ AATags);
+
+ case Intrinsic::arm_neon_vld1:
+ assert(ArgIdx == 0 && "Invalid argument index");
+ // LLVM's vld1 and vst1 intrinsics currently only support a single
+ // vector register.
+ return MemoryLocation(
+ Arg, LocationSize::precise(DL.getTypeStoreSize(II->getType())),
+ AATags);
+
+ case Intrinsic::arm_neon_vst1:
+ assert(ArgIdx == 0 && "Invalid argument index");
+ return MemoryLocation(Arg,
+ LocationSize::precise(DL.getTypeStoreSize(
+ II->getArgOperand(1)->getType())),
+ AATags);
+ }
+ }
+
+ // We can bound the aliasing properties of memset_pattern16 just as we can
+ // for memcpy/memset. This is particularly important because the
+ // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16
+ // whenever possible.
+ LibFunc F;
+ if (TLI && Call->getCalledFunction() &&
+ TLI->getLibFunc(*Call->getCalledFunction(), F) &&
+ F == LibFunc_memset_pattern16 && TLI->has(F)) {
+ assert((ArgIdx == 0 || ArgIdx == 1) &&
+ "Invalid argument index for memset_pattern16");
+ if (ArgIdx == 1)
+ return MemoryLocation(Arg, LocationSize::precise(16), AATags);
+ if (const ConstantInt *LenCI =
+ dyn_cast<ConstantInt>(Call->getArgOperand(2)))
+ return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()),
+ AATags);
+ }
+ // FIXME: Handle memset_pattern4 and memset_pattern8 also.
+
+ return MemoryLocation(Call->getArgOperand(ArgIdx), LocationSize::unknown(),
+ AATags);
+}
diff --git a/llvm/lib/Analysis/MemorySSA.cpp b/llvm/lib/Analysis/MemorySSA.cpp
new file mode 100644
index 000000000000..cfb8b7e7dcb5
--- /dev/null
+++ b/llvm/lib/Analysis/MemorySSA.cpp
@@ -0,0 +1,2475 @@
+//===- MemorySSA.cpp - Memory SSA Builder ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the MemorySSA class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/IteratedDominanceFrontier.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Use.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdlib>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "memoryssa"
+
+INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false,
+ true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false,
+ true)
+
+INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa",
+ "Memory SSA Printer", false, false)
+INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
+INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa",
+ "Memory SSA Printer", false, false)
+
+static cl::opt<unsigned> MaxCheckLimit(
+ "memssa-check-limit", cl::Hidden, cl::init(100),
+ cl::desc("The maximum number of stores/phis MemorySSA"
+ "will consider trying to walk past (default = 100)"));
+
+// Always verify MemorySSA if expensive checking is enabled.
+#ifdef EXPENSIVE_CHECKS
+bool llvm::VerifyMemorySSA = true;
+#else
+bool llvm::VerifyMemorySSA = false;
+#endif
+/// Enables memory ssa as a dependency for loop passes in legacy pass manager.
+cl::opt<bool> llvm::EnableMSSALoopDependency(
+ "enable-mssa-loop-dependency", cl::Hidden, cl::init(true),
+ cl::desc("Enable MemorySSA dependency for loop pass manager"));
+
+static cl::opt<bool, true>
+ VerifyMemorySSAX("verify-memoryssa", cl::location(VerifyMemorySSA),
+ cl::Hidden, cl::desc("Enable verification of MemorySSA."));
+
+namespace llvm {
+
+/// An assembly annotator class to print Memory SSA information in
+/// comments.
+class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter {
+ friend class MemorySSA;
+
+ const MemorySSA *MSSA;
+
+public:
+ MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {}
+
+ void emitBasicBlockStartAnnot(const BasicBlock *BB,
+ formatted_raw_ostream &OS) override {
+ if (MemoryAccess *MA = MSSA->getMemoryAccess(BB))
+ OS << "; " << *MA << "\n";
+ }
+
+ void emitInstructionAnnot(const Instruction *I,
+ formatted_raw_ostream &OS) override {
+ if (MemoryAccess *MA = MSSA->getMemoryAccess(I))
+ OS << "; " << *MA << "\n";
+ }
+};
+
+} // end namespace llvm
+
+namespace {
+
+/// Our current alias analysis API differentiates heavily between calls and
+/// non-calls, and functions called on one usually assert on the other.
+/// This class encapsulates the distinction to simplify other code that wants
+/// "Memory affecting instructions and related data" to use as a key.
+/// For example, this class is used as a densemap key in the use optimizer.
+class MemoryLocOrCall {
+public:
+ bool IsCall = false;
+
+ MemoryLocOrCall(MemoryUseOrDef *MUD)
+ : MemoryLocOrCall(MUD->getMemoryInst()) {}
+ MemoryLocOrCall(const MemoryUseOrDef *MUD)
+ : MemoryLocOrCall(MUD->getMemoryInst()) {}
+
+ MemoryLocOrCall(Instruction *Inst) {
+ if (auto *C = dyn_cast<CallBase>(Inst)) {
+ IsCall = true;
+ Call = C;
+ } else {
+ IsCall = false;
+ // There is no such thing as a memorylocation for a fence inst, and it is
+ // unique in that regard.
+ if (!isa<FenceInst>(Inst))
+ Loc = MemoryLocation::get(Inst);
+ }
+ }
+
+ explicit MemoryLocOrCall(const MemoryLocation &Loc) : Loc(Loc) {}
+
+ const CallBase *getCall() const {
+ assert(IsCall);
+ return Call;
+ }
+
+ MemoryLocation getLoc() const {
+ assert(!IsCall);
+ return Loc;
+ }
+
+ bool operator==(const MemoryLocOrCall &Other) const {
+ if (IsCall != Other.IsCall)
+ return false;
+
+ if (!IsCall)
+ return Loc == Other.Loc;
+
+ if (Call->getCalledValue() != Other.Call->getCalledValue())
+ return false;
+
+ return Call->arg_size() == Other.Call->arg_size() &&
+ std::equal(Call->arg_begin(), Call->arg_end(),
+ Other.Call->arg_begin());
+ }
+
+private:
+ union {
+ const CallBase *Call;
+ MemoryLocation Loc;
+ };
+};
+
+} // end anonymous namespace
+
+namespace llvm {
+
+template <> struct DenseMapInfo<MemoryLocOrCall> {
+ static inline MemoryLocOrCall getEmptyKey() {
+ return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey());
+ }
+
+ static inline MemoryLocOrCall getTombstoneKey() {
+ return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey());
+ }
+
+ static unsigned getHashValue(const MemoryLocOrCall &MLOC) {
+ if (!MLOC.IsCall)
+ return hash_combine(
+ MLOC.IsCall,
+ DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc()));
+
+ hash_code hash =
+ hash_combine(MLOC.IsCall, DenseMapInfo<const Value *>::getHashValue(
+ MLOC.getCall()->getCalledValue()));
+
+ for (const Value *Arg : MLOC.getCall()->args())
+ hash = hash_combine(hash, DenseMapInfo<const Value *>::getHashValue(Arg));
+ return hash;
+ }
+
+ static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) {
+ return LHS == RHS;
+ }
+};
+
+} // end namespace llvm
+
+/// This does one-way checks to see if Use could theoretically be hoisted above
+/// MayClobber. This will not check the other way around.
+///
+/// This assumes that, for the purposes of MemorySSA, Use comes directly after
+/// MayClobber, with no potentially clobbering operations in between them.
+/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.)
+static bool areLoadsReorderable(const LoadInst *Use,
+ const LoadInst *MayClobber) {
+ bool VolatileUse = Use->isVolatile();
+ bool VolatileClobber = MayClobber->isVolatile();
+ // Volatile operations may never be reordered with other volatile operations.
+ if (VolatileUse && VolatileClobber)
+ return false;
+ // Otherwise, volatile doesn't matter here. From the language reference:
+ // 'optimizers may change the order of volatile operations relative to
+ // non-volatile operations.'"
+
+ // If a load is seq_cst, it cannot be moved above other loads. If its ordering
+ // is weaker, it can be moved above other loads. We just need to be sure that
+ // MayClobber isn't an acquire load, because loads can't be moved above
+ // acquire loads.
+ //
+ // Note that this explicitly *does* allow the free reordering of monotonic (or
+ // weaker) loads of the same address.
+ bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent;
+ bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(),
+ AtomicOrdering::Acquire);
+ return !(SeqCstUse || MayClobberIsAcquire);
+}
+
+namespace {
+
+struct ClobberAlias {
+ bool IsClobber;
+ Optional<AliasResult> AR;
+};
+
+} // end anonymous namespace
+
+// Return a pair of {IsClobber (bool), AR (AliasResult)}. It relies on AR being
+// ignored if IsClobber = false.
+template <typename AliasAnalysisType>
+static ClobberAlias
+instructionClobbersQuery(const MemoryDef *MD, const MemoryLocation &UseLoc,
+ const Instruction *UseInst, AliasAnalysisType &AA) {
+ Instruction *DefInst = MD->getMemoryInst();
+ assert(DefInst && "Defining instruction not actually an instruction");
+ const auto *UseCall = dyn_cast<CallBase>(UseInst);
+ Optional<AliasResult> AR;
+
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) {
+ // These intrinsics will show up as affecting memory, but they are just
+ // markers, mostly.
+ //
+ // FIXME: We probably don't actually want MemorySSA to model these at all
+ // (including creating MemoryAccesses for them): we just end up inventing
+ // clobbers where they don't really exist at all. Please see D43269 for
+ // context.
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::lifetime_start:
+ if (UseCall)
+ return {false, NoAlias};
+ AR = AA.alias(MemoryLocation(II->getArgOperand(1)), UseLoc);
+ return {AR != NoAlias, AR};
+ case Intrinsic::lifetime_end:
+ case Intrinsic::invariant_start:
+ case Intrinsic::invariant_end:
+ case Intrinsic::assume:
+ return {false, NoAlias};
+ case Intrinsic::dbg_addr:
+ case Intrinsic::dbg_declare:
+ case Intrinsic::dbg_label:
+ case Intrinsic::dbg_value:
+ llvm_unreachable("debuginfo shouldn't have associated defs!");
+ default:
+ break;
+ }
+ }
+
+ if (UseCall) {
+ ModRefInfo I = AA.getModRefInfo(DefInst, UseCall);
+ AR = isMustSet(I) ? MustAlias : MayAlias;
+ return {isModOrRefSet(I), AR};
+ }
+
+ if (auto *DefLoad = dyn_cast<LoadInst>(DefInst))
+ if (auto *UseLoad = dyn_cast<LoadInst>(UseInst))
+ return {!areLoadsReorderable(UseLoad, DefLoad), MayAlias};
+
+ ModRefInfo I = AA.getModRefInfo(DefInst, UseLoc);
+ AR = isMustSet(I) ? MustAlias : MayAlias;
+ return {isModSet(I), AR};
+}
+
+template <typename AliasAnalysisType>
+static ClobberAlias instructionClobbersQuery(MemoryDef *MD,
+ const MemoryUseOrDef *MU,
+ const MemoryLocOrCall &UseMLOC,
+ AliasAnalysisType &AA) {
+ // FIXME: This is a temporary hack to allow a single instructionClobbersQuery
+ // to exist while MemoryLocOrCall is pushed through places.
+ if (UseMLOC.IsCall)
+ return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(),
+ AA);
+ return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(),
+ AA);
+}
+
+// Return true when MD may alias MU, return false otherwise.
+bool MemorySSAUtil::defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU,
+ AliasAnalysis &AA) {
+ return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA).IsClobber;
+}
+
+namespace {
+
+struct UpwardsMemoryQuery {
+ // True if our original query started off as a call
+ bool IsCall = false;
+ // The pointer location we started the query with. This will be empty if
+ // IsCall is true.
+ MemoryLocation StartingLoc;
+ // This is the instruction we were querying about.
+ const Instruction *Inst = nullptr;
+ // The MemoryAccess we actually got called with, used to test local domination
+ const MemoryAccess *OriginalAccess = nullptr;
+ Optional<AliasResult> AR = MayAlias;
+ bool SkipSelfAccess = false;
+
+ UpwardsMemoryQuery() = default;
+
+ UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access)
+ : IsCall(isa<CallBase>(Inst)), Inst(Inst), OriginalAccess(Access) {
+ if (!IsCall)
+ StartingLoc = MemoryLocation::get(Inst);
+ }
+};
+
+} // end anonymous namespace
+
+static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc,
+ BatchAAResults &AA) {
+ Instruction *Inst = MD->getMemoryInst();
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::lifetime_end:
+ return AA.alias(MemoryLocation(II->getArgOperand(1)), Loc) == MustAlias;
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
+template <typename AliasAnalysisType>
+static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysisType &AA,
+ const Instruction *I) {
+ // If the memory can't be changed, then loads of the memory can't be
+ // clobbered.
+ return isa<LoadInst>(I) && (I->hasMetadata(LLVMContext::MD_invariant_load) ||
+ AA.pointsToConstantMemory(MemoryLocation(
+ cast<LoadInst>(I)->getPointerOperand())));
+}
+
+/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing
+/// inbetween `Start` and `ClobberAt` can clobbers `Start`.
+///
+/// This is meant to be as simple and self-contained as possible. Because it
+/// uses no cache, etc., it can be relatively expensive.
+///
+/// \param Start The MemoryAccess that we want to walk from.
+/// \param ClobberAt A clobber for Start.
+/// \param StartLoc The MemoryLocation for Start.
+/// \param MSSA The MemorySSA instance that Start and ClobberAt belong to.
+/// \param Query The UpwardsMemoryQuery we used for our search.
+/// \param AA The AliasAnalysis we used for our search.
+/// \param AllowImpreciseClobber Always false, unless we do relaxed verify.
+
+template <typename AliasAnalysisType>
+LLVM_ATTRIBUTE_UNUSED static void
+checkClobberSanity(const MemoryAccess *Start, MemoryAccess *ClobberAt,
+ const MemoryLocation &StartLoc, const MemorySSA &MSSA,
+ const UpwardsMemoryQuery &Query, AliasAnalysisType &AA,
+ bool AllowImpreciseClobber = false) {
+ assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?");
+
+ if (MSSA.isLiveOnEntryDef(Start)) {
+ assert(MSSA.isLiveOnEntryDef(ClobberAt) &&
+ "liveOnEntry must clobber itself");
+ return;
+ }
+
+ bool FoundClobber = false;
+ DenseSet<ConstMemoryAccessPair> VisitedPhis;
+ SmallVector<ConstMemoryAccessPair, 8> Worklist;
+ Worklist.emplace_back(Start, StartLoc);
+ // Walk all paths from Start to ClobberAt, while looking for clobbers. If one
+ // is found, complain.
+ while (!Worklist.empty()) {
+ auto MAP = Worklist.pop_back_val();
+ // All we care about is that nothing from Start to ClobberAt clobbers Start.
+ // We learn nothing from revisiting nodes.
+ if (!VisitedPhis.insert(MAP).second)
+ continue;
+
+ for (const auto *MA : def_chain(MAP.first)) {
+ if (MA == ClobberAt) {
+ if (const auto *MD = dyn_cast<MemoryDef>(MA)) {
+ // instructionClobbersQuery isn't essentially free, so don't use `|=`,
+ // since it won't let us short-circuit.
+ //
+ // Also, note that this can't be hoisted out of the `Worklist` loop,
+ // since MD may only act as a clobber for 1 of N MemoryLocations.
+ FoundClobber = FoundClobber || MSSA.isLiveOnEntryDef(MD);
+ if (!FoundClobber) {
+ ClobberAlias CA =
+ instructionClobbersQuery(MD, MAP.second, Query.Inst, AA);
+ if (CA.IsClobber) {
+ FoundClobber = true;
+ // Not used: CA.AR;
+ }
+ }
+ }
+ break;
+ }
+
+ // We should never hit liveOnEntry, unless it's the clobber.
+ assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?");
+
+ if (const auto *MD = dyn_cast<MemoryDef>(MA)) {
+ // If Start is a Def, skip self.
+ if (MD == Start)
+ continue;
+
+ assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA)
+ .IsClobber &&
+ "Found clobber before reaching ClobberAt!");
+ continue;
+ }
+
+ if (const auto *MU = dyn_cast<MemoryUse>(MA)) {
+ (void)MU;
+ assert (MU == Start &&
+ "Can only find use in def chain if Start is a use");
+ continue;
+ }
+
+ assert(isa<MemoryPhi>(MA));
+ Worklist.append(
+ upward_defs_begin({const_cast<MemoryAccess *>(MA), MAP.second}),
+ upward_defs_end());
+ }
+ }
+
+ // If the verify is done following an optimization, it's possible that
+ // ClobberAt was a conservative clobbering, that we can now infer is not a
+ // true clobbering access. Don't fail the verify if that's the case.
+ // We do have accesses that claim they're optimized, but could be optimized
+ // further. Updating all these can be expensive, so allow it for now (FIXME).
+ if (AllowImpreciseClobber)
+ return;
+
+ // If ClobberAt is a MemoryPhi, we can assume something above it acted as a
+ // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point.
+ assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) &&
+ "ClobberAt never acted as a clobber");
+}
+
+namespace {
+
+/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up
+/// in one class.
+template <class AliasAnalysisType> class ClobberWalker {
+ /// Save a few bytes by using unsigned instead of size_t.
+ using ListIndex = unsigned;
+
+ /// Represents a span of contiguous MemoryDefs, potentially ending in a
+ /// MemoryPhi.
+ struct DefPath {
+ MemoryLocation Loc;
+ // Note that, because we always walk in reverse, Last will always dominate
+ // First. Also note that First and Last are inclusive.
+ MemoryAccess *First;
+ MemoryAccess *Last;
+ Optional<ListIndex> Previous;
+
+ DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last,
+ Optional<ListIndex> Previous)
+ : Loc(Loc), First(First), Last(Last), Previous(Previous) {}
+
+ DefPath(const MemoryLocation &Loc, MemoryAccess *Init,
+ Optional<ListIndex> Previous)
+ : DefPath(Loc, Init, Init, Previous) {}
+ };
+
+ const MemorySSA &MSSA;
+ AliasAnalysisType &AA;
+ DominatorTree &DT;
+ UpwardsMemoryQuery *Query;
+ unsigned *UpwardWalkLimit;
+
+ // Phi optimization bookkeeping
+ SmallVector<DefPath, 32> Paths;
+ DenseSet<ConstMemoryAccessPair> VisitedPhis;
+
+ /// Find the nearest def or phi that `From` can legally be optimized to.
+ const MemoryAccess *getWalkTarget(const MemoryPhi *From) const {
+ assert(From->getNumOperands() && "Phi with no operands?");
+
+ BasicBlock *BB = From->getBlock();
+ MemoryAccess *Result = MSSA.getLiveOnEntryDef();
+ DomTreeNode *Node = DT.getNode(BB);
+ while ((Node = Node->getIDom())) {
+ auto *Defs = MSSA.getBlockDefs(Node->getBlock());
+ if (Defs)
+ return &*Defs->rbegin();
+ }
+ return Result;
+ }
+
+ /// Result of calling walkToPhiOrClobber.
+ struct UpwardsWalkResult {
+ /// The "Result" of the walk. Either a clobber, the last thing we walked, or
+ /// both. Include alias info when clobber found.
+ MemoryAccess *Result;
+ bool IsKnownClobber;
+ Optional<AliasResult> AR;
+ };
+
+ /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last.
+ /// This will update Desc.Last as it walks. It will (optionally) also stop at
+ /// StopAt.
+ ///
+ /// This does not test for whether StopAt is a clobber
+ UpwardsWalkResult
+ walkToPhiOrClobber(DefPath &Desc, const MemoryAccess *StopAt = nullptr,
+ const MemoryAccess *SkipStopAt = nullptr) const {
+ assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world");
+ assert(UpwardWalkLimit && "Need a valid walk limit");
+ bool LimitAlreadyReached = false;
+ // (*UpwardWalkLimit) may be 0 here, due to the loop in tryOptimizePhi. Set
+ // it to 1. This will not do any alias() calls. It either returns in the
+ // first iteration in the loop below, or is set back to 0 if all def chains
+ // are free of MemoryDefs.
+ if (!*UpwardWalkLimit) {
+ *UpwardWalkLimit = 1;
+ LimitAlreadyReached = true;
+ }
+
+ for (MemoryAccess *Current : def_chain(Desc.Last)) {
+ Desc.Last = Current;
+ if (Current == StopAt || Current == SkipStopAt)
+ return {Current, false, MayAlias};
+
+ if (auto *MD = dyn_cast<MemoryDef>(Current)) {
+ if (MSSA.isLiveOnEntryDef(MD))
+ return {MD, true, MustAlias};
+
+ if (!--*UpwardWalkLimit)
+ return {Current, true, MayAlias};
+
+ ClobberAlias CA =
+ instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA);
+ if (CA.IsClobber)
+ return {MD, true, CA.AR};
+ }
+ }
+
+ if (LimitAlreadyReached)
+ *UpwardWalkLimit = 0;
+
+ assert(isa<MemoryPhi>(Desc.Last) &&
+ "Ended at a non-clobber that's not a phi?");
+ return {Desc.Last, false, MayAlias};
+ }
+
+ void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches,
+ ListIndex PriorNode) {
+ auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}),
+ upward_defs_end());
+ for (const MemoryAccessPair &P : UpwardDefs) {
+ PausedSearches.push_back(Paths.size());
+ Paths.emplace_back(P.second, P.first, PriorNode);
+ }
+ }
+
+ /// Represents a search that terminated after finding a clobber. This clobber
+ /// may or may not be present in the path of defs from LastNode..SearchStart,
+ /// since it may have been retrieved from cache.
+ struct TerminatedPath {
+ MemoryAccess *Clobber;
+ ListIndex LastNode;
+ };
+
+ /// Get an access that keeps us from optimizing to the given phi.
+ ///
+ /// PausedSearches is an array of indices into the Paths array. Its incoming
+ /// value is the indices of searches that stopped at the last phi optimization
+ /// target. It's left in an unspecified state.
+ ///
+ /// If this returns None, NewPaused is a vector of searches that terminated
+ /// at StopWhere. Otherwise, NewPaused is left in an unspecified state.
+ Optional<TerminatedPath>
+ getBlockingAccess(const MemoryAccess *StopWhere,
+ SmallVectorImpl<ListIndex> &PausedSearches,
+ SmallVectorImpl<ListIndex> &NewPaused,
+ SmallVectorImpl<TerminatedPath> &Terminated) {
+ assert(!PausedSearches.empty() && "No searches to continue?");
+
+ // BFS vs DFS really doesn't make a difference here, so just do a DFS with
+ // PausedSearches as our stack.
+ while (!PausedSearches.empty()) {
+ ListIndex PathIndex = PausedSearches.pop_back_val();
+ DefPath &Node = Paths[PathIndex];
+
+ // If we've already visited this path with this MemoryLocation, we don't
+ // need to do so again.
+ //
+ // NOTE: That we just drop these paths on the ground makes caching
+ // behavior sporadic. e.g. given a diamond:
+ // A
+ // B C
+ // D
+ //
+ // ...If we walk D, B, A, C, we'll only cache the result of phi
+ // optimization for A, B, and D; C will be skipped because it dies here.
+ // This arguably isn't the worst thing ever, since:
+ // - We generally query things in a top-down order, so if we got below D
+ // without needing cache entries for {C, MemLoc}, then chances are
+ // that those cache entries would end up ultimately unused.
+ // - We still cache things for A, so C only needs to walk up a bit.
+ // If this behavior becomes problematic, we can fix without a ton of extra
+ // work.
+ if (!VisitedPhis.insert({Node.Last, Node.Loc}).second)
+ continue;
+
+ const MemoryAccess *SkipStopWhere = nullptr;
+ if (Query->SkipSelfAccess && Node.Loc == Query->StartingLoc) {
+ assert(isa<MemoryDef>(Query->OriginalAccess));
+ SkipStopWhere = Query->OriginalAccess;
+ }
+
+ UpwardsWalkResult Res = walkToPhiOrClobber(Node,
+ /*StopAt=*/StopWhere,
+ /*SkipStopAt=*/SkipStopWhere);
+ if (Res.IsKnownClobber) {
+ assert(Res.Result != StopWhere && Res.Result != SkipStopWhere);
+
+ // If this wasn't a cache hit, we hit a clobber when walking. That's a
+ // failure.
+ TerminatedPath Term{Res.Result, PathIndex};
+ if (!MSSA.dominates(Res.Result, StopWhere))
+ return Term;
+
+ // Otherwise, it's a valid thing to potentially optimize to.
+ Terminated.push_back(Term);
+ continue;
+ }
+
+ if (Res.Result == StopWhere || Res.Result == SkipStopWhere) {
+ // We've hit our target. Save this path off for if we want to continue
+ // walking. If we are in the mode of skipping the OriginalAccess, and
+ // we've reached back to the OriginalAccess, do not save path, we've
+ // just looped back to self.
+ if (Res.Result != SkipStopWhere)
+ NewPaused.push_back(PathIndex);
+ continue;
+ }
+
+ assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber");
+ addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex);
+ }
+
+ return None;
+ }
+
+ template <typename T, typename Walker>
+ struct generic_def_path_iterator
+ : public iterator_facade_base<generic_def_path_iterator<T, Walker>,
+ std::forward_iterator_tag, T *> {
+ generic_def_path_iterator() {}
+ generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {}
+
+ T &operator*() const { return curNode(); }
+
+ generic_def_path_iterator &operator++() {
+ N = curNode().Previous;
+ return *this;
+ }
+
+ bool operator==(const generic_def_path_iterator &O) const {
+ if (N.hasValue() != O.N.hasValue())
+ return false;
+ return !N.hasValue() || *N == *O.N;
+ }
+
+ private:
+ T &curNode() const { return W->Paths[*N]; }
+
+ Walker *W = nullptr;
+ Optional<ListIndex> N = None;
+ };
+
+ using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>;
+ using const_def_path_iterator =
+ generic_def_path_iterator<const DefPath, const ClobberWalker>;
+
+ iterator_range<def_path_iterator> def_path(ListIndex From) {
+ return make_range(def_path_iterator(this, From), def_path_iterator());
+ }
+
+ iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const {
+ return make_range(const_def_path_iterator(this, From),
+ const_def_path_iterator());
+ }
+
+ struct OptznResult {
+ /// The path that contains our result.
+ TerminatedPath PrimaryClobber;
+ /// The paths that we can legally cache back from, but that aren't
+ /// necessarily the result of the Phi optimization.
+ SmallVector<TerminatedPath, 4> OtherClobbers;
+ };
+
+ ListIndex defPathIndex(const DefPath &N) const {
+ // The assert looks nicer if we don't need to do &N
+ const DefPath *NP = &N;
+ assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() &&
+ "Out of bounds DefPath!");
+ return NP - &Paths.front();
+ }
+
+ /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths
+ /// that act as legal clobbers. Note that this won't return *all* clobbers.
+ ///
+ /// Phi optimization algorithm tl;dr:
+ /// - Find the earliest def/phi, A, we can optimize to
+ /// - Find if all paths from the starting memory access ultimately reach A
+ /// - If not, optimization isn't possible.
+ /// - Otherwise, walk from A to another clobber or phi, A'.
+ /// - If A' is a def, we're done.
+ /// - If A' is a phi, try to optimize it.
+ ///
+ /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path
+ /// terminates when a MemoryAccess that clobbers said MemoryLocation is found.
+ OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start,
+ const MemoryLocation &Loc) {
+ assert(Paths.empty() && VisitedPhis.empty() &&
+ "Reset the optimization state.");
+
+ Paths.emplace_back(Loc, Start, Phi, None);
+ // Stores how many "valid" optimization nodes we had prior to calling
+ // addSearches/getBlockingAccess. Necessary for caching if we had a blocker.
+ auto PriorPathsSize = Paths.size();
+
+ SmallVector<ListIndex, 16> PausedSearches;
+ SmallVector<ListIndex, 8> NewPaused;
+ SmallVector<TerminatedPath, 4> TerminatedPaths;
+
+ addSearches(Phi, PausedSearches, 0);
+
+ // Moves the TerminatedPath with the "most dominated" Clobber to the end of
+ // Paths.
+ auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) {
+ assert(!Paths.empty() && "Need a path to move");
+ auto Dom = Paths.begin();
+ for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I)
+ if (!MSSA.dominates(I->Clobber, Dom->Clobber))
+ Dom = I;
+ auto Last = Paths.end() - 1;
+ if (Last != Dom)
+ std::iter_swap(Last, Dom);
+ };
+
+ MemoryPhi *Current = Phi;
+ while (true) {
+ assert(!MSSA.isLiveOnEntryDef(Current) &&
+ "liveOnEntry wasn't treated as a clobber?");
+
+ const auto *Target = getWalkTarget(Current);
+ // If a TerminatedPath doesn't dominate Target, then it wasn't a legal
+ // optimization for the prior phi.
+ assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) {
+ return MSSA.dominates(P.Clobber, Target);
+ }));
+
+ // FIXME: This is broken, because the Blocker may be reported to be
+ // liveOnEntry, and we'll happily wait for that to disappear (read: never)
+ // For the moment, this is fine, since we do nothing with blocker info.
+ if (Optional<TerminatedPath> Blocker = getBlockingAccess(
+ Target, PausedSearches, NewPaused, TerminatedPaths)) {
+
+ // Find the node we started at. We can't search based on N->Last, since
+ // we may have gone around a loop with a different MemoryLocation.
+ auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) {
+ return defPathIndex(N) < PriorPathsSize;
+ });
+ assert(Iter != def_path_iterator());
+
+ DefPath &CurNode = *Iter;
+ assert(CurNode.Last == Current);
+
+ // Two things:
+ // A. We can't reliably cache all of NewPaused back. Consider a case
+ // where we have two paths in NewPaused; one of which can't optimize
+ // above this phi, whereas the other can. If we cache the second path
+ // back, we'll end up with suboptimal cache entries. We can handle
+ // cases like this a bit better when we either try to find all
+ // clobbers that block phi optimization, or when our cache starts
+ // supporting unfinished searches.
+ // B. We can't reliably cache TerminatedPaths back here without doing
+ // extra checks; consider a case like:
+ // T
+ // / \
+ // D C
+ // \ /
+ // S
+ // Where T is our target, C is a node with a clobber on it, D is a
+ // diamond (with a clobber *only* on the left or right node, N), and
+ // S is our start. Say we walk to D, through the node opposite N
+ // (read: ignoring the clobber), and see a cache entry in the top
+ // node of D. That cache entry gets put into TerminatedPaths. We then
+ // walk up to C (N is later in our worklist), find the clobber, and
+ // quit. If we append TerminatedPaths to OtherClobbers, we'll cache
+ // the bottom part of D to the cached clobber, ignoring the clobber
+ // in N. Again, this problem goes away if we start tracking all
+ // blockers for a given phi optimization.
+ TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)};
+ return {Result, {}};
+ }
+
+ // If there's nothing left to search, then all paths led to valid clobbers
+ // that we got from our cache; pick the nearest to the start, and allow
+ // the rest to be cached back.
+ if (NewPaused.empty()) {
+ MoveDominatedPathToEnd(TerminatedPaths);
+ TerminatedPath Result = TerminatedPaths.pop_back_val();
+ return {Result, std::move(TerminatedPaths)};
+ }
+
+ MemoryAccess *DefChainEnd = nullptr;
+ SmallVector<TerminatedPath, 4> Clobbers;
+ for (ListIndex Paused : NewPaused) {
+ UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]);
+ if (WR.IsKnownClobber)
+ Clobbers.push_back({WR.Result, Paused});
+ else
+ // Micro-opt: If we hit the end of the chain, save it.
+ DefChainEnd = WR.Result;
+ }
+
+ if (!TerminatedPaths.empty()) {
+ // If we couldn't find the dominating phi/liveOnEntry in the above loop,
+ // do it now.
+ if (!DefChainEnd)
+ for (auto *MA : def_chain(const_cast<MemoryAccess *>(Target)))
+ DefChainEnd = MA;
+ assert(DefChainEnd && "Failed to find dominating phi/liveOnEntry");
+
+ // If any of the terminated paths don't dominate the phi we'll try to
+ // optimize, we need to figure out what they are and quit.
+ const BasicBlock *ChainBB = DefChainEnd->getBlock();
+ for (const TerminatedPath &TP : TerminatedPaths) {
+ // Because we know that DefChainEnd is as "high" as we can go, we
+ // don't need local dominance checks; BB dominance is sufficient.
+ if (DT.dominates(ChainBB, TP.Clobber->getBlock()))
+ Clobbers.push_back(TP);
+ }
+ }
+
+ // If we have clobbers in the def chain, find the one closest to Current
+ // and quit.
+ if (!Clobbers.empty()) {
+ MoveDominatedPathToEnd(Clobbers);
+ TerminatedPath Result = Clobbers.pop_back_val();
+ return {Result, std::move(Clobbers)};
+ }
+
+ assert(all_of(NewPaused,
+ [&](ListIndex I) { return Paths[I].Last == DefChainEnd; }));
+
+ // Because liveOnEntry is a clobber, this must be a phi.
+ auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd);
+
+ PriorPathsSize = Paths.size();
+ PausedSearches.clear();
+ for (ListIndex I : NewPaused)
+ addSearches(DefChainPhi, PausedSearches, I);
+ NewPaused.clear();
+
+ Current = DefChainPhi;
+ }
+ }
+
+ void verifyOptResult(const OptznResult &R) const {
+ assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) {
+ return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber);
+ }));
+ }
+
+ void resetPhiOptznState() {
+ Paths.clear();
+ VisitedPhis.clear();
+ }
+
+public:
+ ClobberWalker(const MemorySSA &MSSA, AliasAnalysisType &AA, DominatorTree &DT)
+ : MSSA(MSSA), AA(AA), DT(DT) {}
+
+ AliasAnalysisType *getAA() { return &AA; }
+ /// Finds the nearest clobber for the given query, optimizing phis if
+ /// possible.
+ MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q,
+ unsigned &UpWalkLimit) {
+ Query = &Q;
+ UpwardWalkLimit = &UpWalkLimit;
+ // Starting limit must be > 0.
+ if (!UpWalkLimit)
+ UpWalkLimit++;
+
+ MemoryAccess *Current = Start;
+ // This walker pretends uses don't exist. If we're handed one, silently grab
+ // its def. (This has the nice side-effect of ensuring we never cache uses)
+ if (auto *MU = dyn_cast<MemoryUse>(Start))
+ Current = MU->getDefiningAccess();
+
+ DefPath FirstDesc(Q.StartingLoc, Current, Current, None);
+ // Fast path for the overly-common case (no crazy phi optimization
+ // necessary)
+ UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc);
+ MemoryAccess *Result;
+ if (WalkResult.IsKnownClobber) {
+ Result = WalkResult.Result;
+ Q.AR = WalkResult.AR;
+ } else {
+ OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last),
+ Current, Q.StartingLoc);
+ verifyOptResult(OptRes);
+ resetPhiOptznState();
+ Result = OptRes.PrimaryClobber.Clobber;
+ }
+
+#ifdef EXPENSIVE_CHECKS
+ if (!Q.SkipSelfAccess && *UpwardWalkLimit > 0)
+ checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA);
+#endif
+ return Result;
+ }
+};
+
+struct RenamePassData {
+ DomTreeNode *DTN;
+ DomTreeNode::const_iterator ChildIt;
+ MemoryAccess *IncomingVal;
+
+ RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It,
+ MemoryAccess *M)
+ : DTN(D), ChildIt(It), IncomingVal(M) {}
+
+ void swap(RenamePassData &RHS) {
+ std::swap(DTN, RHS.DTN);
+ std::swap(ChildIt, RHS.ChildIt);
+ std::swap(IncomingVal, RHS.IncomingVal);
+ }
+};
+
+} // end anonymous namespace
+
+namespace llvm {
+
+template <class AliasAnalysisType> class MemorySSA::ClobberWalkerBase {
+ ClobberWalker<AliasAnalysisType> Walker;
+ MemorySSA *MSSA;
+
+public:
+ ClobberWalkerBase(MemorySSA *M, AliasAnalysisType *A, DominatorTree *D)
+ : Walker(*M, *A, *D), MSSA(M) {}
+
+ MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *,
+ const MemoryLocation &,
+ unsigned &);
+ // Third argument (bool), defines whether the clobber search should skip the
+ // original queried access. If true, there will be a follow-up query searching
+ // for a clobber access past "self". Note that the Optimized access is not
+ // updated if a new clobber is found by this SkipSelf search. If this
+ // additional query becomes heavily used we may decide to cache the result.
+ // Walker instantiations will decide how to set the SkipSelf bool.
+ MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *, unsigned &, bool);
+};
+
+/// A MemorySSAWalker that does AA walks to disambiguate accesses. It no
+/// longer does caching on its own, but the name has been retained for the
+/// moment.
+template <class AliasAnalysisType>
+class MemorySSA::CachingWalker final : public MemorySSAWalker {
+ ClobberWalkerBase<AliasAnalysisType> *Walker;
+
+public:
+ CachingWalker(MemorySSA *M, ClobberWalkerBase<AliasAnalysisType> *W)
+ : MemorySSAWalker(M), Walker(W) {}
+ ~CachingWalker() override = default;
+
+ using MemorySSAWalker::getClobberingMemoryAccess;
+
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, unsigned &UWL) {
+ return Walker->getClobberingMemoryAccessBase(MA, UWL, false);
+ }
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA,
+ const MemoryLocation &Loc,
+ unsigned &UWL) {
+ return Walker->getClobberingMemoryAccessBase(MA, Loc, UWL);
+ }
+
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override {
+ unsigned UpwardWalkLimit = MaxCheckLimit;
+ return getClobberingMemoryAccess(MA, UpwardWalkLimit);
+ }
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA,
+ const MemoryLocation &Loc) override {
+ unsigned UpwardWalkLimit = MaxCheckLimit;
+ return getClobberingMemoryAccess(MA, Loc, UpwardWalkLimit);
+ }
+
+ void invalidateInfo(MemoryAccess *MA) override {
+ if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA))
+ MUD->resetOptimized();
+ }
+};
+
+template <class AliasAnalysisType>
+class MemorySSA::SkipSelfWalker final : public MemorySSAWalker {
+ ClobberWalkerBase<AliasAnalysisType> *Walker;
+
+public:
+ SkipSelfWalker(MemorySSA *M, ClobberWalkerBase<AliasAnalysisType> *W)
+ : MemorySSAWalker(M), Walker(W) {}
+ ~SkipSelfWalker() override = default;
+
+ using MemorySSAWalker::getClobberingMemoryAccess;
+
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, unsigned &UWL) {
+ return Walker->getClobberingMemoryAccessBase(MA, UWL, true);
+ }
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA,
+ const MemoryLocation &Loc,
+ unsigned &UWL) {
+ return Walker->getClobberingMemoryAccessBase(MA, Loc, UWL);
+ }
+
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override {
+ unsigned UpwardWalkLimit = MaxCheckLimit;
+ return getClobberingMemoryAccess(MA, UpwardWalkLimit);
+ }
+ MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA,
+ const MemoryLocation &Loc) override {
+ unsigned UpwardWalkLimit = MaxCheckLimit;
+ return getClobberingMemoryAccess(MA, Loc, UpwardWalkLimit);
+ }
+
+ void invalidateInfo(MemoryAccess *MA) override {
+ if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA))
+ MUD->resetOptimized();
+ }
+};
+
+} // end namespace llvm
+
+void MemorySSA::renameSuccessorPhis(BasicBlock *BB, MemoryAccess *IncomingVal,
+ bool RenameAllUses) {
+ // Pass through values to our successors
+ for (const BasicBlock *S : successors(BB)) {
+ auto It = PerBlockAccesses.find(S);
+ // Rename the phi nodes in our successor block
+ if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front()))
+ continue;
+ AccessList *Accesses = It->second.get();
+ auto *Phi = cast<MemoryPhi>(&Accesses->front());
+ if (RenameAllUses) {
+ bool ReplacementDone = false;
+ for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
+ if (Phi->getIncomingBlock(I) == BB) {
+ Phi->setIncomingValue(I, IncomingVal);
+ ReplacementDone = true;
+ }
+ (void) ReplacementDone;
+ assert(ReplacementDone && "Incomplete phi during partial rename");
+ } else
+ Phi->addIncoming(IncomingVal, BB);
+ }
+}
+
+/// Rename a single basic block into MemorySSA form.
+/// Uses the standard SSA renaming algorithm.
+/// \returns The new incoming value.
+MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, MemoryAccess *IncomingVal,
+ bool RenameAllUses) {
+ auto It = PerBlockAccesses.find(BB);
+ // Skip most processing if the list is empty.
+ if (It != PerBlockAccesses.end()) {
+ AccessList *Accesses = It->second.get();
+ for (MemoryAccess &L : *Accesses) {
+ if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) {
+ if (MUD->getDefiningAccess() == nullptr || RenameAllUses)
+ MUD->setDefiningAccess(IncomingVal);
+ if (isa<MemoryDef>(&L))
+ IncomingVal = &L;
+ } else {
+ IncomingVal = &L;
+ }
+ }
+ }
+ return IncomingVal;
+}
+
+/// This is the standard SSA renaming algorithm.
+///
+/// We walk the dominator tree in preorder, renaming accesses, and then filling
+/// in phi nodes in our successors.
+void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal,
+ SmallPtrSetImpl<BasicBlock *> &Visited,
+ bool SkipVisited, bool RenameAllUses) {
+ assert(Root && "Trying to rename accesses in an unreachable block");
+
+ SmallVector<RenamePassData, 32> WorkStack;
+ // Skip everything if we already renamed this block and we are skipping.
+ // Note: You can't sink this into the if, because we need it to occur
+ // regardless of whether we skip blocks or not.
+ bool AlreadyVisited = !Visited.insert(Root->getBlock()).second;
+ if (SkipVisited && AlreadyVisited)
+ return;
+
+ IncomingVal = renameBlock(Root->getBlock(), IncomingVal, RenameAllUses);
+ renameSuccessorPhis(Root->getBlock(), IncomingVal, RenameAllUses);
+ WorkStack.push_back({Root, Root->begin(), IncomingVal});
+
+ while (!WorkStack.empty()) {
+ DomTreeNode *Node = WorkStack.back().DTN;
+ DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt;
+ IncomingVal = WorkStack.back().IncomingVal;
+
+ if (ChildIt == Node->end()) {
+ WorkStack.pop_back();
+ } else {
+ DomTreeNode *Child = *ChildIt;
+ ++WorkStack.back().ChildIt;
+ BasicBlock *BB = Child->getBlock();
+ // Note: You can't sink this into the if, because we need it to occur
+ // regardless of whether we skip blocks or not.
+ AlreadyVisited = !Visited.insert(BB).second;
+ if (SkipVisited && AlreadyVisited) {
+ // We already visited this during our renaming, which can happen when
+ // being asked to rename multiple blocks. Figure out the incoming val,
+ // which is the last def.
+ // Incoming value can only change if there is a block def, and in that
+ // case, it's the last block def in the list.
+ if (auto *BlockDefs = getWritableBlockDefs(BB))
+ IncomingVal = &*BlockDefs->rbegin();
+ } else
+ IncomingVal = renameBlock(BB, IncomingVal, RenameAllUses);
+ renameSuccessorPhis(BB, IncomingVal, RenameAllUses);
+ WorkStack.push_back({Child, Child->begin(), IncomingVal});
+ }
+ }
+}
+
+/// This handles unreachable block accesses by deleting phi nodes in
+/// unreachable blocks, and marking all other unreachable MemoryAccess's as
+/// being uses of the live on entry definition.
+void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) {
+ assert(!DT->isReachableFromEntry(BB) &&
+ "Reachable block found while handling unreachable blocks");
+
+ // Make sure phi nodes in our reachable successors end up with a
+ // LiveOnEntryDef for our incoming edge, even though our block is forward
+ // unreachable. We could just disconnect these blocks from the CFG fully,
+ // but we do not right now.
+ for (const BasicBlock *S : successors(BB)) {
+ if (!DT->isReachableFromEntry(S))
+ continue;
+ auto It = PerBlockAccesses.find(S);
+ // Rename the phi nodes in our successor block
+ if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front()))
+ continue;
+ AccessList *Accesses = It->second.get();
+ auto *Phi = cast<MemoryPhi>(&Accesses->front());
+ Phi->addIncoming(LiveOnEntryDef.get(), BB);
+ }
+
+ auto It = PerBlockAccesses.find(BB);
+ if (It == PerBlockAccesses.end())
+ return;
+
+ auto &Accesses = It->second;
+ for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) {
+ auto Next = std::next(AI);
+ // If we have a phi, just remove it. We are going to replace all
+ // users with live on entry.
+ if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI))
+ UseOrDef->setDefiningAccess(LiveOnEntryDef.get());
+ else
+ Accesses->erase(AI);
+ AI = Next;
+ }
+}
+
+MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT)
+ : AA(nullptr), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr),
+ SkipWalker(nullptr), NextID(0) {
+ // Build MemorySSA using a batch alias analysis. This reuses the internal
+ // state that AA collects during an alias()/getModRefInfo() call. This is
+ // safe because there are no CFG changes while building MemorySSA and can
+ // significantly reduce the time spent by the compiler in AA, because we will
+ // make queries about all the instructions in the Function.
+ BatchAAResults BatchAA(*AA);
+ buildMemorySSA(BatchAA);
+ // Intentionally leave AA to nullptr while building so we don't accidently
+ // use non-batch AliasAnalysis.
+ this->AA = AA;
+ // Also create the walker here.
+ getWalker();
+}
+
+MemorySSA::~MemorySSA() {
+ // Drop all our references
+ for (const auto &Pair : PerBlockAccesses)
+ for (MemoryAccess &MA : *Pair.second)
+ MA.dropAllReferences();
+}
+
+MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) {
+ auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr));
+
+ if (Res.second)
+ Res.first->second = std::make_unique<AccessList>();
+ return Res.first->second.get();
+}
+
+MemorySSA::DefsList *MemorySSA::getOrCreateDefsList(const BasicBlock *BB) {
+ auto Res = PerBlockDefs.insert(std::make_pair(BB, nullptr));
+
+ if (Res.second)
+ Res.first->second = std::make_unique<DefsList>();
+ return Res.first->second.get();
+}
+
+namespace llvm {
+
+/// This class is a batch walker of all MemoryUse's in the program, and points
+/// their defining access at the thing that actually clobbers them. Because it
+/// is a batch walker that touches everything, it does not operate like the
+/// other walkers. This walker is basically performing a top-down SSA renaming
+/// pass, where the version stack is used as the cache. This enables it to be
+/// significantly more time and memory efficient than using the regular walker,
+/// which is walking bottom-up.
+class MemorySSA::OptimizeUses {
+public:
+ OptimizeUses(MemorySSA *MSSA, CachingWalker<BatchAAResults> *Walker,
+ BatchAAResults *BAA, DominatorTree *DT)
+ : MSSA(MSSA), Walker(Walker), AA(BAA), DT(DT) {}
+
+ void optimizeUses();
+
+private:
+ /// This represents where a given memorylocation is in the stack.
+ struct MemlocStackInfo {
+ // This essentially is keeping track of versions of the stack. Whenever
+ // the stack changes due to pushes or pops, these versions increase.
+ unsigned long StackEpoch;
+ unsigned long PopEpoch;
+ // This is the lower bound of places on the stack to check. It is equal to
+ // the place the last stack walk ended.
+ // Note: Correctness depends on this being initialized to 0, which densemap
+ // does
+ unsigned long LowerBound;
+ const BasicBlock *LowerBoundBlock;
+ // This is where the last walk for this memory location ended.
+ unsigned long LastKill;
+ bool LastKillValid;
+ Optional<AliasResult> AR;
+ };
+
+ void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &,
+ SmallVectorImpl<MemoryAccess *> &,
+ DenseMap<MemoryLocOrCall, MemlocStackInfo> &);
+
+ MemorySSA *MSSA;
+ CachingWalker<BatchAAResults> *Walker;
+ BatchAAResults *AA;
+ DominatorTree *DT;
+};
+
+} // end namespace llvm
+
+/// Optimize the uses in a given block This is basically the SSA renaming
+/// algorithm, with one caveat: We are able to use a single stack for all
+/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is
+/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just
+/// going to be some position in that stack of possible ones.
+///
+/// We track the stack positions that each MemoryLocation needs
+/// to check, and last ended at. This is because we only want to check the
+/// things that changed since last time. The same MemoryLocation should
+/// get clobbered by the same store (getModRefInfo does not use invariantness or
+/// things like this, and if they start, we can modify MemoryLocOrCall to
+/// include relevant data)
+void MemorySSA::OptimizeUses::optimizeUsesInBlock(
+ const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch,
+ SmallVectorImpl<MemoryAccess *> &VersionStack,
+ DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) {
+
+ /// If no accesses, nothing to do.
+ MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB);
+ if (Accesses == nullptr)
+ return;
+
+ // Pop everything that doesn't dominate the current block off the stack,
+ // increment the PopEpoch to account for this.
+ while (true) {
+ assert(
+ !VersionStack.empty() &&
+ "Version stack should have liveOnEntry sentinel dominating everything");
+ BasicBlock *BackBlock = VersionStack.back()->getBlock();
+ if (DT->dominates(BackBlock, BB))
+ break;
+ while (VersionStack.back()->getBlock() == BackBlock)
+ VersionStack.pop_back();
+ ++PopEpoch;
+ }
+
+ for (MemoryAccess &MA : *Accesses) {
+ auto *MU = dyn_cast<MemoryUse>(&MA);
+ if (!MU) {
+ VersionStack.push_back(&MA);
+ ++StackEpoch;
+ continue;
+ }
+
+ if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) {
+ MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true, None);
+ continue;
+ }
+
+ MemoryLocOrCall UseMLOC(MU);
+ auto &LocInfo = LocStackInfo[UseMLOC];
+ // If the pop epoch changed, it means we've removed stuff from top of
+ // stack due to changing blocks. We may have to reset the lower bound or
+ // last kill info.
+ if (LocInfo.PopEpoch != PopEpoch) {
+ LocInfo.PopEpoch = PopEpoch;
+ LocInfo.StackEpoch = StackEpoch;
+ // If the lower bound was in something that no longer dominates us, we
+ // have to reset it.
+ // We can't simply track stack size, because the stack may have had
+ // pushes/pops in the meantime.
+ // XXX: This is non-optimal, but only is slower cases with heavily
+ // branching dominator trees. To get the optimal number of queries would
+ // be to make lowerbound and lastkill a per-loc stack, and pop it until
+ // the top of that stack dominates us. This does not seem worth it ATM.
+ // A much cheaper optimization would be to always explore the deepest
+ // branch of the dominator tree first. This will guarantee this resets on
+ // the smallest set of blocks.
+ if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB &&
+ !DT->dominates(LocInfo.LowerBoundBlock, BB)) {
+ // Reset the lower bound of things to check.
+ // TODO: Some day we should be able to reset to last kill, rather than
+ // 0.
+ LocInfo.LowerBound = 0;
+ LocInfo.LowerBoundBlock = VersionStack[0]->getBlock();
+ LocInfo.LastKillValid = false;
+ }
+ } else if (LocInfo.StackEpoch != StackEpoch) {
+ // If all that has changed is the StackEpoch, we only have to check the
+ // new things on the stack, because we've checked everything before. In
+ // this case, the lower bound of things to check remains the same.
+ LocInfo.PopEpoch = PopEpoch;
+ LocInfo.StackEpoch = StackEpoch;
+ }
+ if (!LocInfo.LastKillValid) {
+ LocInfo.LastKill = VersionStack.size() - 1;
+ LocInfo.LastKillValid = true;
+ LocInfo.AR = MayAlias;
+ }
+
+ // At this point, we should have corrected last kill and LowerBound to be
+ // in bounds.
+ assert(LocInfo.LowerBound < VersionStack.size() &&
+ "Lower bound out of range");
+ assert(LocInfo.LastKill < VersionStack.size() &&
+ "Last kill info out of range");
+ // In any case, the new upper bound is the top of the stack.
+ unsigned long UpperBound = VersionStack.size() - 1;
+
+ if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) {
+ LLVM_DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " ("
+ << *(MU->getMemoryInst()) << ")"
+ << " because there are "
+ << UpperBound - LocInfo.LowerBound
+ << " stores to disambiguate\n");
+ // Because we did not walk, LastKill is no longer valid, as this may
+ // have been a kill.
+ LocInfo.LastKillValid = false;
+ continue;
+ }
+ bool FoundClobberResult = false;
+ unsigned UpwardWalkLimit = MaxCheckLimit;
+ while (UpperBound > LocInfo.LowerBound) {
+ if (isa<MemoryPhi>(VersionStack[UpperBound])) {
+ // For phis, use the walker, see where we ended up, go there
+ MemoryAccess *Result =
+ Walker->getClobberingMemoryAccess(MU, UpwardWalkLimit);
+ // We are guaranteed to find it or something is wrong
+ while (VersionStack[UpperBound] != Result) {
+ assert(UpperBound != 0);
+ --UpperBound;
+ }
+ FoundClobberResult = true;
+ break;
+ }
+
+ MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]);
+ // If the lifetime of the pointer ends at this instruction, it's live on
+ // entry.
+ if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) {
+ // Reset UpperBound to liveOnEntryDef's place in the stack
+ UpperBound = 0;
+ FoundClobberResult = true;
+ LocInfo.AR = MustAlias;
+ break;
+ }
+ ClobberAlias CA = instructionClobbersQuery(MD, MU, UseMLOC, *AA);
+ if (CA.IsClobber) {
+ FoundClobberResult = true;
+ LocInfo.AR = CA.AR;
+ break;
+ }
+ --UpperBound;
+ }
+
+ // Note: Phis always have AliasResult AR set to MayAlias ATM.
+
+ // At the end of this loop, UpperBound is either a clobber, or lower bound
+ // PHI walking may cause it to be < LowerBound, and in fact, < LastKill.
+ if (FoundClobberResult || UpperBound < LocInfo.LastKill) {
+ // We were last killed now by where we got to
+ if (MSSA->isLiveOnEntryDef(VersionStack[UpperBound]))
+ LocInfo.AR = None;
+ MU->setDefiningAccess(VersionStack[UpperBound], true, LocInfo.AR);
+ LocInfo.LastKill = UpperBound;
+ } else {
+ // Otherwise, we checked all the new ones, and now we know we can get to
+ // LastKill.
+ MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true, LocInfo.AR);
+ }
+ LocInfo.LowerBound = VersionStack.size() - 1;
+ LocInfo.LowerBoundBlock = BB;
+ }
+}
+
+/// Optimize uses to point to their actual clobbering definitions.
+void MemorySSA::OptimizeUses::optimizeUses() {
+ SmallVector<MemoryAccess *, 16> VersionStack;
+ DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo;
+ VersionStack.push_back(MSSA->getLiveOnEntryDef());
+
+ unsigned long StackEpoch = 1;
+ unsigned long PopEpoch = 1;
+ // We perform a non-recursive top-down dominator tree walk.
+ for (const auto *DomNode : depth_first(DT->getRootNode()))
+ optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack,
+ LocStackInfo);
+}
+
+void MemorySSA::placePHINodes(
+ const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks) {
+ // Determine where our MemoryPhi's should go
+ ForwardIDFCalculator IDFs(*DT);
+ IDFs.setDefiningBlocks(DefiningBlocks);
+ SmallVector<BasicBlock *, 32> IDFBlocks;
+ IDFs.calculate(IDFBlocks);
+
+ // Now place MemoryPhi nodes.
+ for (auto &BB : IDFBlocks)
+ createMemoryPhi(BB);
+}
+
+void MemorySSA::buildMemorySSA(BatchAAResults &BAA) {
+ // We create an access to represent "live on entry", for things like
+ // arguments or users of globals, where the memory they use is defined before
+ // the beginning of the function. We do not actually insert it into the IR.
+ // We do not define a live on exit for the immediate uses, and thus our
+ // semantics do *not* imply that something with no immediate uses can simply
+ // be removed.
+ BasicBlock &StartingPoint = F.getEntryBlock();
+ LiveOnEntryDef.reset(new MemoryDef(F.getContext(), nullptr, nullptr,
+ &StartingPoint, NextID++));
+
+ // We maintain lists of memory accesses per-block, trading memory for time. We
+ // could just look up the memory access for every possible instruction in the
+ // stream.
+ SmallPtrSet<BasicBlock *, 32> DefiningBlocks;
+ // Go through each block, figure out where defs occur, and chain together all
+ // the accesses.
+ for (BasicBlock &B : F) {
+ bool InsertIntoDef = false;
+ AccessList *Accesses = nullptr;
+ DefsList *Defs = nullptr;
+ for (Instruction &I : B) {
+ MemoryUseOrDef *MUD = createNewAccess(&I, &BAA);
+ if (!MUD)
+ continue;
+
+ if (!Accesses)
+ Accesses = getOrCreateAccessList(&B);
+ Accesses->push_back(MUD);
+ if (isa<MemoryDef>(MUD)) {
+ InsertIntoDef = true;
+ if (!Defs)
+ Defs = getOrCreateDefsList(&B);
+ Defs->push_back(*MUD);
+ }
+ }
+ if (InsertIntoDef)
+ DefiningBlocks.insert(&B);
+ }
+ placePHINodes(DefiningBlocks);
+
+ // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get
+ // filled in with all blocks.
+ SmallPtrSet<BasicBlock *, 16> Visited;
+ renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited);
+
+ ClobberWalkerBase<BatchAAResults> WalkerBase(this, &BAA, DT);
+ CachingWalker<BatchAAResults> WalkerLocal(this, &WalkerBase);
+ OptimizeUses(this, &WalkerLocal, &BAA, DT).optimizeUses();
+
+ // Mark the uses in unreachable blocks as live on entry, so that they go
+ // somewhere.
+ for (auto &BB : F)
+ if (!Visited.count(&BB))
+ markUnreachableAsLiveOnEntry(&BB);
+}
+
+MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); }
+
+MemorySSA::CachingWalker<AliasAnalysis> *MemorySSA::getWalkerImpl() {
+ if (Walker)
+ return Walker.get();
+
+ if (!WalkerBase)
+ WalkerBase =
+ std::make_unique<ClobberWalkerBase<AliasAnalysis>>(this, AA, DT);
+
+ Walker =
+ std::make_unique<CachingWalker<AliasAnalysis>>(this, WalkerBase.get());
+ return Walker.get();
+}
+
+MemorySSAWalker *MemorySSA::getSkipSelfWalker() {
+ if (SkipWalker)
+ return SkipWalker.get();
+
+ if (!WalkerBase)
+ WalkerBase =
+ std::make_unique<ClobberWalkerBase<AliasAnalysis>>(this, AA, DT);
+
+ SkipWalker =
+ std::make_unique<SkipSelfWalker<AliasAnalysis>>(this, WalkerBase.get());
+ return SkipWalker.get();
+ }
+
+
+// This is a helper function used by the creation routines. It places NewAccess
+// into the access and defs lists for a given basic block, at the given
+// insertion point.
+void MemorySSA::insertIntoListsForBlock(MemoryAccess *NewAccess,
+ const BasicBlock *BB,
+ InsertionPlace Point) {
+ auto *Accesses = getOrCreateAccessList(BB);
+ if (Point == Beginning) {
+ // If it's a phi node, it goes first, otherwise, it goes after any phi
+ // nodes.
+ if (isa<MemoryPhi>(NewAccess)) {
+ Accesses->push_front(NewAccess);
+ auto *Defs = getOrCreateDefsList(BB);
+ Defs->push_front(*NewAccess);
+ } else {
+ auto AI = find_if_not(
+ *Accesses, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); });
+ Accesses->insert(AI, NewAccess);
+ if (!isa<MemoryUse>(NewAccess)) {
+ auto *Defs = getOrCreateDefsList(BB);
+ auto DI = find_if_not(
+ *Defs, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); });
+ Defs->insert(DI, *NewAccess);
+ }
+ }
+ } else {
+ Accesses->push_back(NewAccess);
+ if (!isa<MemoryUse>(NewAccess)) {
+ auto *Defs = getOrCreateDefsList(BB);
+ Defs->push_back(*NewAccess);
+ }
+ }
+ BlockNumberingValid.erase(BB);
+}
+
+void MemorySSA::insertIntoListsBefore(MemoryAccess *What, const BasicBlock *BB,
+ AccessList::iterator InsertPt) {
+ auto *Accesses = getWritableBlockAccesses(BB);
+ bool WasEnd = InsertPt == Accesses->end();
+ Accesses->insert(AccessList::iterator(InsertPt), What);
+ if (!isa<MemoryUse>(What)) {
+ auto *Defs = getOrCreateDefsList(BB);
+ // If we got asked to insert at the end, we have an easy job, just shove it
+ // at the end. If we got asked to insert before an existing def, we also get
+ // an iterator. If we got asked to insert before a use, we have to hunt for
+ // the next def.
+ if (WasEnd) {
+ Defs->push_back(*What);
+ } else if (isa<MemoryDef>(InsertPt)) {
+ Defs->insert(InsertPt->getDefsIterator(), *What);
+ } else {
+ while (InsertPt != Accesses->end() && !isa<MemoryDef>(InsertPt))
+ ++InsertPt;
+ // Either we found a def, or we are inserting at the end
+ if (InsertPt == Accesses->end())
+ Defs->push_back(*What);
+ else
+ Defs->insert(InsertPt->getDefsIterator(), *What);
+ }
+ }
+ BlockNumberingValid.erase(BB);
+}
+
+void MemorySSA::prepareForMoveTo(MemoryAccess *What, BasicBlock *BB) {
+ // Keep it in the lookup tables, remove from the lists
+ removeFromLists(What, false);
+
+ // Note that moving should implicitly invalidate the optimized state of a
+ // MemoryUse (and Phis can't be optimized). However, it doesn't do so for a
+ // MemoryDef.
+ if (auto *MD = dyn_cast<MemoryDef>(What))
+ MD->resetOptimized();
+ What->setBlock(BB);
+}
+
+// Move What before Where in the IR. The end result is that What will belong to
+// the right lists and have the right Block set, but will not otherwise be
+// correct. It will not have the right defining access, and if it is a def,
+// things below it will not properly be updated.
+void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB,
+ AccessList::iterator Where) {
+ prepareForMoveTo(What, BB);
+ insertIntoListsBefore(What, BB, Where);
+}
+
+void MemorySSA::moveTo(MemoryAccess *What, BasicBlock *BB,
+ InsertionPlace Point) {
+ if (isa<MemoryPhi>(What)) {
+ assert(Point == Beginning &&
+ "Can only move a Phi at the beginning of the block");
+ // Update lookup table entry
+ ValueToMemoryAccess.erase(What->getBlock());
+ bool Inserted = ValueToMemoryAccess.insert({BB, What}).second;
+ (void)Inserted;
+ assert(Inserted && "Cannot move a Phi to a block that already has one");
+ }
+
+ prepareForMoveTo(What, BB);
+ insertIntoListsForBlock(What, BB, Point);
+}
+
+MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) {
+ assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB");
+ MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++);
+ // Phi's always are placed at the front of the block.
+ insertIntoListsForBlock(Phi, BB, Beginning);
+ ValueToMemoryAccess[BB] = Phi;
+ return Phi;
+}
+
+MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I,
+ MemoryAccess *Definition,
+ const MemoryUseOrDef *Template,
+ bool CreationMustSucceed) {
+ assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI");
+ MemoryUseOrDef *NewAccess = createNewAccess(I, AA, Template);
+ if (CreationMustSucceed)
+ assert(NewAccess != nullptr && "Tried to create a memory access for a "
+ "non-memory touching instruction");
+ if (NewAccess)
+ NewAccess->setDefiningAccess(Definition);
+ return NewAccess;
+}
+
+// Return true if the instruction has ordering constraints.
+// Note specifically that this only considers stores and loads
+// because others are still considered ModRef by getModRefInfo.
+static inline bool isOrdered(const Instruction *I) {
+ if (auto *SI = dyn_cast<StoreInst>(I)) {
+ if (!SI->isUnordered())
+ return true;
+ } else if (auto *LI = dyn_cast<LoadInst>(I)) {
+ if (!LI->isUnordered())
+ return true;
+ }
+ return false;
+}
+
+/// Helper function to create new memory accesses
+template <typename AliasAnalysisType>
+MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I,
+ AliasAnalysisType *AAP,
+ const MemoryUseOrDef *Template) {
+ // The assume intrinsic has a control dependency which we model by claiming
+ // that it writes arbitrarily. Debuginfo intrinsics may be considered
+ // clobbers when we have a nonstandard AA pipeline. Ignore these fake memory
+ // dependencies here.
+ // FIXME: Replace this special casing with a more accurate modelling of
+ // assume's control dependency.
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
+ if (II->getIntrinsicID() == Intrinsic::assume)
+ return nullptr;
+
+ // Using a nonstandard AA pipelines might leave us with unexpected modref
+ // results for I, so add a check to not model instructions that may not read
+ // from or write to memory. This is necessary for correctness.
+ if (!I->mayReadFromMemory() && !I->mayWriteToMemory())
+ return nullptr;
+
+ bool Def, Use;
+ if (Template) {
+ Def = dyn_cast_or_null<MemoryDef>(Template) != nullptr;
+ Use = dyn_cast_or_null<MemoryUse>(Template) != nullptr;
+#if !defined(NDEBUG)
+ ModRefInfo ModRef = AAP->getModRefInfo(I, None);
+ bool DefCheck, UseCheck;
+ DefCheck = isModSet(ModRef) || isOrdered(I);
+ UseCheck = isRefSet(ModRef);
+ assert(Def == DefCheck && (Def || Use == UseCheck) && "Invalid template");
+#endif
+ } else {
+ // Find out what affect this instruction has on memory.
+ ModRefInfo ModRef = AAP->getModRefInfo(I, None);
+ // The isOrdered check is used to ensure that volatiles end up as defs
+ // (atomics end up as ModRef right now anyway). Until we separate the
+ // ordering chain from the memory chain, this enables people to see at least
+ // some relative ordering to volatiles. Note that getClobberingMemoryAccess
+ // will still give an answer that bypasses other volatile loads. TODO:
+ // Separate memory aliasing and ordering into two different chains so that
+ // we can precisely represent both "what memory will this read/write/is
+ // clobbered by" and "what instructions can I move this past".
+ Def = isModSet(ModRef) || isOrdered(I);
+ Use = isRefSet(ModRef);
+ }
+
+ // It's possible for an instruction to not modify memory at all. During
+ // construction, we ignore them.
+ if (!Def && !Use)
+ return nullptr;
+
+ MemoryUseOrDef *MUD;
+ if (Def)
+ MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++);
+ else
+ MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent());
+ ValueToMemoryAccess[I] = MUD;
+ return MUD;
+}
+
+/// Returns true if \p Replacer dominates \p Replacee .
+bool MemorySSA::dominatesUse(const MemoryAccess *Replacer,
+ const MemoryAccess *Replacee) const {
+ if (isa<MemoryUseOrDef>(Replacee))
+ return DT->dominates(Replacer->getBlock(), Replacee->getBlock());
+ const auto *MP = cast<MemoryPhi>(Replacee);
+ // For a phi node, the use occurs in the predecessor block of the phi node.
+ // Since we may occur multiple times in the phi node, we have to check each
+ // operand to ensure Replacer dominates each operand where Replacee occurs.
+ for (const Use &Arg : MP->operands()) {
+ if (Arg.get() != Replacee &&
+ !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg)))
+ return false;
+ }
+ return true;
+}
+
+/// Properly remove \p MA from all of MemorySSA's lookup tables.
+void MemorySSA::removeFromLookups(MemoryAccess *MA) {
+ assert(MA->use_empty() &&
+ "Trying to remove memory access that still has uses");
+ BlockNumbering.erase(MA);
+ if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA))
+ MUD->setDefiningAccess(nullptr);
+ // Invalidate our walker's cache if necessary
+ if (!isa<MemoryUse>(MA))
+ getWalker()->invalidateInfo(MA);
+
+ Value *MemoryInst;
+ if (const auto *MUD = dyn_cast<MemoryUseOrDef>(MA))
+ MemoryInst = MUD->getMemoryInst();
+ else
+ MemoryInst = MA->getBlock();
+
+ auto VMA = ValueToMemoryAccess.find(MemoryInst);
+ if (VMA->second == MA)
+ ValueToMemoryAccess.erase(VMA);
+}
+
+/// Properly remove \p MA from all of MemorySSA's lists.
+///
+/// Because of the way the intrusive list and use lists work, it is important to
+/// do removal in the right order.
+/// ShouldDelete defaults to true, and will cause the memory access to also be
+/// deleted, not just removed.
+void MemorySSA::removeFromLists(MemoryAccess *MA, bool ShouldDelete) {
+ BasicBlock *BB = MA->getBlock();
+ // The access list owns the reference, so we erase it from the non-owning list
+ // first.
+ if (!isa<MemoryUse>(MA)) {
+ auto DefsIt = PerBlockDefs.find(BB);
+ std::unique_ptr<DefsList> &Defs = DefsIt->second;
+ Defs->remove(*MA);
+ if (Defs->empty())
+ PerBlockDefs.erase(DefsIt);
+ }
+
+ // The erase call here will delete it. If we don't want it deleted, we call
+ // remove instead.
+ auto AccessIt = PerBlockAccesses.find(BB);
+ std::unique_ptr<AccessList> &Accesses = AccessIt->second;
+ if (ShouldDelete)
+ Accesses->erase(MA);
+ else
+ Accesses->remove(MA);
+
+ if (Accesses->empty()) {
+ PerBlockAccesses.erase(AccessIt);
+ BlockNumberingValid.erase(BB);
+ }
+}
+
+void MemorySSA::print(raw_ostream &OS) const {
+ MemorySSAAnnotatedWriter Writer(this);
+ F.print(OS, &Writer);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void MemorySSA::dump() const { print(dbgs()); }
+#endif
+
+void MemorySSA::verifyMemorySSA() const {
+ verifyDefUses(F);
+ verifyDomination(F);
+ verifyOrdering(F);
+ verifyDominationNumbers(F);
+ verifyPrevDefInPhis(F);
+ // Previously, the verification used to also verify that the clobberingAccess
+ // cached by MemorySSA is the same as the clobberingAccess found at a later
+ // query to AA. This does not hold true in general due to the current fragility
+ // of BasicAA which has arbitrary caps on the things it analyzes before giving
+ // up. As a result, transformations that are correct, will lead to BasicAA
+ // returning different Alias answers before and after that transformation.
+ // Invalidating MemorySSA is not an option, as the results in BasicAA can be so
+ // random, in the worst case we'd need to rebuild MemorySSA from scratch after
+ // every transformation, which defeats the purpose of using it. For such an
+ // example, see test4 added in D51960.
+}
+
+void MemorySSA::verifyPrevDefInPhis(Function &F) const {
+#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
+ for (const BasicBlock &BB : F) {
+ if (MemoryPhi *Phi = getMemoryAccess(&BB)) {
+ for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
+ auto *Pred = Phi->getIncomingBlock(I);
+ auto *IncAcc = Phi->getIncomingValue(I);
+ // If Pred has no unreachable predecessors, get last def looking at
+ // IDoms. If, while walkings IDoms, any of these has an unreachable
+ // predecessor, then the incoming def can be any access.
+ if (auto *DTNode = DT->getNode(Pred)) {
+ while (DTNode) {
+ if (auto *DefList = getBlockDefs(DTNode->getBlock())) {
+ auto *LastAcc = &*(--DefList->end());
+ assert(LastAcc == IncAcc &&
+ "Incorrect incoming access into phi.");
+ break;
+ }
+ DTNode = DTNode->getIDom();
+ }
+ } else {
+ // If Pred has unreachable predecessors, but has at least a Def, the
+ // incoming access can be the last Def in Pred, or it could have been
+ // optimized to LoE. After an update, though, the LoE may have been
+ // replaced by another access, so IncAcc may be any access.
+ // If Pred has unreachable predecessors and no Defs, incoming access
+ // should be LoE; However, after an update, it may be any access.
+ }
+ }
+ }
+ }
+#endif
+}
+
+/// Verify that all of the blocks we believe to have valid domination numbers
+/// actually have valid domination numbers.
+void MemorySSA::verifyDominationNumbers(const Function &F) const {
+#ifndef NDEBUG
+ if (BlockNumberingValid.empty())
+ return;
+
+ SmallPtrSet<const BasicBlock *, 16> ValidBlocks = BlockNumberingValid;
+ for (const BasicBlock &BB : F) {
+ if (!ValidBlocks.count(&BB))
+ continue;
+
+ ValidBlocks.erase(&BB);
+
+ const AccessList *Accesses = getBlockAccesses(&BB);
+ // It's correct to say an empty block has valid numbering.
+ if (!Accesses)
+ continue;
+
+ // Block numbering starts at 1.
+ unsigned long LastNumber = 0;
+ for (const MemoryAccess &MA : *Accesses) {
+ auto ThisNumberIter = BlockNumbering.find(&MA);
+ assert(ThisNumberIter != BlockNumbering.end() &&
+ "MemoryAccess has no domination number in a valid block!");
+
+ unsigned long ThisNumber = ThisNumberIter->second;
+ assert(ThisNumber > LastNumber &&
+ "Domination numbers should be strictly increasing!");
+ LastNumber = ThisNumber;
+ }
+ }
+
+ assert(ValidBlocks.empty() &&
+ "All valid BasicBlocks should exist in F -- dangling pointers?");
+#endif
+}
+
+/// Verify that the order and existence of MemoryAccesses matches the
+/// order and existence of memory affecting instructions.
+void MemorySSA::verifyOrdering(Function &F) const {
+#ifndef NDEBUG
+ // Walk all the blocks, comparing what the lookups think and what the access
+ // lists think, as well as the order in the blocks vs the order in the access
+ // lists.
+ SmallVector<MemoryAccess *, 32> ActualAccesses;
+ SmallVector<MemoryAccess *, 32> ActualDefs;
+ for (BasicBlock &B : F) {
+ const AccessList *AL = getBlockAccesses(&B);
+ const auto *DL = getBlockDefs(&B);
+ MemoryAccess *Phi = getMemoryAccess(&B);
+ if (Phi) {
+ ActualAccesses.push_back(Phi);
+ ActualDefs.push_back(Phi);
+ }
+
+ for (Instruction &I : B) {
+ MemoryAccess *MA = getMemoryAccess(&I);
+ assert((!MA || (AL && (isa<MemoryUse>(MA) || DL))) &&
+ "We have memory affecting instructions "
+ "in this block but they are not in the "
+ "access list or defs list");
+ if (MA) {
+ ActualAccesses.push_back(MA);
+ if (isa<MemoryDef>(MA))
+ ActualDefs.push_back(MA);
+ }
+ }
+ // Either we hit the assert, really have no accesses, or we have both
+ // accesses and an access list.
+ // Same with defs.
+ if (!AL && !DL)
+ continue;
+ assert(AL->size() == ActualAccesses.size() &&
+ "We don't have the same number of accesses in the block as on the "
+ "access list");
+ assert((DL || ActualDefs.size() == 0) &&
+ "Either we should have a defs list, or we should have no defs");
+ assert((!DL || DL->size() == ActualDefs.size()) &&
+ "We don't have the same number of defs in the block as on the "
+ "def list");
+ auto ALI = AL->begin();
+ auto AAI = ActualAccesses.begin();
+ while (ALI != AL->end() && AAI != ActualAccesses.end()) {
+ assert(&*ALI == *AAI && "Not the same accesses in the same order");
+ ++ALI;
+ ++AAI;
+ }
+ ActualAccesses.clear();
+ if (DL) {
+ auto DLI = DL->begin();
+ auto ADI = ActualDefs.begin();
+ while (DLI != DL->end() && ADI != ActualDefs.end()) {
+ assert(&*DLI == *ADI && "Not the same defs in the same order");
+ ++DLI;
+ ++ADI;
+ }
+ }
+ ActualDefs.clear();
+ }
+#endif
+}
+
+/// Verify the domination properties of MemorySSA by checking that each
+/// definition dominates all of its uses.
+void MemorySSA::verifyDomination(Function &F) const {
+#ifndef NDEBUG
+ for (BasicBlock &B : F) {
+ // Phi nodes are attached to basic blocks
+ if (MemoryPhi *MP = getMemoryAccess(&B))
+ for (const Use &U : MP->uses())
+ assert(dominates(MP, U) && "Memory PHI does not dominate it's uses");
+
+ for (Instruction &I : B) {
+ MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I));
+ if (!MD)
+ continue;
+
+ for (const Use &U : MD->uses())
+ assert(dominates(MD, U) && "Memory Def does not dominate it's uses");
+ }
+ }
+#endif
+}
+
+/// Verify the def-use lists in MemorySSA, by verifying that \p Use
+/// appears in the use list of \p Def.
+void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const {
+#ifndef NDEBUG
+ // The live on entry use may cause us to get a NULL def here
+ if (!Def)
+ assert(isLiveOnEntryDef(Use) &&
+ "Null def but use not point to live on entry def");
+ else
+ assert(is_contained(Def->users(), Use) &&
+ "Did not find use in def's use list");
+#endif
+}
+
+/// Verify the immediate use information, by walking all the memory
+/// accesses and verifying that, for each use, it appears in the
+/// appropriate def's use list
+void MemorySSA::verifyDefUses(Function &F) const {
+#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
+ for (BasicBlock &B : F) {
+ // Phi nodes are attached to basic blocks
+ if (MemoryPhi *Phi = getMemoryAccess(&B)) {
+ assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance(
+ pred_begin(&B), pred_end(&B))) &&
+ "Incomplete MemoryPhi Node");
+ for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
+ verifyUseInDefs(Phi->getIncomingValue(I), Phi);
+ assert(find(predecessors(&B), Phi->getIncomingBlock(I)) !=
+ pred_end(&B) &&
+ "Incoming phi block not a block predecessor");
+ }
+ }
+
+ for (Instruction &I : B) {
+ if (MemoryUseOrDef *MA = getMemoryAccess(&I)) {
+ verifyUseInDefs(MA->getDefiningAccess(), MA);
+ }
+ }
+ }
+#endif
+}
+
+/// Perform a local numbering on blocks so that instruction ordering can be
+/// determined in constant time.
+/// TODO: We currently just number in order. If we numbered by N, we could
+/// allow at least N-1 sequences of insertBefore or insertAfter (and at least
+/// log2(N) sequences of mixed before and after) without needing to invalidate
+/// the numbering.
+void MemorySSA::renumberBlock(const BasicBlock *B) const {
+ // The pre-increment ensures the numbers really start at 1.
+ unsigned long CurrentNumber = 0;
+ const AccessList *AL = getBlockAccesses(B);
+ assert(AL != nullptr && "Asking to renumber an empty block");
+ for (const auto &I : *AL)
+ BlockNumbering[&I] = ++CurrentNumber;
+ BlockNumberingValid.insert(B);
+}
+
+/// Determine, for two memory accesses in the same block,
+/// whether \p Dominator dominates \p Dominatee.
+/// \returns True if \p Dominator dominates \p Dominatee.
+bool MemorySSA::locallyDominates(const MemoryAccess *Dominator,
+ const MemoryAccess *Dominatee) const {
+ const BasicBlock *DominatorBlock = Dominator->getBlock();
+
+ assert((DominatorBlock == Dominatee->getBlock()) &&
+ "Asking for local domination when accesses are in different blocks!");
+ // A node dominates itself.
+ if (Dominatee == Dominator)
+ return true;
+
+ // When Dominatee is defined on function entry, it is not dominated by another
+ // memory access.
+ if (isLiveOnEntryDef(Dominatee))
+ return false;
+
+ // When Dominator is defined on function entry, it dominates the other memory
+ // access.
+ if (isLiveOnEntryDef(Dominator))
+ return true;
+
+ if (!BlockNumberingValid.count(DominatorBlock))
+ renumberBlock(DominatorBlock);
+
+ unsigned long DominatorNum = BlockNumbering.lookup(Dominator);
+ // All numbers start with 1
+ assert(DominatorNum != 0 && "Block was not numbered properly");
+ unsigned long DominateeNum = BlockNumbering.lookup(Dominatee);
+ assert(DominateeNum != 0 && "Block was not numbered properly");
+ return DominatorNum < DominateeNum;
+}
+
+bool MemorySSA::dominates(const MemoryAccess *Dominator,
+ const MemoryAccess *Dominatee) const {
+ if (Dominator == Dominatee)
+ return true;
+
+ if (isLiveOnEntryDef(Dominatee))
+ return false;
+
+ if (Dominator->getBlock() != Dominatee->getBlock())
+ return DT->dominates(Dominator->getBlock(), Dominatee->getBlock());
+ return locallyDominates(Dominator, Dominatee);
+}
+
+bool MemorySSA::dominates(const MemoryAccess *Dominator,
+ const Use &Dominatee) const {
+ if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) {
+ BasicBlock *UseBB = MP->getIncomingBlock(Dominatee);
+ // The def must dominate the incoming block of the phi.
+ if (UseBB != Dominator->getBlock())
+ return DT->dominates(Dominator->getBlock(), UseBB);
+ // If the UseBB and the DefBB are the same, compare locally.
+ return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee));
+ }
+ // If it's not a PHI node use, the normal dominates can already handle it.
+ return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser()));
+}
+
+const static char LiveOnEntryStr[] = "liveOnEntry";
+
+void MemoryAccess::print(raw_ostream &OS) const {
+ switch (getValueID()) {
+ case MemoryPhiVal: return static_cast<const MemoryPhi *>(this)->print(OS);
+ case MemoryDefVal: return static_cast<const MemoryDef *>(this)->print(OS);
+ case MemoryUseVal: return static_cast<const MemoryUse *>(this)->print(OS);
+ }
+ llvm_unreachable("invalid value id");
+}
+
+void MemoryDef::print(raw_ostream &OS) const {
+ MemoryAccess *UO = getDefiningAccess();
+
+ auto printID = [&OS](MemoryAccess *A) {
+ if (A && A->getID())
+ OS << A->getID();
+ else
+ OS << LiveOnEntryStr;
+ };
+
+ OS << getID() << " = MemoryDef(";
+ printID(UO);
+ OS << ")";
+
+ if (isOptimized()) {
+ OS << "->";
+ printID(getOptimized());
+
+ if (Optional<AliasResult> AR = getOptimizedAccessType())
+ OS << " " << *AR;
+ }
+}
+
+void MemoryPhi::print(raw_ostream &OS) const {
+ bool First = true;
+ OS << getID() << " = MemoryPhi(";
+ for (const auto &Op : operands()) {
+ BasicBlock *BB = getIncomingBlock(Op);
+ MemoryAccess *MA = cast<MemoryAccess>(Op);
+ if (!First)
+ OS << ',';
+ else
+ First = false;
+
+ OS << '{';
+ if (BB->hasName())
+ OS << BB->getName();
+ else
+ BB->printAsOperand(OS, false);
+ OS << ',';
+ if (unsigned ID = MA->getID())
+ OS << ID;
+ else
+ OS << LiveOnEntryStr;
+ OS << '}';
+ }
+ OS << ')';
+}
+
+void MemoryUse::print(raw_ostream &OS) const {
+ MemoryAccess *UO = getDefiningAccess();
+ OS << "MemoryUse(";
+ if (UO && UO->getID())
+ OS << UO->getID();
+ else
+ OS << LiveOnEntryStr;
+ OS << ')';
+
+ if (Optional<AliasResult> AR = getOptimizedAccessType())
+ OS << " " << *AR;
+}
+
+void MemoryAccess::dump() const {
+// Cannot completely remove virtual function even in release mode.
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ print(dbgs());
+ dbgs() << "\n";
+#endif
+}
+
+char MemorySSAPrinterLegacyPass::ID = 0;
+
+MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) {
+ initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
+}
+
+void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<MemorySSAWrapperPass>();
+}
+
+bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) {
+ auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
+ MSSA.print(dbgs());
+ if (VerifyMemorySSA)
+ MSSA.verifyMemorySSA();
+ return false;
+}
+
+AnalysisKey MemorySSAAnalysis::Key;
+
+MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+ return MemorySSAAnalysis::Result(std::make_unique<MemorySSA>(F, &AA, &DT));
+}
+
+bool MemorySSAAnalysis::Result::invalidate(
+ Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ auto PAC = PA.getChecker<MemorySSAAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
+ Inv.invalidate<AAManager>(F, PA) ||
+ Inv.invalidate<DominatorTreeAnalysis>(F, PA);
+}
+
+PreservedAnalyses MemorySSAPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ OS << "MemorySSA for function: " << F.getName() << "\n";
+ AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS);
+
+ return PreservedAnalyses::all();
+}
+
+PreservedAnalyses MemorySSAVerifierPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA();
+
+ return PreservedAnalyses::all();
+}
+
+char MemorySSAWrapperPass::ID = 0;
+
+MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) {
+ initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); }
+
+void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<DominatorTreeWrapperPass>();
+ AU.addRequiredTransitive<AAResultsWrapperPass>();
+}
+
+bool MemorySSAWrapperPass::runOnFunction(Function &F) {
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ MSSA.reset(new MemorySSA(F, &AA, &DT));
+ return false;
+}
+
+void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); }
+
+void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const {
+ MSSA->print(OS);
+}
+
+MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {}
+
+/// Walk the use-def chains starting at \p StartingAccess and find
+/// the MemoryAccess that actually clobbers Loc.
+///
+/// \returns our clobbering memory access
+template <typename AliasAnalysisType>
+MemoryAccess *
+MemorySSA::ClobberWalkerBase<AliasAnalysisType>::getClobberingMemoryAccessBase(
+ MemoryAccess *StartingAccess, const MemoryLocation &Loc,
+ unsigned &UpwardWalkLimit) {
+ if (isa<MemoryPhi>(StartingAccess))
+ return StartingAccess;
+
+ auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess);
+ if (MSSA->isLiveOnEntryDef(StartingUseOrDef))
+ return StartingUseOrDef;
+
+ Instruction *I = StartingUseOrDef->getMemoryInst();
+
+ // Conservatively, fences are always clobbers, so don't perform the walk if we
+ // hit a fence.
+ if (!isa<CallBase>(I) && I->isFenceLike())
+ return StartingUseOrDef;
+
+ UpwardsMemoryQuery Q;
+ Q.OriginalAccess = StartingUseOrDef;
+ Q.StartingLoc = Loc;
+ Q.Inst = I;
+ Q.IsCall = false;
+
+ // Unlike the other function, do not walk to the def of a def, because we are
+ // handed something we already believe is the clobbering access.
+ // We never set SkipSelf to true in Q in this method.
+ MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef)
+ ? StartingUseOrDef->getDefiningAccess()
+ : StartingUseOrDef;
+
+ MemoryAccess *Clobber =
+ Walker.findClobber(DefiningAccess, Q, UpwardWalkLimit);
+ LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is ");
+ LLVM_DEBUG(dbgs() << *StartingUseOrDef << "\n");
+ LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is ");
+ LLVM_DEBUG(dbgs() << *Clobber << "\n");
+ return Clobber;
+}
+
+template <typename AliasAnalysisType>
+MemoryAccess *
+MemorySSA::ClobberWalkerBase<AliasAnalysisType>::getClobberingMemoryAccessBase(
+ MemoryAccess *MA, unsigned &UpwardWalkLimit, bool SkipSelf) {
+ auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA);
+ // If this is a MemoryPhi, we can't do anything.
+ if (!StartingAccess)
+ return MA;
+
+ bool IsOptimized = false;
+
+ // If this is an already optimized use or def, return the optimized result.
+ // Note: Currently, we store the optimized def result in a separate field,
+ // since we can't use the defining access.
+ if (StartingAccess->isOptimized()) {
+ if (!SkipSelf || !isa<MemoryDef>(StartingAccess))
+ return StartingAccess->getOptimized();
+ IsOptimized = true;
+ }
+
+ const Instruction *I = StartingAccess->getMemoryInst();
+ // We can't sanely do anything with a fence, since they conservatively clobber
+ // all memory, and have no locations to get pointers from to try to
+ // disambiguate.
+ if (!isa<CallBase>(I) && I->isFenceLike())
+ return StartingAccess;
+
+ UpwardsMemoryQuery Q(I, StartingAccess);
+
+ if (isUseTriviallyOptimizableToLiveOnEntry(*Walker.getAA(), I)) {
+ MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef();
+ StartingAccess->setOptimized(LiveOnEntry);
+ StartingAccess->setOptimizedAccessType(None);
+ return LiveOnEntry;
+ }
+
+ MemoryAccess *OptimizedAccess;
+ if (!IsOptimized) {
+ // Start with the thing we already think clobbers this location
+ MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess();
+
+ // At this point, DefiningAccess may be the live on entry def.
+ // If it is, we will not get a better result.
+ if (MSSA->isLiveOnEntryDef(DefiningAccess)) {
+ StartingAccess->setOptimized(DefiningAccess);
+ StartingAccess->setOptimizedAccessType(None);
+ return DefiningAccess;
+ }
+
+ OptimizedAccess = Walker.findClobber(DefiningAccess, Q, UpwardWalkLimit);
+ StartingAccess->setOptimized(OptimizedAccess);
+ if (MSSA->isLiveOnEntryDef(OptimizedAccess))
+ StartingAccess->setOptimizedAccessType(None);
+ else if (Q.AR == MustAlias)
+ StartingAccess->setOptimizedAccessType(MustAlias);
+ } else
+ OptimizedAccess = StartingAccess->getOptimized();
+
+ LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is ");
+ LLVM_DEBUG(dbgs() << *StartingAccess << "\n");
+ LLVM_DEBUG(dbgs() << "Optimized Memory SSA clobber for " << *I << " is ");
+ LLVM_DEBUG(dbgs() << *OptimizedAccess << "\n");
+
+ MemoryAccess *Result;
+ if (SkipSelf && isa<MemoryPhi>(OptimizedAccess) &&
+ isa<MemoryDef>(StartingAccess) && UpwardWalkLimit) {
+ assert(isa<MemoryDef>(Q.OriginalAccess));
+ Q.SkipSelfAccess = true;
+ Result = Walker.findClobber(OptimizedAccess, Q, UpwardWalkLimit);
+ } else
+ Result = OptimizedAccess;
+
+ LLVM_DEBUG(dbgs() << "Result Memory SSA clobber [SkipSelf = " << SkipSelf);
+ LLVM_DEBUG(dbgs() << "] for " << *I << " is " << *Result << "\n");
+
+ return Result;
+}
+
+MemoryAccess *
+DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) {
+ if (auto *Use = dyn_cast<MemoryUseOrDef>(MA))
+ return Use->getDefiningAccess();
+ return MA;
+}
+
+MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess(
+ MemoryAccess *StartingAccess, const MemoryLocation &) {
+ if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess))
+ return Use->getDefiningAccess();
+ return StartingAccess;
+}
+
+void MemoryPhi::deleteMe(DerivedUser *Self) {
+ delete static_cast<MemoryPhi *>(Self);
+}
+
+void MemoryDef::deleteMe(DerivedUser *Self) {
+ delete static_cast<MemoryDef *>(Self);
+}
+
+void MemoryUse::deleteMe(DerivedUser *Self) {
+ delete static_cast<MemoryUse *>(Self);
+}
diff --git a/llvm/lib/Analysis/MemorySSAUpdater.cpp b/llvm/lib/Analysis/MemorySSAUpdater.cpp
new file mode 100644
index 000000000000..f2d56b05d968
--- /dev/null
+++ b/llvm/lib/Analysis/MemorySSAUpdater.cpp
@@ -0,0 +1,1440 @@
+//===-- MemorySSAUpdater.cpp - Memory SSA Updater--------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------===//
+//
+// This file implements the MemorySSAUpdater class.
+//
+//===----------------------------------------------------------------===//
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/IteratedDominanceFrontier.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormattedStream.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "memoryssa"
+using namespace llvm;
+
+// This is the marker algorithm from "Simple and Efficient Construction of
+// Static Single Assignment Form"
+// The simple, non-marker algorithm places phi nodes at any join
+// Here, we place markers, and only place phi nodes if they end up necessary.
+// They are only necessary if they break a cycle (IE we recursively visit
+// ourselves again), or we discover, while getting the value of the operands,
+// that there are two or more definitions needing to be merged.
+// This still will leave non-minimal form in the case of irreducible control
+// flow, where phi nodes may be in cycles with themselves, but unnecessary.
+MemoryAccess *MemorySSAUpdater::getPreviousDefRecursive(
+ BasicBlock *BB,
+ DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) {
+ // First, do a cache lookup. Without this cache, certain CFG structures
+ // (like a series of if statements) take exponential time to visit.
+ auto Cached = CachedPreviousDef.find(BB);
+ if (Cached != CachedPreviousDef.end())
+ return Cached->second;
+
+ // If this method is called from an unreachable block, return LoE.
+ if (!MSSA->DT->isReachableFromEntry(BB))
+ return MSSA->getLiveOnEntryDef();
+
+ if (BasicBlock *Pred = BB->getUniquePredecessor()) {
+ VisitedBlocks.insert(BB);
+ // Single predecessor case, just recurse, we can only have one definition.
+ MemoryAccess *Result = getPreviousDefFromEnd(Pred, CachedPreviousDef);
+ CachedPreviousDef.insert({BB, Result});
+ return Result;
+ }
+
+ if (VisitedBlocks.count(BB)) {
+ // We hit our node again, meaning we had a cycle, we must insert a phi
+ // node to break it so we have an operand. The only case this will
+ // insert useless phis is if we have irreducible control flow.
+ MemoryAccess *Result = MSSA->createMemoryPhi(BB);
+ CachedPreviousDef.insert({BB, Result});
+ return Result;
+ }
+
+ if (VisitedBlocks.insert(BB).second) {
+ // Mark us visited so we can detect a cycle
+ SmallVector<TrackingVH<MemoryAccess>, 8> PhiOps;
+
+ // Recurse to get the values in our predecessors for placement of a
+ // potential phi node. This will insert phi nodes if we cycle in order to
+ // break the cycle and have an operand.
+ bool UniqueIncomingAccess = true;
+ MemoryAccess *SingleAccess = nullptr;
+ for (auto *Pred : predecessors(BB)) {
+ if (MSSA->DT->isReachableFromEntry(Pred)) {
+ auto *IncomingAccess = getPreviousDefFromEnd(Pred, CachedPreviousDef);
+ if (!SingleAccess)
+ SingleAccess = IncomingAccess;
+ else if (IncomingAccess != SingleAccess)
+ UniqueIncomingAccess = false;
+ PhiOps.push_back(IncomingAccess);
+ } else
+ PhiOps.push_back(MSSA->getLiveOnEntryDef());
+ }
+
+ // Now try to simplify the ops to avoid placing a phi.
+ // This may return null if we never created a phi yet, that's okay
+ MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MSSA->getMemoryAccess(BB));
+
+ // See if we can avoid the phi by simplifying it.
+ auto *Result = tryRemoveTrivialPhi(Phi, PhiOps);
+ // If we couldn't simplify, we may have to create a phi
+ if (Result == Phi && UniqueIncomingAccess && SingleAccess) {
+ // A concrete Phi only exists if we created an empty one to break a cycle.
+ if (Phi) {
+ assert(Phi->operands().empty() && "Expected empty Phi");
+ Phi->replaceAllUsesWith(SingleAccess);
+ removeMemoryAccess(Phi);
+ }
+ Result = SingleAccess;
+ } else if (Result == Phi && !(UniqueIncomingAccess && SingleAccess)) {
+ if (!Phi)
+ Phi = MSSA->createMemoryPhi(BB);
+
+ // See if the existing phi operands match what we need.
+ // Unlike normal SSA, we only allow one phi node per block, so we can't just
+ // create a new one.
+ if (Phi->getNumOperands() != 0) {
+ // FIXME: Figure out whether this is dead code and if so remove it.
+ if (!std::equal(Phi->op_begin(), Phi->op_end(), PhiOps.begin())) {
+ // These will have been filled in by the recursive read we did above.
+ llvm::copy(PhiOps, Phi->op_begin());
+ std::copy(pred_begin(BB), pred_end(BB), Phi->block_begin());
+ }
+ } else {
+ unsigned i = 0;
+ for (auto *Pred : predecessors(BB))
+ Phi->addIncoming(&*PhiOps[i++], Pred);
+ InsertedPHIs.push_back(Phi);
+ }
+ Result = Phi;
+ }
+
+ // Set ourselves up for the next variable by resetting visited state.
+ VisitedBlocks.erase(BB);
+ CachedPreviousDef.insert({BB, Result});
+ return Result;
+ }
+ llvm_unreachable("Should have hit one of the three cases above");
+}
+
+// This starts at the memory access, and goes backwards in the block to find the
+// previous definition. If a definition is not found the block of the access,
+// it continues globally, creating phi nodes to ensure we have a single
+// definition.
+MemoryAccess *MemorySSAUpdater::getPreviousDef(MemoryAccess *MA) {
+ if (auto *LocalResult = getPreviousDefInBlock(MA))
+ return LocalResult;
+ DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> CachedPreviousDef;
+ return getPreviousDefRecursive(MA->getBlock(), CachedPreviousDef);
+}
+
+// This starts at the memory access, and goes backwards in the block to the find
+// the previous definition. If the definition is not found in the block of the
+// access, it returns nullptr.
+MemoryAccess *MemorySSAUpdater::getPreviousDefInBlock(MemoryAccess *MA) {
+ auto *Defs = MSSA->getWritableBlockDefs(MA->getBlock());
+
+ // It's possible there are no defs, or we got handed the first def to start.
+ if (Defs) {
+ // If this is a def, we can just use the def iterators.
+ if (!isa<MemoryUse>(MA)) {
+ auto Iter = MA->getReverseDefsIterator();
+ ++Iter;
+ if (Iter != Defs->rend())
+ return &*Iter;
+ } else {
+ // Otherwise, have to walk the all access iterator.
+ auto End = MSSA->getWritableBlockAccesses(MA->getBlock())->rend();
+ for (auto &U : make_range(++MA->getReverseIterator(), End))
+ if (!isa<MemoryUse>(U))
+ return cast<MemoryAccess>(&U);
+ // Note that if MA comes before Defs->begin(), we won't hit a def.
+ return nullptr;
+ }
+ }
+ return nullptr;
+}
+
+// This starts at the end of block
+MemoryAccess *MemorySSAUpdater::getPreviousDefFromEnd(
+ BasicBlock *BB,
+ DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) {
+ auto *Defs = MSSA->getWritableBlockDefs(BB);
+
+ if (Defs) {
+ CachedPreviousDef.insert({BB, &*Defs->rbegin()});
+ return &*Defs->rbegin();
+ }
+
+ return getPreviousDefRecursive(BB, CachedPreviousDef);
+}
+// Recurse over a set of phi uses to eliminate the trivial ones
+MemoryAccess *MemorySSAUpdater::recursePhi(MemoryAccess *Phi) {
+ if (!Phi)
+ return nullptr;
+ TrackingVH<MemoryAccess> Res(Phi);
+ SmallVector<TrackingVH<Value>, 8> Uses;
+ std::copy(Phi->user_begin(), Phi->user_end(), std::back_inserter(Uses));
+ for (auto &U : Uses)
+ if (MemoryPhi *UsePhi = dyn_cast<MemoryPhi>(&*U))
+ tryRemoveTrivialPhi(UsePhi);
+ return Res;
+}
+
+// Eliminate trivial phis
+// Phis are trivial if they are defined either by themselves, or all the same
+// argument.
+// IE phi(a, a) or b = phi(a, b) or c = phi(a, a, c)
+// We recursively try to remove them.
+MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi) {
+ assert(Phi && "Can only remove concrete Phi.");
+ auto OperRange = Phi->operands();
+ return tryRemoveTrivialPhi(Phi, OperRange);
+}
+template <class RangeType>
+MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi,
+ RangeType &Operands) {
+ // Bail out on non-opt Phis.
+ if (NonOptPhis.count(Phi))
+ return Phi;
+
+ // Detect equal or self arguments
+ MemoryAccess *Same = nullptr;
+ for (auto &Op : Operands) {
+ // If the same or self, good so far
+ if (Op == Phi || Op == Same)
+ continue;
+ // not the same, return the phi since it's not eliminatable by us
+ if (Same)
+ return Phi;
+ Same = cast<MemoryAccess>(&*Op);
+ }
+ // Never found a non-self reference, the phi is undef
+ if (Same == nullptr)
+ return MSSA->getLiveOnEntryDef();
+ if (Phi) {
+ Phi->replaceAllUsesWith(Same);
+ removeMemoryAccess(Phi);
+ }
+
+ // We should only end up recursing in case we replaced something, in which
+ // case, we may have made other Phis trivial.
+ return recursePhi(Same);
+}
+
+void MemorySSAUpdater::insertUse(MemoryUse *MU, bool RenameUses) {
+ InsertedPHIs.clear();
+ MU->setDefiningAccess(getPreviousDef(MU));
+
+ // In cases without unreachable blocks, because uses do not create new
+ // may-defs, there are only two cases:
+ // 1. There was a def already below us, and therefore, we should not have
+ // created a phi node because it was already needed for the def.
+ //
+ // 2. There is no def below us, and therefore, there is no extra renaming work
+ // to do.
+
+ // In cases with unreachable blocks, where the unnecessary Phis were
+ // optimized out, adding the Use may re-insert those Phis. Hence, when
+ // inserting Uses outside of the MSSA creation process, and new Phis were
+ // added, rename all uses if we are asked.
+
+ if (!RenameUses && !InsertedPHIs.empty()) {
+ auto *Defs = MSSA->getBlockDefs(MU->getBlock());
+ (void)Defs;
+ assert((!Defs || (++Defs->begin() == Defs->end())) &&
+ "Block may have only a Phi or no defs");
+ }
+
+ if (RenameUses && InsertedPHIs.size()) {
+ SmallPtrSet<BasicBlock *, 16> Visited;
+ BasicBlock *StartBlock = MU->getBlock();
+
+ if (auto *Defs = MSSA->getWritableBlockDefs(StartBlock)) {
+ MemoryAccess *FirstDef = &*Defs->begin();
+ // Convert to incoming value if it's a memorydef. A phi *is* already an
+ // incoming value.
+ if (auto *MD = dyn_cast<MemoryDef>(FirstDef))
+ FirstDef = MD->getDefiningAccess();
+
+ MSSA->renamePass(MU->getBlock(), FirstDef, Visited);
+ }
+ // We just inserted a phi into this block, so the incoming value will
+ // become the phi anyway, so it does not matter what we pass.
+ for (auto &MP : InsertedPHIs)
+ if (MemoryPhi *Phi = cast_or_null<MemoryPhi>(MP))
+ MSSA->renamePass(Phi->getBlock(), nullptr, Visited);
+ }
+}
+
+// Set every incoming edge {BB, MP->getBlock()} of MemoryPhi MP to NewDef.
+static void setMemoryPhiValueForBlock(MemoryPhi *MP, const BasicBlock *BB,
+ MemoryAccess *NewDef) {
+ // Replace any operand with us an incoming block with the new defining
+ // access.
+ int i = MP->getBasicBlockIndex(BB);
+ assert(i != -1 && "Should have found the basic block in the phi");
+ // We can't just compare i against getNumOperands since one is signed and the
+ // other not. So use it to index into the block iterator.
+ for (auto BBIter = MP->block_begin() + i; BBIter != MP->block_end();
+ ++BBIter) {
+ if (*BBIter != BB)
+ break;
+ MP->setIncomingValue(i, NewDef);
+ ++i;
+ }
+}
+
+// A brief description of the algorithm:
+// First, we compute what should define the new def, using the SSA
+// construction algorithm.
+// Then, we update the defs below us (and any new phi nodes) in the graph to
+// point to the correct new defs, to ensure we only have one variable, and no
+// disconnected stores.
+void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) {
+ InsertedPHIs.clear();
+
+ // See if we had a local def, and if not, go hunting.
+ MemoryAccess *DefBefore = getPreviousDef(MD);
+ bool DefBeforeSameBlock = false;
+ if (DefBefore->getBlock() == MD->getBlock() &&
+ !(isa<MemoryPhi>(DefBefore) &&
+ std::find(InsertedPHIs.begin(), InsertedPHIs.end(), DefBefore) !=
+ InsertedPHIs.end()))
+ DefBeforeSameBlock = true;
+
+ // There is a def before us, which means we can replace any store/phi uses
+ // of that thing with us, since we are in the way of whatever was there
+ // before.
+ // We now define that def's memorydefs and memoryphis
+ if (DefBeforeSameBlock) {
+ DefBefore->replaceUsesWithIf(MD, [MD](Use &U) {
+ // Leave the MemoryUses alone.
+ // Also make sure we skip ourselves to avoid self references.
+ User *Usr = U.getUser();
+ return !isa<MemoryUse>(Usr) && Usr != MD;
+ // Defs are automatically unoptimized when the user is set to MD below,
+ // because the isOptimized() call will fail to find the same ID.
+ });
+ }
+
+ // and that def is now our defining access.
+ MD->setDefiningAccess(DefBefore);
+
+ SmallVector<WeakVH, 8> FixupList(InsertedPHIs.begin(), InsertedPHIs.end());
+
+ // Remember the index where we may insert new phis.
+ unsigned NewPhiIndex = InsertedPHIs.size();
+ if (!DefBeforeSameBlock) {
+ // If there was a local def before us, we must have the same effect it
+ // did. Because every may-def is the same, any phis/etc we would create, it
+ // would also have created. If there was no local def before us, we
+ // performed a global update, and have to search all successors and make
+ // sure we update the first def in each of them (following all paths until
+ // we hit the first def along each path). This may also insert phi nodes.
+ // TODO: There are other cases we can skip this work, such as when we have a
+ // single successor, and only used a straight line of single pred blocks
+ // backwards to find the def. To make that work, we'd have to track whether
+ // getDefRecursive only ever used the single predecessor case. These types
+ // of paths also only exist in between CFG simplifications.
+
+ // If this is the first def in the block and this insert is in an arbitrary
+ // place, compute IDF and place phis.
+ SmallPtrSet<BasicBlock *, 2> DefiningBlocks;
+
+ // If this is the last Def in the block, also compute IDF based on MD, since
+ // this may a new Def added, and we may need additional Phis.
+ auto Iter = MD->getDefsIterator();
+ ++Iter;
+ auto IterEnd = MSSA->getBlockDefs(MD->getBlock())->end();
+ if (Iter == IterEnd)
+ DefiningBlocks.insert(MD->getBlock());
+
+ for (const auto &VH : InsertedPHIs)
+ if (const auto *RealPHI = cast_or_null<MemoryPhi>(VH))
+ DefiningBlocks.insert(RealPHI->getBlock());
+ ForwardIDFCalculator IDFs(*MSSA->DT);
+ SmallVector<BasicBlock *, 32> IDFBlocks;
+ IDFs.setDefiningBlocks(DefiningBlocks);
+ IDFs.calculate(IDFBlocks);
+ SmallVector<AssertingVH<MemoryPhi>, 4> NewInsertedPHIs;
+ for (auto *BBIDF : IDFBlocks) {
+ auto *MPhi = MSSA->getMemoryAccess(BBIDF);
+ if (!MPhi) {
+ MPhi = MSSA->createMemoryPhi(BBIDF);
+ NewInsertedPHIs.push_back(MPhi);
+ }
+ // Add the phis created into the IDF blocks to NonOptPhis, so they are not
+ // optimized out as trivial by the call to getPreviousDefFromEnd below.
+ // Once they are complete, all these Phis are added to the FixupList, and
+ // removed from NonOptPhis inside fixupDefs(). Existing Phis in IDF may
+ // need fixing as well, and potentially be trivial before this insertion,
+ // hence add all IDF Phis. See PR43044.
+ NonOptPhis.insert(MPhi);
+ }
+ for (auto &MPhi : NewInsertedPHIs) {
+ auto *BBIDF = MPhi->getBlock();
+ for (auto *Pred : predecessors(BBIDF)) {
+ DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> CachedPreviousDef;
+ MPhi->addIncoming(getPreviousDefFromEnd(Pred, CachedPreviousDef), Pred);
+ }
+ }
+
+ // Re-take the index where we're adding the new phis, because the above call
+ // to getPreviousDefFromEnd, may have inserted into InsertedPHIs.
+ NewPhiIndex = InsertedPHIs.size();
+ for (auto &MPhi : NewInsertedPHIs) {
+ InsertedPHIs.push_back(&*MPhi);
+ FixupList.push_back(&*MPhi);
+ }
+
+ FixupList.push_back(MD);
+ }
+
+ // Remember the index where we stopped inserting new phis above, since the
+ // fixupDefs call in the loop below may insert more, that are already minimal.
+ unsigned NewPhiIndexEnd = InsertedPHIs.size();
+
+ while (!FixupList.empty()) {
+ unsigned StartingPHISize = InsertedPHIs.size();
+ fixupDefs(FixupList);
+ FixupList.clear();
+ // Put any new phis on the fixup list, and process them
+ FixupList.append(InsertedPHIs.begin() + StartingPHISize, InsertedPHIs.end());
+ }
+
+ // Optimize potentially non-minimal phis added in this method.
+ unsigned NewPhiSize = NewPhiIndexEnd - NewPhiIndex;
+ if (NewPhiSize)
+ tryRemoveTrivialPhis(ArrayRef<WeakVH>(&InsertedPHIs[NewPhiIndex], NewPhiSize));
+
+ // Now that all fixups are done, rename all uses if we are asked.
+ if (RenameUses) {
+ SmallPtrSet<BasicBlock *, 16> Visited;
+ BasicBlock *StartBlock = MD->getBlock();
+ // We are guaranteed there is a def in the block, because we just got it
+ // handed to us in this function.
+ MemoryAccess *FirstDef = &*MSSA->getWritableBlockDefs(StartBlock)->begin();
+ // Convert to incoming value if it's a memorydef. A phi *is* already an
+ // incoming value.
+ if (auto *MD = dyn_cast<MemoryDef>(FirstDef))
+ FirstDef = MD->getDefiningAccess();
+
+ MSSA->renamePass(MD->getBlock(), FirstDef, Visited);
+ // We just inserted a phi into this block, so the incoming value will become
+ // the phi anyway, so it does not matter what we pass.
+ for (auto &MP : InsertedPHIs) {
+ MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MP);
+ if (Phi)
+ MSSA->renamePass(Phi->getBlock(), nullptr, Visited);
+ }
+ }
+}
+
+void MemorySSAUpdater::fixupDefs(const SmallVectorImpl<WeakVH> &Vars) {
+ SmallPtrSet<const BasicBlock *, 8> Seen;
+ SmallVector<const BasicBlock *, 16> Worklist;
+ for (auto &Var : Vars) {
+ MemoryAccess *NewDef = dyn_cast_or_null<MemoryAccess>(Var);
+ if (!NewDef)
+ continue;
+ // First, see if there is a local def after the operand.
+ auto *Defs = MSSA->getWritableBlockDefs(NewDef->getBlock());
+ auto DefIter = NewDef->getDefsIterator();
+
+ // The temporary Phi is being fixed, unmark it for not to optimize.
+ if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(NewDef))
+ NonOptPhis.erase(Phi);
+
+ // If there is a local def after us, we only have to rename that.
+ if (++DefIter != Defs->end()) {
+ cast<MemoryDef>(DefIter)->setDefiningAccess(NewDef);
+ continue;
+ }
+
+ // Otherwise, we need to search down through the CFG.
+ // For each of our successors, handle it directly if their is a phi, or
+ // place on the fixup worklist.
+ for (const auto *S : successors(NewDef->getBlock())) {
+ if (auto *MP = MSSA->getMemoryAccess(S))
+ setMemoryPhiValueForBlock(MP, NewDef->getBlock(), NewDef);
+ else
+ Worklist.push_back(S);
+ }
+
+ while (!Worklist.empty()) {
+ const BasicBlock *FixupBlock = Worklist.back();
+ Worklist.pop_back();
+
+ // Get the first def in the block that isn't a phi node.
+ if (auto *Defs = MSSA->getWritableBlockDefs(FixupBlock)) {
+ auto *FirstDef = &*Defs->begin();
+ // The loop above and below should have taken care of phi nodes
+ assert(!isa<MemoryPhi>(FirstDef) &&
+ "Should have already handled phi nodes!");
+ // We are now this def's defining access, make sure we actually dominate
+ // it
+ assert(MSSA->dominates(NewDef, FirstDef) &&
+ "Should have dominated the new access");
+
+ // This may insert new phi nodes, because we are not guaranteed the
+ // block we are processing has a single pred, and depending where the
+ // store was inserted, it may require phi nodes below it.
+ cast<MemoryDef>(FirstDef)->setDefiningAccess(getPreviousDef(FirstDef));
+ return;
+ }
+ // We didn't find a def, so we must continue.
+ for (const auto *S : successors(FixupBlock)) {
+ // If there is a phi node, handle it.
+ // Otherwise, put the block on the worklist
+ if (auto *MP = MSSA->getMemoryAccess(S))
+ setMemoryPhiValueForBlock(MP, FixupBlock, NewDef);
+ else {
+ // If we cycle, we should have ended up at a phi node that we already
+ // processed. FIXME: Double check this
+ if (!Seen.insert(S).second)
+ continue;
+ Worklist.push_back(S);
+ }
+ }
+ }
+ }
+}
+
+void MemorySSAUpdater::removeEdge(BasicBlock *From, BasicBlock *To) {
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) {
+ MPhi->unorderedDeleteIncomingBlock(From);
+ tryRemoveTrivialPhi(MPhi);
+ }
+}
+
+void MemorySSAUpdater::removeDuplicatePhiEdgesBetween(const BasicBlock *From,
+ const BasicBlock *To) {
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) {
+ bool Found = false;
+ MPhi->unorderedDeleteIncomingIf([&](const MemoryAccess *, BasicBlock *B) {
+ if (From != B)
+ return false;
+ if (Found)
+ return true;
+ Found = true;
+ return false;
+ });
+ tryRemoveTrivialPhi(MPhi);
+ }
+}
+
+static MemoryAccess *getNewDefiningAccessForClone(MemoryAccess *MA,
+ const ValueToValueMapTy &VMap,
+ PhiToDefMap &MPhiMap,
+ bool CloneWasSimplified,
+ MemorySSA *MSSA) {
+ MemoryAccess *InsnDefining = MA;
+ if (MemoryDef *DefMUD = dyn_cast<MemoryDef>(InsnDefining)) {
+ if (!MSSA->isLiveOnEntryDef(DefMUD)) {
+ Instruction *DefMUDI = DefMUD->getMemoryInst();
+ assert(DefMUDI && "Found MemoryUseOrDef with no Instruction.");
+ if (Instruction *NewDefMUDI =
+ cast_or_null<Instruction>(VMap.lookup(DefMUDI))) {
+ InsnDefining = MSSA->getMemoryAccess(NewDefMUDI);
+ if (!CloneWasSimplified)
+ assert(InsnDefining && "Defining instruction cannot be nullptr.");
+ else if (!InsnDefining || isa<MemoryUse>(InsnDefining)) {
+ // The clone was simplified, it's no longer a MemoryDef, look up.
+ auto DefIt = DefMUD->getDefsIterator();
+ // Since simplified clones only occur in single block cloning, a
+ // previous definition must exist, otherwise NewDefMUDI would not
+ // have been found in VMap.
+ assert(DefIt != MSSA->getBlockDefs(DefMUD->getBlock())->begin() &&
+ "Previous def must exist");
+ InsnDefining = getNewDefiningAccessForClone(
+ &*(--DefIt), VMap, MPhiMap, CloneWasSimplified, MSSA);
+ }
+ }
+ }
+ } else {
+ MemoryPhi *DefPhi = cast<MemoryPhi>(InsnDefining);
+ if (MemoryAccess *NewDefPhi = MPhiMap.lookup(DefPhi))
+ InsnDefining = NewDefPhi;
+ }
+ assert(InsnDefining && "Defining instruction cannot be nullptr.");
+ return InsnDefining;
+}
+
+void MemorySSAUpdater::cloneUsesAndDefs(BasicBlock *BB, BasicBlock *NewBB,
+ const ValueToValueMapTy &VMap,
+ PhiToDefMap &MPhiMap,
+ bool CloneWasSimplified) {
+ const MemorySSA::AccessList *Acc = MSSA->getBlockAccesses(BB);
+ if (!Acc)
+ return;
+ for (const MemoryAccess &MA : *Acc) {
+ if (const MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&MA)) {
+ Instruction *Insn = MUD->getMemoryInst();
+ // Entry does not exist if the clone of the block did not clone all
+ // instructions. This occurs in LoopRotate when cloning instructions
+ // from the old header to the old preheader. The cloned instruction may
+ // also be a simplified Value, not an Instruction (see LoopRotate).
+ // Also in LoopRotate, even when it's an instruction, due to it being
+ // simplified, it may be a Use rather than a Def, so we cannot use MUD as
+ // template. Calls coming from updateForClonedBlockIntoPred, ensure this.
+ if (Instruction *NewInsn =
+ dyn_cast_or_null<Instruction>(VMap.lookup(Insn))) {
+ MemoryAccess *NewUseOrDef = MSSA->createDefinedAccess(
+ NewInsn,
+ getNewDefiningAccessForClone(MUD->getDefiningAccess(), VMap,
+ MPhiMap, CloneWasSimplified, MSSA),
+ /*Template=*/CloneWasSimplified ? nullptr : MUD,
+ /*CreationMustSucceed=*/CloneWasSimplified ? false : true);
+ if (NewUseOrDef)
+ MSSA->insertIntoListsForBlock(NewUseOrDef, NewBB, MemorySSA::End);
+ }
+ }
+ }
+}
+
+void MemorySSAUpdater::updatePhisWhenInsertingUniqueBackedgeBlock(
+ BasicBlock *Header, BasicBlock *Preheader, BasicBlock *BEBlock) {
+ auto *MPhi = MSSA->getMemoryAccess(Header);
+ if (!MPhi)
+ return;
+
+ // Create phi node in the backedge block and populate it with the same
+ // incoming values as MPhi. Skip incoming values coming from Preheader.
+ auto *NewMPhi = MSSA->createMemoryPhi(BEBlock);
+ bool HasUniqueIncomingValue = true;
+ MemoryAccess *UniqueValue = nullptr;
+ for (unsigned I = 0, E = MPhi->getNumIncomingValues(); I != E; ++I) {
+ BasicBlock *IBB = MPhi->getIncomingBlock(I);
+ MemoryAccess *IV = MPhi->getIncomingValue(I);
+ if (IBB != Preheader) {
+ NewMPhi->addIncoming(IV, IBB);
+ if (HasUniqueIncomingValue) {
+ if (!UniqueValue)
+ UniqueValue = IV;
+ else if (UniqueValue != IV)
+ HasUniqueIncomingValue = false;
+ }
+ }
+ }
+
+ // Update incoming edges into MPhi. Remove all but the incoming edge from
+ // Preheader. Add an edge from NewMPhi
+ auto *AccFromPreheader = MPhi->getIncomingValueForBlock(Preheader);
+ MPhi->setIncomingValue(0, AccFromPreheader);
+ MPhi->setIncomingBlock(0, Preheader);
+ for (unsigned I = MPhi->getNumIncomingValues() - 1; I >= 1; --I)
+ MPhi->unorderedDeleteIncoming(I);
+ MPhi->addIncoming(NewMPhi, BEBlock);
+
+ // If NewMPhi is a trivial phi, remove it. Its use in the header MPhi will be
+ // replaced with the unique value.
+ tryRemoveTrivialPhi(NewMPhi);
+}
+
+void MemorySSAUpdater::updateForClonedLoop(const LoopBlocksRPO &LoopBlocks,
+ ArrayRef<BasicBlock *> ExitBlocks,
+ const ValueToValueMapTy &VMap,
+ bool IgnoreIncomingWithNoClones) {
+ PhiToDefMap MPhiMap;
+
+ auto FixPhiIncomingValues = [&](MemoryPhi *Phi, MemoryPhi *NewPhi) {
+ assert(Phi && NewPhi && "Invalid Phi nodes.");
+ BasicBlock *NewPhiBB = NewPhi->getBlock();
+ SmallPtrSet<BasicBlock *, 4> NewPhiBBPreds(pred_begin(NewPhiBB),
+ pred_end(NewPhiBB));
+ for (unsigned It = 0, E = Phi->getNumIncomingValues(); It < E; ++It) {
+ MemoryAccess *IncomingAccess = Phi->getIncomingValue(It);
+ BasicBlock *IncBB = Phi->getIncomingBlock(It);
+
+ if (BasicBlock *NewIncBB = cast_or_null<BasicBlock>(VMap.lookup(IncBB)))
+ IncBB = NewIncBB;
+ else if (IgnoreIncomingWithNoClones)
+ continue;
+
+ // Now we have IncBB, and will need to add incoming from it to NewPhi.
+
+ // If IncBB is not a predecessor of NewPhiBB, then do not add it.
+ // NewPhiBB was cloned without that edge.
+ if (!NewPhiBBPreds.count(IncBB))
+ continue;
+
+ // Determine incoming value and add it as incoming from IncBB.
+ if (MemoryUseOrDef *IncMUD = dyn_cast<MemoryUseOrDef>(IncomingAccess)) {
+ if (!MSSA->isLiveOnEntryDef(IncMUD)) {
+ Instruction *IncI = IncMUD->getMemoryInst();
+ assert(IncI && "Found MemoryUseOrDef with no Instruction.");
+ if (Instruction *NewIncI =
+ cast_or_null<Instruction>(VMap.lookup(IncI))) {
+ IncMUD = MSSA->getMemoryAccess(NewIncI);
+ assert(IncMUD &&
+ "MemoryUseOrDef cannot be null, all preds processed.");
+ }
+ }
+ NewPhi->addIncoming(IncMUD, IncBB);
+ } else {
+ MemoryPhi *IncPhi = cast<MemoryPhi>(IncomingAccess);
+ if (MemoryAccess *NewDefPhi = MPhiMap.lookup(IncPhi))
+ NewPhi->addIncoming(NewDefPhi, IncBB);
+ else
+ NewPhi->addIncoming(IncPhi, IncBB);
+ }
+ }
+ };
+
+ auto ProcessBlock = [&](BasicBlock *BB) {
+ BasicBlock *NewBlock = cast_or_null<BasicBlock>(VMap.lookup(BB));
+ if (!NewBlock)
+ return;
+
+ assert(!MSSA->getWritableBlockAccesses(NewBlock) &&
+ "Cloned block should have no accesses");
+
+ // Add MemoryPhi.
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) {
+ MemoryPhi *NewPhi = MSSA->createMemoryPhi(NewBlock);
+ MPhiMap[MPhi] = NewPhi;
+ }
+ // Update Uses and Defs.
+ cloneUsesAndDefs(BB, NewBlock, VMap, MPhiMap);
+ };
+
+ for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks))
+ ProcessBlock(BB);
+
+ for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks))
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB))
+ if (MemoryAccess *NewPhi = MPhiMap.lookup(MPhi))
+ FixPhiIncomingValues(MPhi, cast<MemoryPhi>(NewPhi));
+}
+
+void MemorySSAUpdater::updateForClonedBlockIntoPred(
+ BasicBlock *BB, BasicBlock *P1, const ValueToValueMapTy &VM) {
+ // All defs/phis from outside BB that are used in BB, are valid uses in P1.
+ // Since those defs/phis must have dominated BB, and also dominate P1.
+ // Defs from BB being used in BB will be replaced with the cloned defs from
+ // VM. The uses of BB's Phi (if it exists) in BB will be replaced by the
+ // incoming def into the Phi from P1.
+ // Instructions cloned into the predecessor are in practice sometimes
+ // simplified, so disable the use of the template, and create an access from
+ // scratch.
+ PhiToDefMap MPhiMap;
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB))
+ MPhiMap[MPhi] = MPhi->getIncomingValueForBlock(P1);
+ cloneUsesAndDefs(BB, P1, VM, MPhiMap, /*CloneWasSimplified=*/true);
+}
+
+template <typename Iter>
+void MemorySSAUpdater::privateUpdateExitBlocksForClonedLoop(
+ ArrayRef<BasicBlock *> ExitBlocks, Iter ValuesBegin, Iter ValuesEnd,
+ DominatorTree &DT) {
+ SmallVector<CFGUpdate, 4> Updates;
+ // Update/insert phis in all successors of exit blocks.
+ for (auto *Exit : ExitBlocks)
+ for (const ValueToValueMapTy *VMap : make_range(ValuesBegin, ValuesEnd))
+ if (BasicBlock *NewExit = cast_or_null<BasicBlock>(VMap->lookup(Exit))) {
+ BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0);
+ Updates.push_back({DT.Insert, NewExit, ExitSucc});
+ }
+ applyInsertUpdates(Updates, DT);
+}
+
+void MemorySSAUpdater::updateExitBlocksForClonedLoop(
+ ArrayRef<BasicBlock *> ExitBlocks, const ValueToValueMapTy &VMap,
+ DominatorTree &DT) {
+ const ValueToValueMapTy *const Arr[] = {&VMap};
+ privateUpdateExitBlocksForClonedLoop(ExitBlocks, std::begin(Arr),
+ std::end(Arr), DT);
+}
+
+void MemorySSAUpdater::updateExitBlocksForClonedLoop(
+ ArrayRef<BasicBlock *> ExitBlocks,
+ ArrayRef<std::unique_ptr<ValueToValueMapTy>> VMaps, DominatorTree &DT) {
+ auto GetPtr = [&](const std::unique_ptr<ValueToValueMapTy> &I) {
+ return I.get();
+ };
+ using MappedIteratorType =
+ mapped_iterator<const std::unique_ptr<ValueToValueMapTy> *,
+ decltype(GetPtr)>;
+ auto MapBegin = MappedIteratorType(VMaps.begin(), GetPtr);
+ auto MapEnd = MappedIteratorType(VMaps.end(), GetPtr);
+ privateUpdateExitBlocksForClonedLoop(ExitBlocks, MapBegin, MapEnd, DT);
+}
+
+void MemorySSAUpdater::applyUpdates(ArrayRef<CFGUpdate> Updates,
+ DominatorTree &DT) {
+ SmallVector<CFGUpdate, 4> RevDeleteUpdates;
+ SmallVector<CFGUpdate, 4> InsertUpdates;
+ for (auto &Update : Updates) {
+ if (Update.getKind() == DT.Insert)
+ InsertUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()});
+ else
+ RevDeleteUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()});
+ }
+
+ if (!RevDeleteUpdates.empty()) {
+ // Update for inserted edges: use newDT and snapshot CFG as if deletes had
+ // not occurred.
+ // FIXME: This creates a new DT, so it's more expensive to do mix
+ // delete/inserts vs just inserts. We can do an incremental update on the DT
+ // to revert deletes, than re-delete the edges. Teaching DT to do this, is
+ // part of a pending cleanup.
+ DominatorTree NewDT(DT, RevDeleteUpdates);
+ GraphDiff<BasicBlock *> GD(RevDeleteUpdates);
+ applyInsertUpdates(InsertUpdates, NewDT, &GD);
+ } else {
+ GraphDiff<BasicBlock *> GD;
+ applyInsertUpdates(InsertUpdates, DT, &GD);
+ }
+
+ // Update for deleted edges
+ for (auto &Update : RevDeleteUpdates)
+ removeEdge(Update.getFrom(), Update.getTo());
+}
+
+void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates,
+ DominatorTree &DT) {
+ GraphDiff<BasicBlock *> GD;
+ applyInsertUpdates(Updates, DT, &GD);
+}
+
+void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates,
+ DominatorTree &DT,
+ const GraphDiff<BasicBlock *> *GD) {
+ // Get recursive last Def, assuming well formed MSSA and updated DT.
+ auto GetLastDef = [&](BasicBlock *BB) -> MemoryAccess * {
+ while (true) {
+ MemorySSA::DefsList *Defs = MSSA->getWritableBlockDefs(BB);
+ // Return last Def or Phi in BB, if it exists.
+ if (Defs)
+ return &*(--Defs->end());
+
+ // Check number of predecessors, we only care if there's more than one.
+ unsigned Count = 0;
+ BasicBlock *Pred = nullptr;
+ for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) {
+ Pred = Pair.second;
+ Count++;
+ if (Count == 2)
+ break;
+ }
+
+ // If BB has multiple predecessors, get last definition from IDom.
+ if (Count != 1) {
+ // [SimpleLoopUnswitch] If BB is a dead block, about to be deleted, its
+ // DT is invalidated. Return LoE as its last def. This will be added to
+ // MemoryPhi node, and later deleted when the block is deleted.
+ if (!DT.getNode(BB))
+ return MSSA->getLiveOnEntryDef();
+ if (auto *IDom = DT.getNode(BB)->getIDom())
+ if (IDom->getBlock() != BB) {
+ BB = IDom->getBlock();
+ continue;
+ }
+ return MSSA->getLiveOnEntryDef();
+ } else {
+ // Single predecessor, BB cannot be dead. GetLastDef of Pred.
+ assert(Count == 1 && Pred && "Single predecessor expected.");
+ // BB can be unreachable though, return LoE if that is the case.
+ if (!DT.getNode(BB))
+ return MSSA->getLiveOnEntryDef();
+ BB = Pred;
+ }
+ };
+ llvm_unreachable("Unable to get last definition.");
+ };
+
+ // Get nearest IDom given a set of blocks.
+ // TODO: this can be optimized by starting the search at the node with the
+ // lowest level (highest in the tree).
+ auto FindNearestCommonDominator =
+ [&](const SmallSetVector<BasicBlock *, 2> &BBSet) -> BasicBlock * {
+ BasicBlock *PrevIDom = *BBSet.begin();
+ for (auto *BB : BBSet)
+ PrevIDom = DT.findNearestCommonDominator(PrevIDom, BB);
+ return PrevIDom;
+ };
+
+ // Get all blocks that dominate PrevIDom, stop when reaching CurrIDom. Do not
+ // include CurrIDom.
+ auto GetNoLongerDomBlocks =
+ [&](BasicBlock *PrevIDom, BasicBlock *CurrIDom,
+ SmallVectorImpl<BasicBlock *> &BlocksPrevDom) {
+ if (PrevIDom == CurrIDom)
+ return;
+ BlocksPrevDom.push_back(PrevIDom);
+ BasicBlock *NextIDom = PrevIDom;
+ while (BasicBlock *UpIDom =
+ DT.getNode(NextIDom)->getIDom()->getBlock()) {
+ if (UpIDom == CurrIDom)
+ break;
+ BlocksPrevDom.push_back(UpIDom);
+ NextIDom = UpIDom;
+ }
+ };
+
+ // Map a BB to its predecessors: added + previously existing. To get a
+ // deterministic order, store predecessors as SetVectors. The order in each
+ // will be defined by the order in Updates (fixed) and the order given by
+ // children<> (also fixed). Since we further iterate over these ordered sets,
+ // we lose the information of multiple edges possibly existing between two
+ // blocks, so we'll keep and EdgeCount map for that.
+ // An alternate implementation could keep unordered set for the predecessors,
+ // traverse either Updates or children<> each time to get the deterministic
+ // order, and drop the usage of EdgeCount. This alternate approach would still
+ // require querying the maps for each predecessor, and children<> call has
+ // additional computation inside for creating the snapshot-graph predecessors.
+ // As such, we favor using a little additional storage and less compute time.
+ // This decision can be revisited if we find the alternative more favorable.
+
+ struct PredInfo {
+ SmallSetVector<BasicBlock *, 2> Added;
+ SmallSetVector<BasicBlock *, 2> Prev;
+ };
+ SmallDenseMap<BasicBlock *, PredInfo> PredMap;
+
+ for (auto &Edge : Updates) {
+ BasicBlock *BB = Edge.getTo();
+ auto &AddedBlockSet = PredMap[BB].Added;
+ AddedBlockSet.insert(Edge.getFrom());
+ }
+
+ // Store all existing predecessor for each BB, at least one must exist.
+ SmallDenseMap<std::pair<BasicBlock *, BasicBlock *>, int> EdgeCountMap;
+ SmallPtrSet<BasicBlock *, 2> NewBlocks;
+ for (auto &BBPredPair : PredMap) {
+ auto *BB = BBPredPair.first;
+ const auto &AddedBlockSet = BBPredPair.second.Added;
+ auto &PrevBlockSet = BBPredPair.second.Prev;
+ for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) {
+ BasicBlock *Pi = Pair.second;
+ if (!AddedBlockSet.count(Pi))
+ PrevBlockSet.insert(Pi);
+ EdgeCountMap[{Pi, BB}]++;
+ }
+
+ if (PrevBlockSet.empty()) {
+ assert(pred_size(BB) == AddedBlockSet.size() && "Duplicate edges added.");
+ LLVM_DEBUG(
+ dbgs()
+ << "Adding a predecessor to a block with no predecessors. "
+ "This must be an edge added to a new, likely cloned, block. "
+ "Its memory accesses must be already correct, assuming completed "
+ "via the updateExitBlocksForClonedLoop API. "
+ "Assert a single such edge is added so no phi addition or "
+ "additional processing is required.\n");
+ assert(AddedBlockSet.size() == 1 &&
+ "Can only handle adding one predecessor to a new block.");
+ // Need to remove new blocks from PredMap. Remove below to not invalidate
+ // iterator here.
+ NewBlocks.insert(BB);
+ }
+ }
+ // Nothing to process for new/cloned blocks.
+ for (auto *BB : NewBlocks)
+ PredMap.erase(BB);
+
+ SmallVector<BasicBlock *, 16> BlocksWithDefsToReplace;
+ SmallVector<WeakVH, 8> InsertedPhis;
+
+ // First create MemoryPhis in all blocks that don't have one. Create in the
+ // order found in Updates, not in PredMap, to get deterministic numbering.
+ for (auto &Edge : Updates) {
+ BasicBlock *BB = Edge.getTo();
+ if (PredMap.count(BB) && !MSSA->getMemoryAccess(BB))
+ InsertedPhis.push_back(MSSA->createMemoryPhi(BB));
+ }
+
+ // Now we'll fill in the MemoryPhis with the right incoming values.
+ for (auto &BBPredPair : PredMap) {
+ auto *BB = BBPredPair.first;
+ const auto &PrevBlockSet = BBPredPair.second.Prev;
+ const auto &AddedBlockSet = BBPredPair.second.Added;
+ assert(!PrevBlockSet.empty() &&
+ "At least one previous predecessor must exist.");
+
+ // TODO: if this becomes a bottleneck, we can save on GetLastDef calls by
+ // keeping this map before the loop. We can reuse already populated entries
+ // if an edge is added from the same predecessor to two different blocks,
+ // and this does happen in rotate. Note that the map needs to be updated
+ // when deleting non-necessary phis below, if the phi is in the map by
+ // replacing the value with DefP1.
+ SmallDenseMap<BasicBlock *, MemoryAccess *> LastDefAddedPred;
+ for (auto *AddedPred : AddedBlockSet) {
+ auto *DefPn = GetLastDef(AddedPred);
+ assert(DefPn != nullptr && "Unable to find last definition.");
+ LastDefAddedPred[AddedPred] = DefPn;
+ }
+
+ MemoryPhi *NewPhi = MSSA->getMemoryAccess(BB);
+ // If Phi is not empty, add an incoming edge from each added pred. Must
+ // still compute blocks with defs to replace for this block below.
+ if (NewPhi->getNumOperands()) {
+ for (auto *Pred : AddedBlockSet) {
+ auto *LastDefForPred = LastDefAddedPred[Pred];
+ for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I)
+ NewPhi->addIncoming(LastDefForPred, Pred);
+ }
+ } else {
+ // Pick any existing predecessor and get its definition. All other
+ // existing predecessors should have the same one, since no phi existed.
+ auto *P1 = *PrevBlockSet.begin();
+ MemoryAccess *DefP1 = GetLastDef(P1);
+
+ // Check DefP1 against all Defs in LastDefPredPair. If all the same,
+ // nothing to add.
+ bool InsertPhi = false;
+ for (auto LastDefPredPair : LastDefAddedPred)
+ if (DefP1 != LastDefPredPair.second) {
+ InsertPhi = true;
+ break;
+ }
+ if (!InsertPhi) {
+ // Since NewPhi may be used in other newly added Phis, replace all uses
+ // of NewPhi with the definition coming from all predecessors (DefP1),
+ // before deleting it.
+ NewPhi->replaceAllUsesWith(DefP1);
+ removeMemoryAccess(NewPhi);
+ continue;
+ }
+
+ // Update Phi with new values for new predecessors and old value for all
+ // other predecessors. Since AddedBlockSet and PrevBlockSet are ordered
+ // sets, the order of entries in NewPhi is deterministic.
+ for (auto *Pred : AddedBlockSet) {
+ auto *LastDefForPred = LastDefAddedPred[Pred];
+ for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I)
+ NewPhi->addIncoming(LastDefForPred, Pred);
+ }
+ for (auto *Pred : PrevBlockSet)
+ for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I)
+ NewPhi->addIncoming(DefP1, Pred);
+ }
+
+ // Get all blocks that used to dominate BB and no longer do after adding
+ // AddedBlockSet, where PrevBlockSet are the previously known predecessors.
+ assert(DT.getNode(BB)->getIDom() && "BB does not have valid idom");
+ BasicBlock *PrevIDom = FindNearestCommonDominator(PrevBlockSet);
+ assert(PrevIDom && "Previous IDom should exists");
+ BasicBlock *NewIDom = DT.getNode(BB)->getIDom()->getBlock();
+ assert(NewIDom && "BB should have a new valid idom");
+ assert(DT.dominates(NewIDom, PrevIDom) &&
+ "New idom should dominate old idom");
+ GetNoLongerDomBlocks(PrevIDom, NewIDom, BlocksWithDefsToReplace);
+ }
+
+ tryRemoveTrivialPhis(InsertedPhis);
+ // Create the set of blocks that now have a definition. We'll use this to
+ // compute IDF and add Phis there next.
+ SmallVector<BasicBlock *, 8> BlocksToProcess;
+ for (auto &VH : InsertedPhis)
+ if (auto *MPhi = cast_or_null<MemoryPhi>(VH))
+ BlocksToProcess.push_back(MPhi->getBlock());
+
+ // Compute IDF and add Phis in all IDF blocks that do not have one.
+ SmallVector<BasicBlock *, 32> IDFBlocks;
+ if (!BlocksToProcess.empty()) {
+ ForwardIDFCalculator IDFs(DT, GD);
+ SmallPtrSet<BasicBlock *, 16> DefiningBlocks(BlocksToProcess.begin(),
+ BlocksToProcess.end());
+ IDFs.setDefiningBlocks(DefiningBlocks);
+ IDFs.calculate(IDFBlocks);
+
+ SmallSetVector<MemoryPhi *, 4> PhisToFill;
+ // First create all needed Phis.
+ for (auto *BBIDF : IDFBlocks)
+ if (!MSSA->getMemoryAccess(BBIDF)) {
+ auto *IDFPhi = MSSA->createMemoryPhi(BBIDF);
+ InsertedPhis.push_back(IDFPhi);
+ PhisToFill.insert(IDFPhi);
+ }
+ // Then update or insert their correct incoming values.
+ for (auto *BBIDF : IDFBlocks) {
+ auto *IDFPhi = MSSA->getMemoryAccess(BBIDF);
+ assert(IDFPhi && "Phi must exist");
+ if (!PhisToFill.count(IDFPhi)) {
+ // Update existing Phi.
+ // FIXME: some updates may be redundant, try to optimize and skip some.
+ for (unsigned I = 0, E = IDFPhi->getNumIncomingValues(); I < E; ++I)
+ IDFPhi->setIncomingValue(I, GetLastDef(IDFPhi->getIncomingBlock(I)));
+ } else {
+ for (auto &Pair : children<GraphDiffInvBBPair>({GD, BBIDF})) {
+ BasicBlock *Pi = Pair.second;
+ IDFPhi->addIncoming(GetLastDef(Pi), Pi);
+ }
+ }
+ }
+ }
+
+ // Now for all defs in BlocksWithDefsToReplace, if there are uses they no
+ // longer dominate, replace those with the closest dominating def.
+ // This will also update optimized accesses, as they're also uses.
+ for (auto *BlockWithDefsToReplace : BlocksWithDefsToReplace) {
+ if (auto DefsList = MSSA->getWritableBlockDefs(BlockWithDefsToReplace)) {
+ for (auto &DefToReplaceUses : *DefsList) {
+ BasicBlock *DominatingBlock = DefToReplaceUses.getBlock();
+ Value::use_iterator UI = DefToReplaceUses.use_begin(),
+ E = DefToReplaceUses.use_end();
+ for (; UI != E;) {
+ Use &U = *UI;
+ ++UI;
+ MemoryAccess *Usr = cast<MemoryAccess>(U.getUser());
+ if (MemoryPhi *UsrPhi = dyn_cast<MemoryPhi>(Usr)) {
+ BasicBlock *DominatedBlock = UsrPhi->getIncomingBlock(U);
+ if (!DT.dominates(DominatingBlock, DominatedBlock))
+ U.set(GetLastDef(DominatedBlock));
+ } else {
+ BasicBlock *DominatedBlock = Usr->getBlock();
+ if (!DT.dominates(DominatingBlock, DominatedBlock)) {
+ if (auto *DomBlPhi = MSSA->getMemoryAccess(DominatedBlock))
+ U.set(DomBlPhi);
+ else {
+ auto *IDom = DT.getNode(DominatedBlock)->getIDom();
+ assert(IDom && "Block must have a valid IDom.");
+ U.set(GetLastDef(IDom->getBlock()));
+ }
+ cast<MemoryUseOrDef>(Usr)->resetOptimized();
+ }
+ }
+ }
+ }
+ }
+ }
+ tryRemoveTrivialPhis(InsertedPhis);
+}
+
+// Move What before Where in the MemorySSA IR.
+template <class WhereType>
+void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB,
+ WhereType Where) {
+ // Mark MemoryPhi users of What not to be optimized.
+ for (auto *U : What->users())
+ if (MemoryPhi *PhiUser = dyn_cast<MemoryPhi>(U))
+ NonOptPhis.insert(PhiUser);
+
+ // Replace all our users with our defining access.
+ What->replaceAllUsesWith(What->getDefiningAccess());
+
+ // Let MemorySSA take care of moving it around in the lists.
+ MSSA->moveTo(What, BB, Where);
+
+ // Now reinsert it into the IR and do whatever fixups needed.
+ if (auto *MD = dyn_cast<MemoryDef>(What))
+ insertDef(MD, /*RenameUses=*/true);
+ else
+ insertUse(cast<MemoryUse>(What), /*RenameUses=*/true);
+
+ // Clear dangling pointers. We added all MemoryPhi users, but not all
+ // of them are removed by fixupDefs().
+ NonOptPhis.clear();
+}
+
+// Move What before Where in the MemorySSA IR.
+void MemorySSAUpdater::moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where) {
+ moveTo(What, Where->getBlock(), Where->getIterator());
+}
+
+// Move What after Where in the MemorySSA IR.
+void MemorySSAUpdater::moveAfter(MemoryUseOrDef *What, MemoryUseOrDef *Where) {
+ moveTo(What, Where->getBlock(), ++Where->getIterator());
+}
+
+void MemorySSAUpdater::moveToPlace(MemoryUseOrDef *What, BasicBlock *BB,
+ MemorySSA::InsertionPlace Where) {
+ return moveTo(What, BB, Where);
+}
+
+// All accesses in To used to be in From. Move to end and update access lists.
+void MemorySSAUpdater::moveAllAccesses(BasicBlock *From, BasicBlock *To,
+ Instruction *Start) {
+
+ MemorySSA::AccessList *Accs = MSSA->getWritableBlockAccesses(From);
+ if (!Accs)
+ return;
+
+ assert(Start->getParent() == To && "Incorrect Start instruction");
+ MemoryAccess *FirstInNew = nullptr;
+ for (Instruction &I : make_range(Start->getIterator(), To->end()))
+ if ((FirstInNew = MSSA->getMemoryAccess(&I)))
+ break;
+ if (FirstInNew) {
+ auto *MUD = cast<MemoryUseOrDef>(FirstInNew);
+ do {
+ auto NextIt = ++MUD->getIterator();
+ MemoryUseOrDef *NextMUD = (!Accs || NextIt == Accs->end())
+ ? nullptr
+ : cast<MemoryUseOrDef>(&*NextIt);
+ MSSA->moveTo(MUD, To, MemorySSA::End);
+ // Moving MUD from Accs in the moveTo above, may delete Accs, so we need
+ // to retrieve it again.
+ Accs = MSSA->getWritableBlockAccesses(From);
+ MUD = NextMUD;
+ } while (MUD);
+ }
+
+ // If all accesses were moved and only a trivial Phi remains, we try to remove
+ // that Phi. This is needed when From is going to be deleted.
+ auto *Defs = MSSA->getWritableBlockDefs(From);
+ if (Defs && !Defs->empty())
+ if (auto *Phi = dyn_cast<MemoryPhi>(&*Defs->begin()))
+ tryRemoveTrivialPhi(Phi);
+}
+
+void MemorySSAUpdater::moveAllAfterSpliceBlocks(BasicBlock *From,
+ BasicBlock *To,
+ Instruction *Start) {
+ assert(MSSA->getBlockAccesses(To) == nullptr &&
+ "To block is expected to be free of MemoryAccesses.");
+ moveAllAccesses(From, To, Start);
+ for (BasicBlock *Succ : successors(To))
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ))
+ MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To);
+}
+
+void MemorySSAUpdater::moveAllAfterMergeBlocks(BasicBlock *From, BasicBlock *To,
+ Instruction *Start) {
+ assert(From->getUniquePredecessor() == To &&
+ "From block is expected to have a single predecessor (To).");
+ moveAllAccesses(From, To, Start);
+ for (BasicBlock *Succ : successors(From))
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ))
+ MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To);
+}
+
+/// If all arguments of a MemoryPHI are defined by the same incoming
+/// argument, return that argument.
+static MemoryAccess *onlySingleValue(MemoryPhi *MP) {
+ MemoryAccess *MA = nullptr;
+
+ for (auto &Arg : MP->operands()) {
+ if (!MA)
+ MA = cast<MemoryAccess>(Arg);
+ else if (MA != Arg)
+ return nullptr;
+ }
+ return MA;
+}
+
+void MemorySSAUpdater::wireOldPredecessorsToNewImmediatePredecessor(
+ BasicBlock *Old, BasicBlock *New, ArrayRef<BasicBlock *> Preds,
+ bool IdenticalEdgesWereMerged) {
+ assert(!MSSA->getWritableBlockAccesses(New) &&
+ "Access list should be null for a new block.");
+ MemoryPhi *Phi = MSSA->getMemoryAccess(Old);
+ if (!Phi)
+ return;
+ if (Old->hasNPredecessors(1)) {
+ assert(pred_size(New) == Preds.size() &&
+ "Should have moved all predecessors.");
+ MSSA->moveTo(Phi, New, MemorySSA::Beginning);
+ } else {
+ assert(!Preds.empty() && "Must be moving at least one predecessor to the "
+ "new immediate predecessor.");
+ MemoryPhi *NewPhi = MSSA->createMemoryPhi(New);
+ SmallPtrSet<BasicBlock *, 16> PredsSet(Preds.begin(), Preds.end());
+ // Currently only support the case of removing a single incoming edge when
+ // identical edges were not merged.
+ if (!IdenticalEdgesWereMerged)
+ assert(PredsSet.size() == Preds.size() &&
+ "If identical edges were not merged, we cannot have duplicate "
+ "blocks in the predecessors");
+ Phi->unorderedDeleteIncomingIf([&](MemoryAccess *MA, BasicBlock *B) {
+ if (PredsSet.count(B)) {
+ NewPhi->addIncoming(MA, B);
+ if (!IdenticalEdgesWereMerged)
+ PredsSet.erase(B);
+ return true;
+ }
+ return false;
+ });
+ Phi->addIncoming(NewPhi, New);
+ tryRemoveTrivialPhi(NewPhi);
+ }
+}
+
+void MemorySSAUpdater::removeMemoryAccess(MemoryAccess *MA, bool OptimizePhis) {
+ assert(!MSSA->isLiveOnEntryDef(MA) &&
+ "Trying to remove the live on entry def");
+ // We can only delete phi nodes if they have no uses, or we can replace all
+ // uses with a single definition.
+ MemoryAccess *NewDefTarget = nullptr;
+ if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) {
+ // Note that it is sufficient to know that all edges of the phi node have
+ // the same argument. If they do, by the definition of dominance frontiers
+ // (which we used to place this phi), that argument must dominate this phi,
+ // and thus, must dominate the phi's uses, and so we will not hit the assert
+ // below.
+ NewDefTarget = onlySingleValue(MP);
+ assert((NewDefTarget || MP->use_empty()) &&
+ "We can't delete this memory phi");
+ } else {
+ NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess();
+ }
+
+ SmallSetVector<MemoryPhi *, 4> PhisToCheck;
+
+ // Re-point the uses at our defining access
+ if (!isa<MemoryUse>(MA) && !MA->use_empty()) {
+ // Reset optimized on users of this store, and reset the uses.
+ // A few notes:
+ // 1. This is a slightly modified version of RAUW to avoid walking the
+ // uses twice here.
+ // 2. If we wanted to be complete, we would have to reset the optimized
+ // flags on users of phi nodes if doing the below makes a phi node have all
+ // the same arguments. Instead, we prefer users to removeMemoryAccess those
+ // phi nodes, because doing it here would be N^3.
+ if (MA->hasValueHandle())
+ ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget);
+ // Note: We assume MemorySSA is not used in metadata since it's not really
+ // part of the IR.
+
+ while (!MA->use_empty()) {
+ Use &U = *MA->use_begin();
+ if (auto *MUD = dyn_cast<MemoryUseOrDef>(U.getUser()))
+ MUD->resetOptimized();
+ if (OptimizePhis)
+ if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U.getUser()))
+ PhisToCheck.insert(MP);
+ U.set(NewDefTarget);
+ }
+ }
+
+ // The call below to erase will destroy MA, so we can't change the order we
+ // are doing things here
+ MSSA->removeFromLookups(MA);
+ MSSA->removeFromLists(MA);
+
+ // Optionally optimize Phi uses. This will recursively remove trivial phis.
+ if (!PhisToCheck.empty()) {
+ SmallVector<WeakVH, 16> PhisToOptimize{PhisToCheck.begin(),
+ PhisToCheck.end()};
+ PhisToCheck.clear();
+
+ unsigned PhisSize = PhisToOptimize.size();
+ while (PhisSize-- > 0)
+ if (MemoryPhi *MP =
+ cast_or_null<MemoryPhi>(PhisToOptimize.pop_back_val()))
+ tryRemoveTrivialPhi(MP);
+ }
+}
+
+void MemorySSAUpdater::removeBlocks(
+ const SmallSetVector<BasicBlock *, 8> &DeadBlocks) {
+ // First delete all uses of BB in MemoryPhis.
+ for (BasicBlock *BB : DeadBlocks) {
+ Instruction *TI = BB->getTerminator();
+ assert(TI && "Basic block expected to have a terminator instruction");
+ for (BasicBlock *Succ : successors(TI))
+ if (!DeadBlocks.count(Succ))
+ if (MemoryPhi *MP = MSSA->getMemoryAccess(Succ)) {
+ MP->unorderedDeleteIncomingBlock(BB);
+ tryRemoveTrivialPhi(MP);
+ }
+ // Drop all references of all accesses in BB
+ if (MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB))
+ for (MemoryAccess &MA : *Acc)
+ MA.dropAllReferences();
+ }
+
+ // Next, delete all memory accesses in each block
+ for (BasicBlock *BB : DeadBlocks) {
+ MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB);
+ if (!Acc)
+ continue;
+ for (auto AB = Acc->begin(), AE = Acc->end(); AB != AE;) {
+ MemoryAccess *MA = &*AB;
+ ++AB;
+ MSSA->removeFromLookups(MA);
+ MSSA->removeFromLists(MA);
+ }
+ }
+}
+
+void MemorySSAUpdater::tryRemoveTrivialPhis(ArrayRef<WeakVH> UpdatedPHIs) {
+ for (auto &VH : UpdatedPHIs)
+ if (auto *MPhi = cast_or_null<MemoryPhi>(VH))
+ tryRemoveTrivialPhi(MPhi);
+}
+
+void MemorySSAUpdater::changeToUnreachable(const Instruction *I) {
+ const BasicBlock *BB = I->getParent();
+ // Remove memory accesses in BB for I and all following instructions.
+ auto BBI = I->getIterator(), BBE = BB->end();
+ // FIXME: If this becomes too expensive, iterate until the first instruction
+ // with a memory access, then iterate over MemoryAccesses.
+ while (BBI != BBE)
+ removeMemoryAccess(&*(BBI++));
+ // Update phis in BB's successors to remove BB.
+ SmallVector<WeakVH, 16> UpdatedPHIs;
+ for (const BasicBlock *Successor : successors(BB)) {
+ removeDuplicatePhiEdgesBetween(BB, Successor);
+ if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Successor)) {
+ MPhi->unorderedDeleteIncomingBlock(BB);
+ UpdatedPHIs.push_back(MPhi);
+ }
+ }
+ // Optimize trivial phis.
+ tryRemoveTrivialPhis(UpdatedPHIs);
+}
+
+void MemorySSAUpdater::changeCondBranchToUnconditionalTo(const BranchInst *BI,
+ const BasicBlock *To) {
+ const BasicBlock *BB = BI->getParent();
+ SmallVector<WeakVH, 16> UpdatedPHIs;
+ for (const BasicBlock *Succ : successors(BB)) {
+ removeDuplicatePhiEdgesBetween(BB, Succ);
+ if (Succ != To)
+ if (auto *MPhi = MSSA->getMemoryAccess(Succ)) {
+ MPhi->unorderedDeleteIncomingBlock(BB);
+ UpdatedPHIs.push_back(MPhi);
+ }
+ }
+ // Optimize trivial phis.
+ tryRemoveTrivialPhis(UpdatedPHIs);
+}
+
+MemoryAccess *MemorySSAUpdater::createMemoryAccessInBB(
+ Instruction *I, MemoryAccess *Definition, const BasicBlock *BB,
+ MemorySSA::InsertionPlace Point) {
+ MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition);
+ MSSA->insertIntoListsForBlock(NewAccess, BB, Point);
+ return NewAccess;
+}
+
+MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessBefore(
+ Instruction *I, MemoryAccess *Definition, MemoryUseOrDef *InsertPt) {
+ assert(I->getParent() == InsertPt->getBlock() &&
+ "New and old access must be in the same block");
+ MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition);
+ MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(),
+ InsertPt->getIterator());
+ return NewAccess;
+}
+
+MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessAfter(
+ Instruction *I, MemoryAccess *Definition, MemoryAccess *InsertPt) {
+ assert(I->getParent() == InsertPt->getBlock() &&
+ "New and old access must be in the same block");
+ MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition);
+ MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(),
+ ++InsertPt->getIterator());
+ return NewAccess;
+}
diff --git a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp
new file mode 100644
index 000000000000..519242759824
--- /dev/null
+++ b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp
@@ -0,0 +1,127 @@
+//===-- ModuleDebugInfoPrinter.cpp - Prints module debug info metadata ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass decodes the debug info metadata in a module and prints in a
+// (sufficiently-prepared-) human-readable form.
+//
+// For example, run this pass from opt along with the -analyze option, and
+// it'll print to standard output.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/IR/DebugInfo.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+namespace {
+ class ModuleDebugInfoPrinter : public ModulePass {
+ DebugInfoFinder Finder;
+ public:
+ static char ID; // Pass identification, replacement for typeid
+ ModuleDebugInfoPrinter() : ModulePass(ID) {
+ initializeModuleDebugInfoPrinterPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnModule(Module &M) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ void print(raw_ostream &O, const Module *M) const override;
+ };
+}
+
+char ModuleDebugInfoPrinter::ID = 0;
+INITIALIZE_PASS(ModuleDebugInfoPrinter, "module-debuginfo",
+ "Decodes module-level debug info", false, true)
+
+ModulePass *llvm::createModuleDebugInfoPrinterPass() {
+ return new ModuleDebugInfoPrinter();
+}
+
+bool ModuleDebugInfoPrinter::runOnModule(Module &M) {
+ Finder.processModule(M);
+ return false;
+}
+
+static void printFile(raw_ostream &O, StringRef Filename, StringRef Directory,
+ unsigned Line = 0) {
+ if (Filename.empty())
+ return;
+
+ O << " from ";
+ if (!Directory.empty())
+ O << Directory << "/";
+ O << Filename;
+ if (Line)
+ O << ":" << Line;
+}
+
+void ModuleDebugInfoPrinter::print(raw_ostream &O, const Module *M) const {
+ // Printing the nodes directly isn't particularly helpful (since they
+ // reference other nodes that won't be printed, particularly for the
+ // filenames), so just print a few useful things.
+ for (DICompileUnit *CU : Finder.compile_units()) {
+ O << "Compile unit: ";
+ auto Lang = dwarf::LanguageString(CU->getSourceLanguage());
+ if (!Lang.empty())
+ O << Lang;
+ else
+ O << "unknown-language(" << CU->getSourceLanguage() << ")";
+ printFile(O, CU->getFilename(), CU->getDirectory());
+ O << '\n';
+ }
+
+ for (DISubprogram *S : Finder.subprograms()) {
+ O << "Subprogram: " << S->getName();
+ printFile(O, S->getFilename(), S->getDirectory(), S->getLine());
+ if (!S->getLinkageName().empty())
+ O << " ('" << S->getLinkageName() << "')";
+ O << '\n';
+ }
+
+ for (auto GVU : Finder.global_variables()) {
+ const auto *GV = GVU->getVariable();
+ O << "Global variable: " << GV->getName();
+ printFile(O, GV->getFilename(), GV->getDirectory(), GV->getLine());
+ if (!GV->getLinkageName().empty())
+ O << " ('" << GV->getLinkageName() << "')";
+ O << '\n';
+ }
+
+ for (const DIType *T : Finder.types()) {
+ O << "Type:";
+ if (!T->getName().empty())
+ O << ' ' << T->getName();
+ printFile(O, T->getFilename(), T->getDirectory(), T->getLine());
+ if (auto *BT = dyn_cast<DIBasicType>(T)) {
+ O << " ";
+ auto Encoding = dwarf::AttributeEncodingString(BT->getEncoding());
+ if (!Encoding.empty())
+ O << Encoding;
+ else
+ O << "unknown-encoding(" << BT->getEncoding() << ')';
+ } else {
+ O << ' ';
+ auto Tag = dwarf::TagString(T->getTag());
+ if (!Tag.empty())
+ O << Tag;
+ else
+ O << "unknown-tag(" << T->getTag() << ")";
+ }
+ if (auto *CT = dyn_cast<DICompositeType>(T)) {
+ if (auto *S = CT->getRawIdentifier())
+ O << " (identifier: '" << S->getString() << "')";
+ }
+ O << '\n';
+ }
+}
diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
new file mode 100644
index 000000000000..8232bf07cafc
--- /dev/null
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -0,0 +1,881 @@
+//===- ModuleSummaryAnalysis.cpp - Module summary index builder -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass builds a ModuleSummaryIndex object for the module, to be written
+// to bitcode or LLVM assembly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/IndirectCallPromotionAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/ModuleSummaryIndex.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/Object/ModuleSymbolTable.h"
+#include "llvm/Object/SymbolicFile.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "module-summary-analysis"
+
+// Option to force edges cold which will block importing when the
+// -import-cold-multiplier is set to 0. Useful for debugging.
+FunctionSummary::ForceSummaryHotnessType ForceSummaryEdgesCold =
+ FunctionSummary::FSHT_None;
+cl::opt<FunctionSummary::ForceSummaryHotnessType, true> FSEC(
+ "force-summary-edges-cold", cl::Hidden, cl::location(ForceSummaryEdgesCold),
+ cl::desc("Force all edges in the function summary to cold"),
+ cl::values(clEnumValN(FunctionSummary::FSHT_None, "none", "None."),
+ clEnumValN(FunctionSummary::FSHT_AllNonCritical,
+ "all-non-critical", "All non-critical edges."),
+ clEnumValN(FunctionSummary::FSHT_All, "all", "All edges.")));
+
+cl::opt<std::string> ModuleSummaryDotFile(
+ "module-summary-dot-file", cl::init(""), cl::Hidden,
+ cl::value_desc("filename"),
+ cl::desc("File to emit dot graph of new summary into."));
+
+// Walk through the operands of a given User via worklist iteration and populate
+// the set of GlobalValue references encountered. Invoked either on an
+// Instruction or a GlobalVariable (which walks its initializer).
+// Return true if any of the operands contains blockaddress. This is important
+// to know when computing summary for global var, because if global variable
+// references basic block address we can't import it separately from function
+// containing that basic block. For simplicity we currently don't import such
+// global vars at all. When importing function we aren't interested if any
+// instruction in it takes an address of any basic block, because instruction
+// can only take an address of basic block located in the same function.
+static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser,
+ SetVector<ValueInfo> &RefEdges,
+ SmallPtrSet<const User *, 8> &Visited) {
+ bool HasBlockAddress = false;
+ SmallVector<const User *, 32> Worklist;
+ Worklist.push_back(CurUser);
+
+ while (!Worklist.empty()) {
+ const User *U = Worklist.pop_back_val();
+
+ if (!Visited.insert(U).second)
+ continue;
+
+ ImmutableCallSite CS(U);
+
+ for (const auto &OI : U->operands()) {
+ const User *Operand = dyn_cast<User>(OI);
+ if (!Operand)
+ continue;
+ if (isa<BlockAddress>(Operand)) {
+ HasBlockAddress = true;
+ continue;
+ }
+ if (auto *GV = dyn_cast<GlobalValue>(Operand)) {
+ // We have a reference to a global value. This should be added to
+ // the reference set unless it is a callee. Callees are handled
+ // specially by WriteFunction and are added to a separate list.
+ if (!(CS && CS.isCallee(&OI)))
+ RefEdges.insert(Index.getOrInsertValueInfo(GV));
+ continue;
+ }
+ Worklist.push_back(Operand);
+ }
+ }
+ return HasBlockAddress;
+}
+
+static CalleeInfo::HotnessType getHotness(uint64_t ProfileCount,
+ ProfileSummaryInfo *PSI) {
+ if (!PSI)
+ return CalleeInfo::HotnessType::Unknown;
+ if (PSI->isHotCount(ProfileCount))
+ return CalleeInfo::HotnessType::Hot;
+ if (PSI->isColdCount(ProfileCount))
+ return CalleeInfo::HotnessType::Cold;
+ return CalleeInfo::HotnessType::None;
+}
+
+static bool isNonRenamableLocal(const GlobalValue &GV) {
+ return GV.hasSection() && GV.hasLocalLinkage();
+}
+
+/// Determine whether this call has all constant integer arguments (excluding
+/// "this") and summarize it to VCalls or ConstVCalls as appropriate.
+static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid,
+ SetVector<FunctionSummary::VFuncId> &VCalls,
+ SetVector<FunctionSummary::ConstVCall> &ConstVCalls) {
+ std::vector<uint64_t> Args;
+ // Start from the second argument to skip the "this" pointer.
+ for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) {
+ auto *CI = dyn_cast<ConstantInt>(Arg);
+ if (!CI || CI->getBitWidth() > 64) {
+ VCalls.insert({Guid, Call.Offset});
+ return;
+ }
+ Args.push_back(CI->getZExtValue());
+ }
+ ConstVCalls.insert({{Guid, Call.Offset}, std::move(Args)});
+}
+
+/// If this intrinsic call requires that we add information to the function
+/// summary, do so via the non-constant reference arguments.
+static void addIntrinsicToSummary(
+ const CallInst *CI, SetVector<GlobalValue::GUID> &TypeTests,
+ SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls,
+ SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls,
+ SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls,
+ SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls,
+ DominatorTree &DT) {
+ switch (CI->getCalledFunction()->getIntrinsicID()) {
+ case Intrinsic::type_test: {
+ auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1));
+ auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata());
+ if (!TypeId)
+ break;
+ GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString());
+
+ // Produce a summary from type.test intrinsics. We only summarize type.test
+ // intrinsics that are used other than by an llvm.assume intrinsic.
+ // Intrinsics that are assumed are relevant only to the devirtualization
+ // pass, not the type test lowering pass.
+ bool HasNonAssumeUses = llvm::any_of(CI->uses(), [](const Use &CIU) {
+ auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser());
+ if (!AssumeCI)
+ return true;
+ Function *F = AssumeCI->getCalledFunction();
+ return !F || F->getIntrinsicID() != Intrinsic::assume;
+ });
+ if (HasNonAssumeUses)
+ TypeTests.insert(Guid);
+
+ SmallVector<DevirtCallSite, 4> DevirtCalls;
+ SmallVector<CallInst *, 4> Assumes;
+ findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
+ for (auto &Call : DevirtCalls)
+ addVCallToSet(Call, Guid, TypeTestAssumeVCalls,
+ TypeTestAssumeConstVCalls);
+
+ break;
+ }
+
+ case Intrinsic::type_checked_load: {
+ auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(2));
+ auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata());
+ if (!TypeId)
+ break;
+ GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString());
+
+ SmallVector<DevirtCallSite, 4> DevirtCalls;
+ SmallVector<Instruction *, 4> LoadedPtrs;
+ SmallVector<Instruction *, 4> Preds;
+ bool HasNonCallUses = false;
+ findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
+ HasNonCallUses, CI, DT);
+ // Any non-call uses of the result of llvm.type.checked.load will
+ // prevent us from optimizing away the llvm.type.test.
+ if (HasNonCallUses)
+ TypeTests.insert(Guid);
+ for (auto &Call : DevirtCalls)
+ addVCallToSet(Call, Guid, TypeCheckedLoadVCalls,
+ TypeCheckedLoadConstVCalls);
+
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+static bool isNonVolatileLoad(const Instruction *I) {
+ if (const auto *LI = dyn_cast<LoadInst>(I))
+ return !LI->isVolatile();
+
+ return false;
+}
+
+static bool isNonVolatileStore(const Instruction *I) {
+ if (const auto *SI = dyn_cast<StoreInst>(I))
+ return !SI->isVolatile();
+
+ return false;
+}
+
+static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
+ const Function &F, BlockFrequencyInfo *BFI,
+ ProfileSummaryInfo *PSI, DominatorTree &DT,
+ bool HasLocalsInUsedOrAsm,
+ DenseSet<GlobalValue::GUID> &CantBePromoted,
+ bool IsThinLTO) {
+ // Summary not currently supported for anonymous functions, they should
+ // have been named.
+ assert(F.hasName());
+
+ unsigned NumInsts = 0;
+ // Map from callee ValueId to profile count. Used to accumulate profile
+ // counts for all static calls to a given callee.
+ MapVector<ValueInfo, CalleeInfo> CallGraphEdges;
+ SetVector<ValueInfo> RefEdges, LoadRefEdges, StoreRefEdges;
+ SetVector<GlobalValue::GUID> TypeTests;
+ SetVector<FunctionSummary::VFuncId> TypeTestAssumeVCalls,
+ TypeCheckedLoadVCalls;
+ SetVector<FunctionSummary::ConstVCall> TypeTestAssumeConstVCalls,
+ TypeCheckedLoadConstVCalls;
+ ICallPromotionAnalysis ICallAnalysis;
+ SmallPtrSet<const User *, 8> Visited;
+
+ // Add personality function, prefix data and prologue data to function's ref
+ // list.
+ findRefEdges(Index, &F, RefEdges, Visited);
+ std::vector<const Instruction *> NonVolatileLoads;
+ std::vector<const Instruction *> NonVolatileStores;
+
+ bool HasInlineAsmMaybeReferencingInternal = false;
+ for (const BasicBlock &BB : F)
+ for (const Instruction &I : BB) {
+ if (isa<DbgInfoIntrinsic>(I))
+ continue;
+ ++NumInsts;
+ // Regular LTO module doesn't participate in ThinLTO import,
+ // so no reference from it can be read/writeonly, since this
+ // would require importing variable as local copy
+ if (IsThinLTO) {
+ if (isNonVolatileLoad(&I)) {
+ // Postpone processing of non-volatile load instructions
+ // See comments below
+ Visited.insert(&I);
+ NonVolatileLoads.push_back(&I);
+ continue;
+ } else if (isNonVolatileStore(&I)) {
+ Visited.insert(&I);
+ NonVolatileStores.push_back(&I);
+ // All references from second operand of store (destination address)
+ // can be considered write-only if they're not referenced by any
+ // non-store instruction. References from first operand of store
+ // (stored value) can't be treated either as read- or as write-only
+ // so we add them to RefEdges as we do with all other instructions
+ // except non-volatile load.
+ Value *Stored = I.getOperand(0);
+ if (auto *GV = dyn_cast<GlobalValue>(Stored))
+ // findRefEdges will try to examine GV operands, so instead
+ // of calling it we should add GV to RefEdges directly.
+ RefEdges.insert(Index.getOrInsertValueInfo(GV));
+ else if (auto *U = dyn_cast<User>(Stored))
+ findRefEdges(Index, U, RefEdges, Visited);
+ continue;
+ }
+ }
+ findRefEdges(Index, &I, RefEdges, Visited);
+ auto CS = ImmutableCallSite(&I);
+ if (!CS)
+ continue;
+
+ const auto *CI = dyn_cast<CallInst>(&I);
+ // Since we don't know exactly which local values are referenced in inline
+ // assembly, conservatively mark the function as possibly referencing
+ // a local value from inline assembly to ensure we don't export a
+ // reference (which would require renaming and promotion of the
+ // referenced value).
+ if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm())
+ HasInlineAsmMaybeReferencingInternal = true;
+
+ auto *CalledValue = CS.getCalledValue();
+ auto *CalledFunction = CS.getCalledFunction();
+ if (CalledValue && !CalledFunction) {
+ CalledValue = CalledValue->stripPointerCasts();
+ // Stripping pointer casts can reveal a called function.
+ CalledFunction = dyn_cast<Function>(CalledValue);
+ }
+ // Check if this is an alias to a function. If so, get the
+ // called aliasee for the checks below.
+ if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) {
+ assert(!CalledFunction && "Expected null called function in callsite for alias");
+ CalledFunction = dyn_cast<Function>(GA->getBaseObject());
+ }
+ // Check if this is a direct call to a known function or a known
+ // intrinsic, or an indirect call with profile data.
+ if (CalledFunction) {
+ if (CI && CalledFunction->isIntrinsic()) {
+ addIntrinsicToSummary(
+ CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls,
+ TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls, DT);
+ continue;
+ }
+ // We should have named any anonymous globals
+ assert(CalledFunction->hasName());
+ auto ScaledCount = PSI->getProfileCount(&I, BFI);
+ auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI)
+ : CalleeInfo::HotnessType::Unknown;
+ if (ForceSummaryEdgesCold != FunctionSummary::FSHT_None)
+ Hotness = CalleeInfo::HotnessType::Cold;
+
+ // Use the original CalledValue, in case it was an alias. We want
+ // to record the call edge to the alias in that case. Eventually
+ // an alias summary will be created to associate the alias and
+ // aliasee.
+ auto &ValueInfo = CallGraphEdges[Index.getOrInsertValueInfo(
+ cast<GlobalValue>(CalledValue))];
+ ValueInfo.updateHotness(Hotness);
+ // Add the relative block frequency to CalleeInfo if there is no profile
+ // information.
+ if (BFI != nullptr && Hotness == CalleeInfo::HotnessType::Unknown) {
+ uint64_t BBFreq = BFI->getBlockFreq(&BB).getFrequency();
+ uint64_t EntryFreq = BFI->getEntryFreq();
+ ValueInfo.updateRelBlockFreq(BBFreq, EntryFreq);
+ }
+ } else {
+ // Skip inline assembly calls.
+ if (CI && CI->isInlineAsm())
+ continue;
+ // Skip direct calls.
+ if (!CalledValue || isa<Constant>(CalledValue))
+ continue;
+
+ // Check if the instruction has a callees metadata. If so, add callees
+ // to CallGraphEdges to reflect the references from the metadata, and
+ // to enable importing for subsequent indirect call promotion and
+ // inlining.
+ if (auto *MD = I.getMetadata(LLVMContext::MD_callees)) {
+ for (auto &Op : MD->operands()) {
+ Function *Callee = mdconst::extract_or_null<Function>(Op);
+ if (Callee)
+ CallGraphEdges[Index.getOrInsertValueInfo(Callee)];
+ }
+ }
+
+ uint32_t NumVals, NumCandidates;
+ uint64_t TotalCount;
+ auto CandidateProfileData =
+ ICallAnalysis.getPromotionCandidatesForInstruction(
+ &I, NumVals, TotalCount, NumCandidates);
+ for (auto &Candidate : CandidateProfileData)
+ CallGraphEdges[Index.getOrInsertValueInfo(Candidate.Value)]
+ .updateHotness(getHotness(Candidate.Count, PSI));
+ }
+ }
+
+ std::vector<ValueInfo> Refs;
+ if (IsThinLTO) {
+ auto AddRefEdges = [&](const std::vector<const Instruction *> &Instrs,
+ SetVector<ValueInfo> &Edges,
+ SmallPtrSet<const User *, 8> &Cache) {
+ for (const auto *I : Instrs) {
+ Cache.erase(I);
+ findRefEdges(Index, I, Edges, Cache);
+ }
+ };
+
+ // By now we processed all instructions in a function, except
+ // non-volatile loads and non-volatile value stores. Let's find
+ // ref edges for both of instruction sets
+ AddRefEdges(NonVolatileLoads, LoadRefEdges, Visited);
+ // We can add some values to the Visited set when processing load
+ // instructions which are also used by stores in NonVolatileStores.
+ // For example this can happen if we have following code:
+ //
+ // store %Derived* @foo, %Derived** bitcast (%Base** @bar to %Derived**)
+ // %42 = load %Derived*, %Derived** bitcast (%Base** @bar to %Derived**)
+ //
+ // After processing loads we'll add bitcast to the Visited set, and if
+ // we use the same set while processing stores, we'll never see store
+ // to @bar and @bar will be mistakenly treated as readonly.
+ SmallPtrSet<const llvm::User *, 8> StoreCache;
+ AddRefEdges(NonVolatileStores, StoreRefEdges, StoreCache);
+
+ // If both load and store instruction reference the same variable
+ // we won't be able to optimize it. Add all such reference edges
+ // to RefEdges set.
+ for (auto &VI : StoreRefEdges)
+ if (LoadRefEdges.remove(VI))
+ RefEdges.insert(VI);
+
+ unsigned RefCnt = RefEdges.size();
+ // All new reference edges inserted in two loops below are either
+ // read or write only. They will be grouped in the end of RefEdges
+ // vector, so we can use a single integer value to identify them.
+ for (auto &VI : LoadRefEdges)
+ RefEdges.insert(VI);
+
+ unsigned FirstWORef = RefEdges.size();
+ for (auto &VI : StoreRefEdges)
+ RefEdges.insert(VI);
+
+ Refs = RefEdges.takeVector();
+ for (; RefCnt < FirstWORef; ++RefCnt)
+ Refs[RefCnt].setReadOnly();
+
+ for (; RefCnt < Refs.size(); ++RefCnt)
+ Refs[RefCnt].setWriteOnly();
+ } else {
+ Refs = RefEdges.takeVector();
+ }
+ // Explicit add hot edges to enforce importing for designated GUIDs for
+ // sample PGO, to enable the same inlines as the profiled optimized binary.
+ for (auto &I : F.getImportGUIDs())
+ CallGraphEdges[Index.getOrInsertValueInfo(I)].updateHotness(
+ ForceSummaryEdgesCold == FunctionSummary::FSHT_All
+ ? CalleeInfo::HotnessType::Cold
+ : CalleeInfo::HotnessType::Critical);
+
+ bool NonRenamableLocal = isNonRenamableLocal(F);
+ bool NotEligibleForImport =
+ NonRenamableLocal || HasInlineAsmMaybeReferencingInternal;
+ GlobalValueSummary::GVFlags Flags(F.getLinkage(), NotEligibleForImport,
+ /* Live = */ false, F.isDSOLocal(),
+ F.hasLinkOnceODRLinkage() && F.hasGlobalUnnamedAddr());
+ FunctionSummary::FFlags FunFlags{
+ F.hasFnAttribute(Attribute::ReadNone),
+ F.hasFnAttribute(Attribute::ReadOnly),
+ F.hasFnAttribute(Attribute::NoRecurse), F.returnDoesNotAlias(),
+ // FIXME: refactor this to use the same code that inliner is using.
+ // Don't try to import functions with noinline attribute.
+ F.getAttributes().hasFnAttribute(Attribute::NoInline)};
+ auto FuncSummary = std::make_unique<FunctionSummary>(
+ Flags, NumInsts, FunFlags, /*EntryCount=*/0, std::move(Refs),
+ CallGraphEdges.takeVector(), TypeTests.takeVector(),
+ TypeTestAssumeVCalls.takeVector(), TypeCheckedLoadVCalls.takeVector(),
+ TypeTestAssumeConstVCalls.takeVector(),
+ TypeCheckedLoadConstVCalls.takeVector());
+ if (NonRenamableLocal)
+ CantBePromoted.insert(F.getGUID());
+ Index.addGlobalValueSummary(F, std::move(FuncSummary));
+}
+
+/// Find function pointers referenced within the given vtable initializer
+/// (or subset of an initializer) \p I. The starting offset of \p I within
+/// the vtable initializer is \p StartingOffset. Any discovered function
+/// pointers are added to \p VTableFuncs along with their cumulative offset
+/// within the initializer.
+static void findFuncPointers(const Constant *I, uint64_t StartingOffset,
+ const Module &M, ModuleSummaryIndex &Index,
+ VTableFuncList &VTableFuncs) {
+ // First check if this is a function pointer.
+ if (I->getType()->isPointerTy()) {
+ auto Fn = dyn_cast<Function>(I->stripPointerCasts());
+ // We can disregard __cxa_pure_virtual as a possible call target, as
+ // calls to pure virtuals are UB.
+ if (Fn && Fn->getName() != "__cxa_pure_virtual")
+ VTableFuncs.push_back({Index.getOrInsertValueInfo(Fn), StartingOffset});
+ return;
+ }
+
+ // Walk through the elements in the constant struct or array and recursively
+ // look for virtual function pointers.
+ const DataLayout &DL = M.getDataLayout();
+ if (auto *C = dyn_cast<ConstantStruct>(I)) {
+ StructType *STy = dyn_cast<StructType>(C->getType());
+ assert(STy);
+ const StructLayout *SL = DL.getStructLayout(C->getType());
+
+ for (StructType::element_iterator EB = STy->element_begin(), EI = EB,
+ EE = STy->element_end();
+ EI != EE; ++EI) {
+ auto Offset = SL->getElementOffset(EI - EB);
+ unsigned Op = SL->getElementContainingOffset(Offset);
+ findFuncPointers(cast<Constant>(I->getOperand(Op)),
+ StartingOffset + Offset, M, Index, VTableFuncs);
+ }
+ } else if (auto *C = dyn_cast<ConstantArray>(I)) {
+ ArrayType *ATy = C->getType();
+ Type *EltTy = ATy->getElementType();
+ uint64_t EltSize = DL.getTypeAllocSize(EltTy);
+ for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) {
+ findFuncPointers(cast<Constant>(I->getOperand(i)),
+ StartingOffset + i * EltSize, M, Index, VTableFuncs);
+ }
+ }
+}
+
+// Identify the function pointers referenced by vtable definition \p V.
+static void computeVTableFuncs(ModuleSummaryIndex &Index,
+ const GlobalVariable &V, const Module &M,
+ VTableFuncList &VTableFuncs) {
+ if (!V.isConstant())
+ return;
+
+ findFuncPointers(V.getInitializer(), /*StartingOffset=*/0, M, Index,
+ VTableFuncs);
+
+#ifndef NDEBUG
+ // Validate that the VTableFuncs list is ordered by offset.
+ uint64_t PrevOffset = 0;
+ for (auto &P : VTableFuncs) {
+ // The findVFuncPointers traversal should have encountered the
+ // functions in offset order. We need to use ">=" since PrevOffset
+ // starts at 0.
+ assert(P.VTableOffset >= PrevOffset);
+ PrevOffset = P.VTableOffset;
+ }
+#endif
+}
+
+/// Record vtable definition \p V for each type metadata it references.
+static void
+recordTypeIdCompatibleVtableReferences(ModuleSummaryIndex &Index,
+ const GlobalVariable &V,
+ SmallVectorImpl<MDNode *> &Types) {
+ for (MDNode *Type : Types) {
+ auto TypeID = Type->getOperand(1).get();
+
+ uint64_t Offset =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
+ ->getZExtValue();
+
+ if (auto *TypeId = dyn_cast<MDString>(TypeID))
+ Index.getOrInsertTypeIdCompatibleVtableSummary(TypeId->getString())
+ .push_back({Offset, Index.getOrInsertValueInfo(&V)});
+ }
+}
+
+static void computeVariableSummary(ModuleSummaryIndex &Index,
+ const GlobalVariable &V,
+ DenseSet<GlobalValue::GUID> &CantBePromoted,
+ const Module &M,
+ SmallVectorImpl<MDNode *> &Types) {
+ SetVector<ValueInfo> RefEdges;
+ SmallPtrSet<const User *, 8> Visited;
+ bool HasBlockAddress = findRefEdges(Index, &V, RefEdges, Visited);
+ bool NonRenamableLocal = isNonRenamableLocal(V);
+ GlobalValueSummary::GVFlags Flags(V.getLinkage(), NonRenamableLocal,
+ /* Live = */ false, V.isDSOLocal(),
+ V.hasLinkOnceODRLinkage() && V.hasGlobalUnnamedAddr());
+
+ VTableFuncList VTableFuncs;
+ // If splitting is not enabled, then we compute the summary information
+ // necessary for index-based whole program devirtualization.
+ if (!Index.enableSplitLTOUnit()) {
+ Types.clear();
+ V.getMetadata(LLVMContext::MD_type, Types);
+ if (!Types.empty()) {
+ // Identify the function pointers referenced by this vtable definition.
+ computeVTableFuncs(Index, V, M, VTableFuncs);
+
+ // Record this vtable definition for each type metadata it references.
+ recordTypeIdCompatibleVtableReferences(Index, V, Types);
+ }
+ }
+
+ // Don't mark variables we won't be able to internalize as read/write-only.
+ bool CanBeInternalized =
+ !V.hasComdat() && !V.hasAppendingLinkage() && !V.isInterposable() &&
+ !V.hasAvailableExternallyLinkage() && !V.hasDLLExportStorageClass();
+ GlobalVarSummary::GVarFlags VarFlags(CanBeInternalized, CanBeInternalized);
+ auto GVarSummary = std::make_unique<GlobalVarSummary>(Flags, VarFlags,
+ RefEdges.takeVector());
+ if (NonRenamableLocal)
+ CantBePromoted.insert(V.getGUID());
+ if (HasBlockAddress)
+ GVarSummary->setNotEligibleToImport();
+ if (!VTableFuncs.empty())
+ GVarSummary->setVTableFuncs(VTableFuncs);
+ Index.addGlobalValueSummary(V, std::move(GVarSummary));
+}
+
+static void
+computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A,
+ DenseSet<GlobalValue::GUID> &CantBePromoted) {
+ bool NonRenamableLocal = isNonRenamableLocal(A);
+ GlobalValueSummary::GVFlags Flags(A.getLinkage(), NonRenamableLocal,
+ /* Live = */ false, A.isDSOLocal(),
+ A.hasLinkOnceODRLinkage() && A.hasGlobalUnnamedAddr());
+ auto AS = std::make_unique<AliasSummary>(Flags);
+ auto *Aliasee = A.getBaseObject();
+ auto AliaseeVI = Index.getValueInfo(Aliasee->getGUID());
+ assert(AliaseeVI && "Alias expects aliasee summary to be available");
+ assert(AliaseeVI.getSummaryList().size() == 1 &&
+ "Expected a single entry per aliasee in per-module index");
+ AS->setAliasee(AliaseeVI, AliaseeVI.getSummaryList()[0].get());
+ if (NonRenamableLocal)
+ CantBePromoted.insert(A.getGUID());
+ Index.addGlobalValueSummary(A, std::move(AS));
+}
+
+// Set LiveRoot flag on entries matching the given value name.
+static void setLiveRoot(ModuleSummaryIndex &Index, StringRef Name) {
+ if (ValueInfo VI = Index.getValueInfo(GlobalValue::getGUID(Name)))
+ for (auto &Summary : VI.getSummaryList())
+ Summary->setLive(true);
+}
+
+ModuleSummaryIndex llvm::buildModuleSummaryIndex(
+ const Module &M,
+ std::function<BlockFrequencyInfo *(const Function &F)> GetBFICallback,
+ ProfileSummaryInfo *PSI) {
+ assert(PSI);
+ bool EnableSplitLTOUnit = false;
+ if (auto *MD = mdconst::extract_or_null<ConstantInt>(
+ M.getModuleFlag("EnableSplitLTOUnit")))
+ EnableSplitLTOUnit = MD->getZExtValue();
+ ModuleSummaryIndex Index(/*HaveGVs=*/true, EnableSplitLTOUnit);
+
+ // Identify the local values in the llvm.used and llvm.compiler.used sets,
+ // which should not be exported as they would then require renaming and
+ // promotion, but we may have opaque uses e.g. in inline asm. We collect them
+ // here because we use this information to mark functions containing inline
+ // assembly calls as not importable.
+ SmallPtrSet<GlobalValue *, 8> LocalsUsed;
+ SmallPtrSet<GlobalValue *, 8> Used;
+ // First collect those in the llvm.used set.
+ collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ false);
+ // Next collect those in the llvm.compiler.used set.
+ collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ true);
+ DenseSet<GlobalValue::GUID> CantBePromoted;
+ for (auto *V : Used) {
+ if (V->hasLocalLinkage()) {
+ LocalsUsed.insert(V);
+ CantBePromoted.insert(V->getGUID());
+ }
+ }
+
+ bool HasLocalInlineAsmSymbol = false;
+ if (!M.getModuleInlineAsm().empty()) {
+ // Collect the local values defined by module level asm, and set up
+ // summaries for these symbols so that they can be marked as NoRename,
+ // to prevent export of any use of them in regular IR that would require
+ // renaming within the module level asm. Note we don't need to create a
+ // summary for weak or global defs, as they don't need to be flagged as
+ // NoRename, and defs in module level asm can't be imported anyway.
+ // Also, any values used but not defined within module level asm should
+ // be listed on the llvm.used or llvm.compiler.used global and marked as
+ // referenced from there.
+ ModuleSymbolTable::CollectAsmSymbols(
+ M, [&](StringRef Name, object::BasicSymbolRef::Flags Flags) {
+ // Symbols not marked as Weak or Global are local definitions.
+ if (Flags & (object::BasicSymbolRef::SF_Weak |
+ object::BasicSymbolRef::SF_Global))
+ return;
+ HasLocalInlineAsmSymbol = true;
+ GlobalValue *GV = M.getNamedValue(Name);
+ if (!GV)
+ return;
+ assert(GV->isDeclaration() && "Def in module asm already has definition");
+ GlobalValueSummary::GVFlags GVFlags(GlobalValue::InternalLinkage,
+ /* NotEligibleToImport = */ true,
+ /* Live = */ true,
+ /* Local */ GV->isDSOLocal(),
+ GV->hasLinkOnceODRLinkage() && GV->hasGlobalUnnamedAddr());
+ CantBePromoted.insert(GV->getGUID());
+ // Create the appropriate summary type.
+ if (Function *F = dyn_cast<Function>(GV)) {
+ std::unique_ptr<FunctionSummary> Summary =
+ std::make_unique<FunctionSummary>(
+ GVFlags, /*InstCount=*/0,
+ FunctionSummary::FFlags{
+ F->hasFnAttribute(Attribute::ReadNone),
+ F->hasFnAttribute(Attribute::ReadOnly),
+ F->hasFnAttribute(Attribute::NoRecurse),
+ F->returnDoesNotAlias(),
+ /* NoInline = */ false},
+ /*EntryCount=*/0, ArrayRef<ValueInfo>{},
+ ArrayRef<FunctionSummary::EdgeTy>{},
+ ArrayRef<GlobalValue::GUID>{},
+ ArrayRef<FunctionSummary::VFuncId>{},
+ ArrayRef<FunctionSummary::VFuncId>{},
+ ArrayRef<FunctionSummary::ConstVCall>{},
+ ArrayRef<FunctionSummary::ConstVCall>{});
+ Index.addGlobalValueSummary(*GV, std::move(Summary));
+ } else {
+ std::unique_ptr<GlobalVarSummary> Summary =
+ std::make_unique<GlobalVarSummary>(
+ GVFlags, GlobalVarSummary::GVarFlags(false, false),
+ ArrayRef<ValueInfo>{});
+ Index.addGlobalValueSummary(*GV, std::move(Summary));
+ }
+ });
+ }
+
+ bool IsThinLTO = true;
+ if (auto *MD =
+ mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("ThinLTO")))
+ IsThinLTO = MD->getZExtValue();
+
+ // Compute summaries for all functions defined in module, and save in the
+ // index.
+ for (auto &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ DominatorTree DT(const_cast<Function &>(F));
+ BlockFrequencyInfo *BFI = nullptr;
+ std::unique_ptr<BlockFrequencyInfo> BFIPtr;
+ if (GetBFICallback)
+ BFI = GetBFICallback(F);
+ else if (F.hasProfileData()) {
+ LoopInfo LI{DT};
+ BranchProbabilityInfo BPI{F, LI};
+ BFIPtr = std::make_unique<BlockFrequencyInfo>(F, BPI, LI);
+ BFI = BFIPtr.get();
+ }
+
+ computeFunctionSummary(Index, M, F, BFI, PSI, DT,
+ !LocalsUsed.empty() || HasLocalInlineAsmSymbol,
+ CantBePromoted, IsThinLTO);
+ }
+
+ // Compute summaries for all variables defined in module, and save in the
+ // index.
+ SmallVector<MDNode *, 2> Types;
+ for (const GlobalVariable &G : M.globals()) {
+ if (G.isDeclaration())
+ continue;
+ computeVariableSummary(Index, G, CantBePromoted, M, Types);
+ }
+
+ // Compute summaries for all aliases defined in module, and save in the
+ // index.
+ for (const GlobalAlias &A : M.aliases())
+ computeAliasSummary(Index, A, CantBePromoted);
+
+ for (auto *V : LocalsUsed) {
+ auto *Summary = Index.getGlobalValueSummary(*V);
+ assert(Summary && "Missing summary for global value");
+ Summary->setNotEligibleToImport();
+ }
+
+ // The linker doesn't know about these LLVM produced values, so we need
+ // to flag them as live in the index to ensure index-based dead value
+ // analysis treats them as live roots of the analysis.
+ setLiveRoot(Index, "llvm.used");
+ setLiveRoot(Index, "llvm.compiler.used");
+ setLiveRoot(Index, "llvm.global_ctors");
+ setLiveRoot(Index, "llvm.global_dtors");
+ setLiveRoot(Index, "llvm.global.annotations");
+
+ for (auto &GlobalList : Index) {
+ // Ignore entries for references that are undefined in the current module.
+ if (GlobalList.second.SummaryList.empty())
+ continue;
+
+ assert(GlobalList.second.SummaryList.size() == 1 &&
+ "Expected module's index to have one summary per GUID");
+ auto &Summary = GlobalList.second.SummaryList[0];
+ if (!IsThinLTO) {
+ Summary->setNotEligibleToImport();
+ continue;
+ }
+
+ bool AllRefsCanBeExternallyReferenced =
+ llvm::all_of(Summary->refs(), [&](const ValueInfo &VI) {
+ return !CantBePromoted.count(VI.getGUID());
+ });
+ if (!AllRefsCanBeExternallyReferenced) {
+ Summary->setNotEligibleToImport();
+ continue;
+ }
+
+ if (auto *FuncSummary = dyn_cast<FunctionSummary>(Summary.get())) {
+ bool AllCallsCanBeExternallyReferenced = llvm::all_of(
+ FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) {
+ return !CantBePromoted.count(Edge.first.getGUID());
+ });
+ if (!AllCallsCanBeExternallyReferenced)
+ Summary->setNotEligibleToImport();
+ }
+ }
+
+ if (!ModuleSummaryDotFile.empty()) {
+ std::error_code EC;
+ raw_fd_ostream OSDot(ModuleSummaryDotFile, EC, sys::fs::OpenFlags::OF_None);
+ if (EC)
+ report_fatal_error(Twine("Failed to open dot file ") +
+ ModuleSummaryDotFile + ": " + EC.message() + "\n");
+ Index.exportToDot(OSDot);
+ }
+
+ return Index;
+}
+
+AnalysisKey ModuleSummaryIndexAnalysis::Key;
+
+ModuleSummaryIndex
+ModuleSummaryIndexAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
+ ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M);
+ auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ return buildModuleSummaryIndex(
+ M,
+ [&FAM](const Function &F) {
+ return &FAM.getResult<BlockFrequencyAnalysis>(
+ *const_cast<Function *>(&F));
+ },
+ &PSI);
+}
+
+char ModuleSummaryIndexWrapperPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(ModuleSummaryIndexWrapperPass, "module-summary-analysis",
+ "Module Summary Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
+INITIALIZE_PASS_END(ModuleSummaryIndexWrapperPass, "module-summary-analysis",
+ "Module Summary Analysis", false, true)
+
+ModulePass *llvm::createModuleSummaryIndexWrapperPass() {
+ return new ModuleSummaryIndexWrapperPass();
+}
+
+ModuleSummaryIndexWrapperPass::ModuleSummaryIndexWrapperPass()
+ : ModulePass(ID) {
+ initializeModuleSummaryIndexWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) {
+ auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
+ Index.emplace(buildModuleSummaryIndex(
+ M,
+ [this](const Function &F) {
+ return &(this->getAnalysis<BlockFrequencyInfoWrapperPass>(
+ *const_cast<Function *>(&F))
+ .getBFI());
+ },
+ PSI));
+ return false;
+}
+
+bool ModuleSummaryIndexWrapperPass::doFinalization(Module &M) {
+ Index.reset();
+ return false;
+}
+
+void ModuleSummaryIndexWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<BlockFrequencyInfoWrapperPass>();
+ AU.addRequired<ProfileSummaryInfoWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/MustExecute.cpp b/llvm/lib/Analysis/MustExecute.cpp
new file mode 100644
index 000000000000..44527773115d
--- /dev/null
+++ b/llvm/lib/Analysis/MustExecute.cpp
@@ -0,0 +1,516 @@
+//===- MustExecute.cpp - Printer for isGuaranteedToExecute ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/MustExecute.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "must-execute"
+
+const DenseMap<BasicBlock *, ColorVector> &
+LoopSafetyInfo::getBlockColors() const {
+ return BlockColors;
+}
+
+void LoopSafetyInfo::copyColors(BasicBlock *New, BasicBlock *Old) {
+ ColorVector &ColorsForNewBlock = BlockColors[New];
+ ColorVector &ColorsForOldBlock = BlockColors[Old];
+ ColorsForNewBlock = ColorsForOldBlock;
+}
+
+bool SimpleLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const {
+ (void)BB;
+ return anyBlockMayThrow();
+}
+
+bool SimpleLoopSafetyInfo::anyBlockMayThrow() const {
+ return MayThrow;
+}
+
+void SimpleLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) {
+ assert(CurLoop != nullptr && "CurLoop can't be null");
+ BasicBlock *Header = CurLoop->getHeader();
+ // Iterate over header and compute safety info.
+ HeaderMayThrow = !isGuaranteedToTransferExecutionToSuccessor(Header);
+ MayThrow = HeaderMayThrow;
+ // Iterate over loop instructions and compute safety info.
+ // Skip header as it has been computed and stored in HeaderMayThrow.
+ // The first block in loopinfo.Blocks is guaranteed to be the header.
+ assert(Header == *CurLoop->getBlocks().begin() &&
+ "First block must be header");
+ for (Loop::block_iterator BB = std::next(CurLoop->block_begin()),
+ BBE = CurLoop->block_end();
+ (BB != BBE) && !MayThrow; ++BB)
+ MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(*BB);
+
+ computeBlockColors(CurLoop);
+}
+
+bool ICFLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const {
+ return ICF.hasICF(BB);
+}
+
+bool ICFLoopSafetyInfo::anyBlockMayThrow() const {
+ return MayThrow;
+}
+
+void ICFLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) {
+ assert(CurLoop != nullptr && "CurLoop can't be null");
+ ICF.clear();
+ MW.clear();
+ MayThrow = false;
+ // Figure out the fact that at least one block may throw.
+ for (auto &BB : CurLoop->blocks())
+ if (ICF.hasICF(&*BB)) {
+ MayThrow = true;
+ break;
+ }
+ computeBlockColors(CurLoop);
+}
+
+void ICFLoopSafetyInfo::insertInstructionTo(const Instruction *Inst,
+ const BasicBlock *BB) {
+ ICF.insertInstructionTo(Inst, BB);
+ MW.insertInstructionTo(Inst, BB);
+}
+
+void ICFLoopSafetyInfo::removeInstruction(const Instruction *Inst) {
+ ICF.removeInstruction(Inst);
+ MW.removeInstruction(Inst);
+}
+
+void LoopSafetyInfo::computeBlockColors(const Loop *CurLoop) {
+ // Compute funclet colors if we might sink/hoist in a function with a funclet
+ // personality routine.
+ Function *Fn = CurLoop->getHeader()->getParent();
+ if (Fn->hasPersonalityFn())
+ if (Constant *PersonalityFn = Fn->getPersonalityFn())
+ if (isScopedEHPersonality(classifyEHPersonality(PersonalityFn)))
+ BlockColors = colorEHFunclets(*Fn);
+}
+
+/// Return true if we can prove that the given ExitBlock is not reached on the
+/// first iteration of the given loop. That is, the backedge of the loop must
+/// be executed before the ExitBlock is executed in any dynamic execution trace.
+static bool CanProveNotTakenFirstIteration(const BasicBlock *ExitBlock,
+ const DominatorTree *DT,
+ const Loop *CurLoop) {
+ auto *CondExitBlock = ExitBlock->getSinglePredecessor();
+ if (!CondExitBlock)
+ // expect unique exits
+ return false;
+ assert(CurLoop->contains(CondExitBlock) && "meaning of exit block");
+ auto *BI = dyn_cast<BranchInst>(CondExitBlock->getTerminator());
+ if (!BI || !BI->isConditional())
+ return false;
+ // If condition is constant and false leads to ExitBlock then we always
+ // execute the true branch.
+ if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition()))
+ return BI->getSuccessor(Cond->getZExtValue() ? 1 : 0) == ExitBlock;
+ auto *Cond = dyn_cast<CmpInst>(BI->getCondition());
+ if (!Cond)
+ return false;
+ // todo: this would be a lot more powerful if we used scev, but all the
+ // plumbing is currently missing to pass a pointer in from the pass
+ // Check for cmp (phi [x, preheader] ...), y where (pred x, y is known
+ auto *LHS = dyn_cast<PHINode>(Cond->getOperand(0));
+ auto *RHS = Cond->getOperand(1);
+ if (!LHS || LHS->getParent() != CurLoop->getHeader())
+ return false;
+ auto DL = ExitBlock->getModule()->getDataLayout();
+ auto *IVStart = LHS->getIncomingValueForBlock(CurLoop->getLoopPreheader());
+ auto *SimpleValOrNull = SimplifyCmpInst(Cond->getPredicate(),
+ IVStart, RHS,
+ {DL, /*TLI*/ nullptr,
+ DT, /*AC*/ nullptr, BI});
+ auto *SimpleCst = dyn_cast_or_null<Constant>(SimpleValOrNull);
+ if (!SimpleCst)
+ return false;
+ if (ExitBlock == BI->getSuccessor(0))
+ return SimpleCst->isZeroValue();
+ assert(ExitBlock == BI->getSuccessor(1) && "implied by above");
+ return SimpleCst->isAllOnesValue();
+}
+
+/// Collect all blocks from \p CurLoop which lie on all possible paths from
+/// the header of \p CurLoop (inclusive) to BB (exclusive) into the set
+/// \p Predecessors. If \p BB is the header, \p Predecessors will be empty.
+static void collectTransitivePredecessors(
+ const Loop *CurLoop, const BasicBlock *BB,
+ SmallPtrSetImpl<const BasicBlock *> &Predecessors) {
+ assert(Predecessors.empty() && "Garbage in predecessors set?");
+ assert(CurLoop->contains(BB) && "Should only be called for loop blocks!");
+ if (BB == CurLoop->getHeader())
+ return;
+ SmallVector<const BasicBlock *, 4> WorkList;
+ for (auto *Pred : predecessors(BB)) {
+ Predecessors.insert(Pred);
+ WorkList.push_back(Pred);
+ }
+ while (!WorkList.empty()) {
+ auto *Pred = WorkList.pop_back_val();
+ assert(CurLoop->contains(Pred) && "Should only reach loop blocks!");
+ // We are not interested in backedges and we don't want to leave loop.
+ if (Pred == CurLoop->getHeader())
+ continue;
+ // TODO: If BB lies in an inner loop of CurLoop, this will traverse over all
+ // blocks of this inner loop, even those that are always executed AFTER the
+ // BB. It may make our analysis more conservative than it could be, see test
+ // @nested and @nested_no_throw in test/Analysis/MustExecute/loop-header.ll.
+ // We can ignore backedge of all loops containing BB to get a sligtly more
+ // optimistic result.
+ for (auto *PredPred : predecessors(Pred))
+ if (Predecessors.insert(PredPred).second)
+ WorkList.push_back(PredPred);
+ }
+}
+
+bool LoopSafetyInfo::allLoopPathsLeadToBlock(const Loop *CurLoop,
+ const BasicBlock *BB,
+ const DominatorTree *DT) const {
+ assert(CurLoop->contains(BB) && "Should only be called for loop blocks!");
+
+ // Fast path: header is always reached once the loop is entered.
+ if (BB == CurLoop->getHeader())
+ return true;
+
+ // Collect all transitive predecessors of BB in the same loop. This set will
+ // be a subset of the blocks within the loop.
+ SmallPtrSet<const BasicBlock *, 4> Predecessors;
+ collectTransitivePredecessors(CurLoop, BB, Predecessors);
+
+ // Make sure that all successors of, all predecessors of BB which are not
+ // dominated by BB, are either:
+ // 1) BB,
+ // 2) Also predecessors of BB,
+ // 3) Exit blocks which are not taken on 1st iteration.
+ // Memoize blocks we've already checked.
+ SmallPtrSet<const BasicBlock *, 4> CheckedSuccessors;
+ for (auto *Pred : Predecessors) {
+ // Predecessor block may throw, so it has a side exit.
+ if (blockMayThrow(Pred))
+ return false;
+
+ // BB dominates Pred, so if Pred runs, BB must run.
+ // This is true when Pred is a loop latch.
+ if (DT->dominates(BB, Pred))
+ continue;
+
+ for (auto *Succ : successors(Pred))
+ if (CheckedSuccessors.insert(Succ).second &&
+ Succ != BB && !Predecessors.count(Succ))
+ // By discharging conditions that are not executed on the 1st iteration,
+ // we guarantee that *at least* on the first iteration all paths from
+ // header that *may* execute will lead us to the block of interest. So
+ // that if we had virtually peeled one iteration away, in this peeled
+ // iteration the set of predecessors would contain only paths from
+ // header to BB without any exiting edges that may execute.
+ //
+ // TODO: We only do it for exiting edges currently. We could use the
+ // same function to skip some of the edges within the loop if we know
+ // that they will not be taken on the 1st iteration.
+ //
+ // TODO: If we somehow know the number of iterations in loop, the same
+ // check may be done for any arbitrary N-th iteration as long as N is
+ // not greater than minimum number of iterations in this loop.
+ if (CurLoop->contains(Succ) ||
+ !CanProveNotTakenFirstIteration(Succ, DT, CurLoop))
+ return false;
+ }
+
+ // All predecessors can only lead us to BB.
+ return true;
+}
+
+/// Returns true if the instruction in a loop is guaranteed to execute at least
+/// once.
+bool SimpleLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst,
+ const DominatorTree *DT,
+ const Loop *CurLoop) const {
+ // If the instruction is in the header block for the loop (which is very
+ // common), it is always guaranteed to dominate the exit blocks. Since this
+ // is a common case, and can save some work, check it now.
+ if (Inst.getParent() == CurLoop->getHeader())
+ // If there's a throw in the header block, we can't guarantee we'll reach
+ // Inst unless we can prove that Inst comes before the potential implicit
+ // exit. At the moment, we use a (cheap) hack for the common case where
+ // the instruction of interest is the first one in the block.
+ return !HeaderMayThrow ||
+ Inst.getParent()->getFirstNonPHIOrDbg() == &Inst;
+
+ // If there is a path from header to exit or latch that doesn't lead to our
+ // instruction's block, return false.
+ return allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT);
+}
+
+bool ICFLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst,
+ const DominatorTree *DT,
+ const Loop *CurLoop) const {
+ return !ICF.isDominatedByICFIFromSameBlock(&Inst) &&
+ allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT);
+}
+
+bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const BasicBlock *BB,
+ const Loop *CurLoop) const {
+ assert(CurLoop->contains(BB) && "Should only be called for loop blocks!");
+
+ // Fast path: there are no instructions before header.
+ if (BB == CurLoop->getHeader())
+ return true;
+
+ // Collect all transitive predecessors of BB in the same loop. This set will
+ // be a subset of the blocks within the loop.
+ SmallPtrSet<const BasicBlock *, 4> Predecessors;
+ collectTransitivePredecessors(CurLoop, BB, Predecessors);
+ // Find if there any instruction in either predecessor that could write
+ // to memory.
+ for (auto *Pred : Predecessors)
+ if (MW.mayWriteToMemory(Pred))
+ return false;
+ return true;
+}
+
+bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const Instruction &I,
+ const Loop *CurLoop) const {
+ auto *BB = I.getParent();
+ assert(CurLoop->contains(BB) && "Should only be called for loop blocks!");
+ return !MW.isDominatedByMemoryWriteFromSameBlock(&I) &&
+ doesNotWriteMemoryBefore(BB, CurLoop);
+}
+
+namespace {
+ struct MustExecutePrinter : public FunctionPass {
+
+ static char ID; // Pass identification, replacement for typeid
+ MustExecutePrinter() : FunctionPass(ID) {
+ initializeMustExecutePrinterPass(*PassRegistry::getPassRegistry());
+ }
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ }
+ bool runOnFunction(Function &F) override;
+ };
+ struct MustBeExecutedContextPrinter : public ModulePass {
+ static char ID;
+
+ MustBeExecutedContextPrinter() : ModulePass(ID) {
+ initializeMustBeExecutedContextPrinterPass(*PassRegistry::getPassRegistry());
+ }
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+ bool runOnModule(Module &M) override;
+ };
+}
+
+char MustExecutePrinter::ID = 0;
+INITIALIZE_PASS_BEGIN(MustExecutePrinter, "print-mustexecute",
+ "Instructions which execute on loop entry", false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(MustExecutePrinter, "print-mustexecute",
+ "Instructions which execute on loop entry", false, true)
+
+FunctionPass *llvm::createMustExecutePrinter() {
+ return new MustExecutePrinter();
+}
+
+char MustBeExecutedContextPrinter::ID = 0;
+INITIALIZE_PASS_BEGIN(
+ MustBeExecutedContextPrinter, "print-must-be-executed-contexts",
+ "print the must-be-executed-contexed for all instructions", false, true)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(MustBeExecutedContextPrinter,
+ "print-must-be-executed-contexts",
+ "print the must-be-executed-contexed for all instructions",
+ false, true)
+
+ModulePass *llvm::createMustBeExecutedContextPrinter() {
+ return new MustBeExecutedContextPrinter();
+}
+
+bool MustBeExecutedContextPrinter::runOnModule(Module &M) {
+ MustBeExecutedContextExplorer Explorer(true);
+ for (Function &F : M) {
+ for (Instruction &I : instructions(F)) {
+ dbgs() << "-- Explore context of: " << I << "\n";
+ for (const Instruction *CI : Explorer.range(&I))
+ dbgs() << " [F: " << CI->getFunction()->getName() << "] " << *CI
+ << "\n";
+ }
+ }
+
+ return false;
+}
+
+static bool isMustExecuteIn(const Instruction &I, Loop *L, DominatorTree *DT) {
+ // TODO: merge these two routines. For the moment, we display the best
+ // result obtained by *either* implementation. This is a bit unfair since no
+ // caller actually gets the full power at the moment.
+ SimpleLoopSafetyInfo LSI;
+ LSI.computeLoopSafetyInfo(L);
+ return LSI.isGuaranteedToExecute(I, DT, L) ||
+ isGuaranteedToExecuteForEveryIteration(&I, L);
+}
+
+namespace {
+/// An assembly annotator class to print must execute information in
+/// comments.
+class MustExecuteAnnotatedWriter : public AssemblyAnnotationWriter {
+ DenseMap<const Value*, SmallVector<Loop*, 4> > MustExec;
+
+public:
+ MustExecuteAnnotatedWriter(const Function &F,
+ DominatorTree &DT, LoopInfo &LI) {
+ for (auto &I: instructions(F)) {
+ Loop *L = LI.getLoopFor(I.getParent());
+ while (L) {
+ if (isMustExecuteIn(I, L, &DT)) {
+ MustExec[&I].push_back(L);
+ }
+ L = L->getParentLoop();
+ };
+ }
+ }
+ MustExecuteAnnotatedWriter(const Module &M,
+ DominatorTree &DT, LoopInfo &LI) {
+ for (auto &F : M)
+ for (auto &I: instructions(F)) {
+ Loop *L = LI.getLoopFor(I.getParent());
+ while (L) {
+ if (isMustExecuteIn(I, L, &DT)) {
+ MustExec[&I].push_back(L);
+ }
+ L = L->getParentLoop();
+ };
+ }
+ }
+
+
+ void printInfoComment(const Value &V, formatted_raw_ostream &OS) override {
+ if (!MustExec.count(&V))
+ return;
+
+ const auto &Loops = MustExec.lookup(&V);
+ const auto NumLoops = Loops.size();
+ if (NumLoops > 1)
+ OS << " ; (mustexec in " << NumLoops << " loops: ";
+ else
+ OS << " ; (mustexec in: ";
+
+ bool first = true;
+ for (const Loop *L : Loops) {
+ if (!first)
+ OS << ", ";
+ first = false;
+ OS << L->getHeader()->getName();
+ }
+ OS << ")";
+ }
+};
+} // namespace
+
+bool MustExecutePrinter::runOnFunction(Function &F) {
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+
+ MustExecuteAnnotatedWriter Writer(F, DT, LI);
+ F.print(dbgs(), &Writer);
+
+ return false;
+}
+
+const Instruction *
+MustBeExecutedContextExplorer::getMustBeExecutedNextInstruction(
+ MustBeExecutedIterator &It, const Instruction *PP) {
+ if (!PP)
+ return PP;
+ LLVM_DEBUG(dbgs() << "Find next instruction for " << *PP << "\n");
+
+ // If we explore only inside a given basic block we stop at terminators.
+ if (!ExploreInterBlock && PP->isTerminator()) {
+ LLVM_DEBUG(dbgs() << "\tReached terminator in intra-block mode, done\n");
+ return nullptr;
+ }
+
+ // If we do not traverse the call graph we check if we can make progress in
+ // the current function. First, check if the instruction is guaranteed to
+ // transfer execution to the successor.
+ bool TransfersExecution = isGuaranteedToTransferExecutionToSuccessor(PP);
+ if (!TransfersExecution)
+ return nullptr;
+
+ // If this is not a terminator we know that there is a single instruction
+ // after this one that is executed next if control is transfered. If not,
+ // we can try to go back to a call site we entered earlier. If none exists, we
+ // do not know any instruction that has to be executd next.
+ if (!PP->isTerminator()) {
+ const Instruction *NextPP = PP->getNextNode();
+ LLVM_DEBUG(dbgs() << "\tIntermediate instruction does transfer control\n");
+ return NextPP;
+ }
+
+ // Finally, we have to handle terminators, trivial ones first.
+ assert(PP->isTerminator() && "Expected a terminator!");
+
+ // A terminator without a successor is not handled yet.
+ if (PP->getNumSuccessors() == 0) {
+ LLVM_DEBUG(dbgs() << "\tUnhandled terminator\n");
+ return nullptr;
+ }
+
+ // A terminator with a single successor, we will continue at the beginning of
+ // that one.
+ if (PP->getNumSuccessors() == 1) {
+ LLVM_DEBUG(
+ dbgs() << "\tUnconditional terminator, continue with successor\n");
+ return &PP->getSuccessor(0)->front();
+ }
+
+ LLVM_DEBUG(dbgs() << "\tNo join point found\n");
+ return nullptr;
+}
+
+MustBeExecutedIterator::MustBeExecutedIterator(
+ MustBeExecutedContextExplorer &Explorer, const Instruction *I)
+ : Explorer(Explorer), CurInst(I) {
+ reset(I);
+}
+
+void MustBeExecutedIterator::reset(const Instruction *I) {
+ CurInst = I;
+ Visited.clear();
+ Visited.insert(I);
+}
+
+const Instruction *MustBeExecutedIterator::advance() {
+ assert(CurInst && "Cannot advance an end iterator!");
+ const Instruction *Next =
+ Explorer.getMustBeExecutedNextInstruction(*this, CurInst);
+ if (Next && !Visited.insert(Next).second)
+ Next = nullptr;
+ return Next;
+}
diff --git a/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp
new file mode 100644
index 000000000000..811033e73147
--- /dev/null
+++ b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp
@@ -0,0 +1,164 @@
+//===- ObjCARCAliasAnalysis.cpp - ObjC ARC Optimization -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines a simple ARC-aware AliasAnalysis using special knowledge
+/// of Objective C to enhance other optimization passes which rely on the Alias
+/// Analysis infrastructure.
+///
+/// WARNING: This file knows about certain library functions. It recognizes them
+/// by name, and hardwires knowledge of their semantics.
+///
+/// WARNING: This file knows about how certain Objective-C library functions are
+/// used. Naive LLVM IR transformations which would otherwise be
+/// behavior-preserving may break these assumptions.
+///
+/// TODO: Theoretically we could check for dependencies between objc_* calls
+/// and FMRB_OnlyAccessesArgumentPointees calls or other well-behaved calls.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ObjCARCAliasAnalysis.h"
+#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/PassAnalysisSupport.h"
+#include "llvm/PassSupport.h"
+
+#define DEBUG_TYPE "objc-arc-aa"
+
+using namespace llvm;
+using namespace llvm::objcarc;
+
+AliasResult ObjCARCAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ if (!EnableARCOpts)
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ // First, strip off no-ops, including ObjC-specific no-ops, and try making a
+ // precise alias query.
+ const Value *SA = GetRCIdentityRoot(LocA.Ptr);
+ const Value *SB = GetRCIdentityRoot(LocB.Ptr);
+ AliasResult Result =
+ AAResultBase::alias(MemoryLocation(SA, LocA.Size, LocA.AATags),
+ MemoryLocation(SB, LocB.Size, LocB.AATags), AAQI);
+ if (Result != MayAlias)
+ return Result;
+
+ // If that failed, climb to the underlying object, including climbing through
+ // ObjC-specific no-ops, and try making an imprecise alias query.
+ const Value *UA = GetUnderlyingObjCPtr(SA, DL);
+ const Value *UB = GetUnderlyingObjCPtr(SB, DL);
+ if (UA != SA || UB != SB) {
+ Result = AAResultBase::alias(MemoryLocation(UA), MemoryLocation(UB), AAQI);
+ // We can't use MustAlias or PartialAlias results here because
+ // GetUnderlyingObjCPtr may return an offsetted pointer value.
+ if (Result == NoAlias)
+ return NoAlias;
+ }
+
+ // If that failed, fail. We don't need to chain here, since that's covered
+ // by the earlier precise query.
+ return MayAlias;
+}
+
+bool ObjCARCAAResult::pointsToConstantMemory(const MemoryLocation &Loc,
+ AAQueryInfo &AAQI, bool OrLocal) {
+ if (!EnableARCOpts)
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+
+ // First, strip off no-ops, including ObjC-specific no-ops, and try making
+ // a precise alias query.
+ const Value *S = GetRCIdentityRoot(Loc.Ptr);
+ if (AAResultBase::pointsToConstantMemory(
+ MemoryLocation(S, Loc.Size, Loc.AATags), AAQI, OrLocal))
+ return true;
+
+ // If that failed, climb to the underlying object, including climbing through
+ // ObjC-specific no-ops, and try making an imprecise alias query.
+ const Value *U = GetUnderlyingObjCPtr(S, DL);
+ if (U != S)
+ return AAResultBase::pointsToConstantMemory(MemoryLocation(U), AAQI,
+ OrLocal);
+
+ // If that failed, fail. We don't need to chain here, since that's covered
+ // by the earlier precise query.
+ return false;
+}
+
+FunctionModRefBehavior ObjCARCAAResult::getModRefBehavior(const Function *F) {
+ if (!EnableARCOpts)
+ return AAResultBase::getModRefBehavior(F);
+
+ switch (GetFunctionClass(F)) {
+ case ARCInstKind::NoopCast:
+ return FMRB_DoesNotAccessMemory;
+ default:
+ break;
+ }
+
+ return AAResultBase::getModRefBehavior(F);
+}
+
+ModRefInfo ObjCARCAAResult::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (!EnableARCOpts)
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+
+ switch (GetBasicARCInstKind(Call)) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::NoopCast:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ // These functions don't access any memory visible to the compiler.
+ // Note that this doesn't include objc_retainBlock, because it updates
+ // pointers when it copies block data.
+ return ModRefInfo::NoModRef;
+ default:
+ break;
+ }
+
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+}
+
+ObjCARCAAResult ObjCARCAA::run(Function &F, FunctionAnalysisManager &AM) {
+ return ObjCARCAAResult(F.getParent()->getDataLayout());
+}
+
+char ObjCARCAAWrapperPass::ID = 0;
+INITIALIZE_PASS(ObjCARCAAWrapperPass, "objc-arc-aa",
+ "ObjC-ARC-Based Alias Analysis", false, true)
+
+ImmutablePass *llvm::createObjCARCAAWrapperPass() {
+ return new ObjCARCAAWrapperPass();
+}
+
+ObjCARCAAWrapperPass::ObjCARCAAWrapperPass() : ImmutablePass(ID) {
+ initializeObjCARCAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool ObjCARCAAWrapperPass::doInitialization(Module &M) {
+ Result.reset(new ObjCARCAAResult(M.getDataLayout()));
+ return false;
+}
+
+bool ObjCARCAAWrapperPass::doFinalization(Module &M) {
+ Result.reset();
+ return false;
+}
+
+void ObjCARCAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
diff --git a/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp b/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp
new file mode 100644
index 000000000000..56d1cb421225
--- /dev/null
+++ b/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp
@@ -0,0 +1,25 @@
+//===- ObjCARCAnalysisUtils.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common infrastructure for libLLVMObjCARCOpts.a, which
+// implements several scalar transformations over the LLVM intermediate
+// representation, including the C bindings for that library.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace llvm;
+using namespace llvm::objcarc;
+
+/// A handy option to enable/disable all ARC Optimizations.
+bool llvm::objcarc::EnableARCOpts;
+static cl::opt<bool, true> EnableARCOptimizations(
+ "enable-objc-arc-opts", cl::desc("enable/disable all ARC Optimizations"),
+ cl::location(EnableARCOpts), cl::init(true), cl::Hidden);
diff --git a/llvm/lib/Analysis/ObjCARCInstKind.cpp b/llvm/lib/Analysis/ObjCARCInstKind.cpp
new file mode 100644
index 000000000000..0e96c6e975c9
--- /dev/null
+++ b/llvm/lib/Analysis/ObjCARCInstKind.cpp
@@ -0,0 +1,705 @@
+//===- ARCInstKind.cpp - ObjC ARC Optimization ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines several utility functions used by various ARC
+/// optimizations which are IMHO too big to be in a header file.
+///
+/// WARNING: This file knows about certain library functions. It recognizes them
+/// by name, and hardwires knowledge of their semantics.
+///
+/// WARNING: This file knows about how certain Objective-C library functions are
+/// used. Naive LLVM IR transformations which would otherwise be
+/// behavior-preserving may break these assumptions.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ObjCARCInstKind.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
+#include "llvm/IR/Intrinsics.h"
+
+using namespace llvm;
+using namespace llvm::objcarc;
+
+raw_ostream &llvm::objcarc::operator<<(raw_ostream &OS,
+ const ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Retain:
+ return OS << "ARCInstKind::Retain";
+ case ARCInstKind::RetainRV:
+ return OS << "ARCInstKind::RetainRV";
+ case ARCInstKind::ClaimRV:
+ return OS << "ARCInstKind::ClaimRV";
+ case ARCInstKind::RetainBlock:
+ return OS << "ARCInstKind::RetainBlock";
+ case ARCInstKind::Release:
+ return OS << "ARCInstKind::Release";
+ case ARCInstKind::Autorelease:
+ return OS << "ARCInstKind::Autorelease";
+ case ARCInstKind::AutoreleaseRV:
+ return OS << "ARCInstKind::AutoreleaseRV";
+ case ARCInstKind::AutoreleasepoolPush:
+ return OS << "ARCInstKind::AutoreleasepoolPush";
+ case ARCInstKind::AutoreleasepoolPop:
+ return OS << "ARCInstKind::AutoreleasepoolPop";
+ case ARCInstKind::NoopCast:
+ return OS << "ARCInstKind::NoopCast";
+ case ARCInstKind::FusedRetainAutorelease:
+ return OS << "ARCInstKind::FusedRetainAutorelease";
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ return OS << "ARCInstKind::FusedRetainAutoreleaseRV";
+ case ARCInstKind::LoadWeakRetained:
+ return OS << "ARCInstKind::LoadWeakRetained";
+ case ARCInstKind::StoreWeak:
+ return OS << "ARCInstKind::StoreWeak";
+ case ARCInstKind::InitWeak:
+ return OS << "ARCInstKind::InitWeak";
+ case ARCInstKind::LoadWeak:
+ return OS << "ARCInstKind::LoadWeak";
+ case ARCInstKind::MoveWeak:
+ return OS << "ARCInstKind::MoveWeak";
+ case ARCInstKind::CopyWeak:
+ return OS << "ARCInstKind::CopyWeak";
+ case ARCInstKind::DestroyWeak:
+ return OS << "ARCInstKind::DestroyWeak";
+ case ARCInstKind::StoreStrong:
+ return OS << "ARCInstKind::StoreStrong";
+ case ARCInstKind::CallOrUser:
+ return OS << "ARCInstKind::CallOrUser";
+ case ARCInstKind::Call:
+ return OS << "ARCInstKind::Call";
+ case ARCInstKind::User:
+ return OS << "ARCInstKind::User";
+ case ARCInstKind::IntrinsicUser:
+ return OS << "ARCInstKind::IntrinsicUser";
+ case ARCInstKind::None:
+ return OS << "ARCInstKind::None";
+ }
+ llvm_unreachable("Unknown instruction class!");
+}
+
+ARCInstKind llvm::objcarc::GetFunctionClass(const Function *F) {
+
+ Intrinsic::ID ID = F->getIntrinsicID();
+ switch (ID) {
+ default:
+ return ARCInstKind::CallOrUser;
+ case Intrinsic::objc_autorelease:
+ return ARCInstKind::Autorelease;
+ case Intrinsic::objc_autoreleasePoolPop:
+ return ARCInstKind::AutoreleasepoolPop;
+ case Intrinsic::objc_autoreleasePoolPush:
+ return ARCInstKind::AutoreleasepoolPush;
+ case Intrinsic::objc_autoreleaseReturnValue:
+ return ARCInstKind::AutoreleaseRV;
+ case Intrinsic::objc_copyWeak:
+ return ARCInstKind::CopyWeak;
+ case Intrinsic::objc_destroyWeak:
+ return ARCInstKind::DestroyWeak;
+ case Intrinsic::objc_initWeak:
+ return ARCInstKind::InitWeak;
+ case Intrinsic::objc_loadWeak:
+ return ARCInstKind::LoadWeak;
+ case Intrinsic::objc_loadWeakRetained:
+ return ARCInstKind::LoadWeakRetained;
+ case Intrinsic::objc_moveWeak:
+ return ARCInstKind::MoveWeak;
+ case Intrinsic::objc_release:
+ return ARCInstKind::Release;
+ case Intrinsic::objc_retain:
+ return ARCInstKind::Retain;
+ case Intrinsic::objc_retainAutorelease:
+ return ARCInstKind::FusedRetainAutorelease;
+ case Intrinsic::objc_retainAutoreleaseReturnValue:
+ return ARCInstKind::FusedRetainAutoreleaseRV;
+ case Intrinsic::objc_retainAutoreleasedReturnValue:
+ return ARCInstKind::RetainRV;
+ case Intrinsic::objc_retainBlock:
+ return ARCInstKind::RetainBlock;
+ case Intrinsic::objc_storeStrong:
+ return ARCInstKind::StoreStrong;
+ case Intrinsic::objc_storeWeak:
+ return ARCInstKind::StoreWeak;
+ case Intrinsic::objc_clang_arc_use:
+ return ARCInstKind::IntrinsicUser;
+ case Intrinsic::objc_unsafeClaimAutoreleasedReturnValue:
+ return ARCInstKind::ClaimRV;
+ case Intrinsic::objc_retainedObject:
+ return ARCInstKind::NoopCast;
+ case Intrinsic::objc_unretainedObject:
+ return ARCInstKind::NoopCast;
+ case Intrinsic::objc_unretainedPointer:
+ return ARCInstKind::NoopCast;
+ case Intrinsic::objc_retain_autorelease:
+ return ARCInstKind::FusedRetainAutorelease;
+ case Intrinsic::objc_sync_enter:
+ return ARCInstKind::User;
+ case Intrinsic::objc_sync_exit:
+ return ARCInstKind::User;
+ case Intrinsic::objc_arc_annotation_topdown_bbstart:
+ case Intrinsic::objc_arc_annotation_topdown_bbend:
+ case Intrinsic::objc_arc_annotation_bottomup_bbstart:
+ case Intrinsic::objc_arc_annotation_bottomup_bbend:
+ // Ignore annotation calls. This is important to stop the
+ // optimizer from treating annotations as uses which would
+ // make the state of the pointers they are attempting to
+ // elucidate to be incorrect.
+ return ARCInstKind::None;
+ }
+}
+
+// A whitelist of intrinsics that we know do not use objc pointers or decrement
+// ref counts.
+static bool isInertIntrinsic(unsigned ID) {
+ // TODO: Make this into a covered switch.
+ switch (ID) {
+ case Intrinsic::returnaddress:
+ case Intrinsic::addressofreturnaddress:
+ case Intrinsic::frameaddress:
+ case Intrinsic::stacksave:
+ case Intrinsic::stackrestore:
+ case Intrinsic::vastart:
+ case Intrinsic::vacopy:
+ case Intrinsic::vaend:
+ case Intrinsic::objectsize:
+ case Intrinsic::prefetch:
+ case Intrinsic::stackprotector:
+ case Intrinsic::eh_return_i32:
+ case Intrinsic::eh_return_i64:
+ case Intrinsic::eh_typeid_for:
+ case Intrinsic::eh_dwarf_cfa:
+ case Intrinsic::eh_sjlj_lsda:
+ case Intrinsic::eh_sjlj_functioncontext:
+ case Intrinsic::init_trampoline:
+ case Intrinsic::adjust_trampoline:
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ case Intrinsic::invariant_start:
+ case Intrinsic::invariant_end:
+ // Don't let dbg info affect our results.
+ case Intrinsic::dbg_declare:
+ case Intrinsic::dbg_value:
+ case Intrinsic::dbg_label:
+ // Short cut: Some intrinsics obviously don't use ObjC pointers.
+ return true;
+ default:
+ return false;
+ }
+}
+
+// A whitelist of intrinsics that we know do not use objc pointers or decrement
+// ref counts.
+static bool isUseOnlyIntrinsic(unsigned ID) {
+ // We are conservative and even though intrinsics are unlikely to touch
+ // reference counts, we white list them for safety.
+ //
+ // TODO: Expand this into a covered switch. There is a lot more here.
+ switch (ID) {
+ case Intrinsic::memcpy:
+ case Intrinsic::memmove:
+ case Intrinsic::memset:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// Determine what kind of construct V is.
+ARCInstKind llvm::objcarc::GetARCInstKind(const Value *V) {
+ if (const Instruction *I = dyn_cast<Instruction>(V)) {
+ // Any instruction other than bitcast and gep with a pointer operand have a
+ // use of an objc pointer. Bitcasts, GEPs, Selects, PHIs transfer a pointer
+ // to a subsequent use, rather than using it themselves, in this sense.
+ // As a short cut, several other opcodes are known to have no pointer
+ // operands of interest. And ret is never followed by a release, so it's
+ // not interesting to examine.
+ switch (I->getOpcode()) {
+ case Instruction::Call: {
+ const CallInst *CI = cast<CallInst>(I);
+ // See if we have a function that we know something about.
+ if (const Function *F = CI->getCalledFunction()) {
+ ARCInstKind Class = GetFunctionClass(F);
+ if (Class != ARCInstKind::CallOrUser)
+ return Class;
+ Intrinsic::ID ID = F->getIntrinsicID();
+ if (isInertIntrinsic(ID))
+ return ARCInstKind::None;
+ if (isUseOnlyIntrinsic(ID))
+ return ARCInstKind::User;
+ }
+
+ // Otherwise, be conservative.
+ return GetCallSiteClass(CI);
+ }
+ case Instruction::Invoke:
+ // Otherwise, be conservative.
+ return GetCallSiteClass(cast<InvokeInst>(I));
+ case Instruction::BitCast:
+ case Instruction::GetElementPtr:
+ case Instruction::Select:
+ case Instruction::PHI:
+ case Instruction::Ret:
+ case Instruction::Br:
+ case Instruction::Switch:
+ case Instruction::IndirectBr:
+ case Instruction::Alloca:
+ case Instruction::VAArg:
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::SDiv:
+ case Instruction::UDiv:
+ case Instruction::FDiv:
+ case Instruction::SRem:
+ case Instruction::URem:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ case Instruction::Trunc:
+ case Instruction::IntToPtr:
+ case Instruction::FCmp:
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::UIToFP:
+ case Instruction::SIToFP:
+ case Instruction::InsertElement:
+ case Instruction::ExtractElement:
+ case Instruction::ShuffleVector:
+ case Instruction::ExtractValue:
+ break;
+ case Instruction::ICmp:
+ // Comparing a pointer with null, or any other constant, isn't an
+ // interesting use, because we don't care what the pointer points to, or
+ // about the values of any other dynamic reference-counted pointers.
+ if (IsPotentialRetainableObjPtr(I->getOperand(1)))
+ return ARCInstKind::User;
+ break;
+ default:
+ // For anything else, check all the operands.
+ // Note that this includes both operands of a Store: while the first
+ // operand isn't actually being dereferenced, it is being stored to
+ // memory where we can no longer track who might read it and dereference
+ // it, so we have to consider it potentially used.
+ for (User::const_op_iterator OI = I->op_begin(), OE = I->op_end();
+ OI != OE; ++OI)
+ if (IsPotentialRetainableObjPtr(*OI))
+ return ARCInstKind::User;
+ }
+ }
+
+ // Otherwise, it's totally inert for ARC purposes.
+ return ARCInstKind::None;
+}
+
+/// Test if the given class is a kind of user.
+bool llvm::objcarc::IsUser(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::User:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::IntrinsicUser:
+ return true;
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::NoopCast:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::Call:
+ case ARCInstKind::None:
+ case ARCInstKind::ClaimRV:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class is objc_retain or equivalent.
+bool llvm::objcarc::IsRetain(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ return true;
+ // I believe we treat retain block as not a retain since it can copy its
+ // block.
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::NoopCast:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::ClaimRV:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class is objc_autorelease or equivalent.
+bool llvm::objcarc::IsAutorelease(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ return true;
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::Release:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::NoopCast:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which return their
+/// argument verbatim.
+bool llvm::objcarc::IsForwarding(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::NoopCast:
+ return true;
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::Release:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which do nothing if
+/// passed a null pointer.
+bool llvm::objcarc::IsNoopOnNull(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::RetainBlock:
+ return true;
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which do nothing if
+/// passed a global variable.
+bool llvm::objcarc::IsNoopOnGlobal(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ return true;
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which are always safe
+/// to mark with the "tail" keyword.
+bool llvm::objcarc::IsAlwaysTail(ARCInstKind Class) {
+ // ARCInstKind::RetainBlock may be given a stack argument.
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::AutoreleaseRV:
+ return true;
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which are never safe
+/// to mark with the "tail" keyword.
+bool llvm::objcarc::IsNeverTail(ARCInstKind Class) {
+ /// It is never safe to tail call objc_autorelease since by tail calling
+ /// objc_autorelease: fast autoreleasing causing our object to be potentially
+ /// reclaimed from the autorelease pool which violates the semantics of
+ /// __autoreleasing types in ARC.
+ switch (Class) {
+ case ARCInstKind::Autorelease:
+ return true;
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::Release:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test if the given class represents instructions which are always safe
+/// to mark with the nounwind attribute.
+bool llvm::objcarc::IsNoThrow(ARCInstKind Class) {
+ // objc_retainBlock is not nounwind because it calls user copy constructors
+ // which could theoretically throw.
+ switch (Class) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::Release:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ return true;
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+/// Test whether the given instruction can autorelease any pointer or cause an
+/// autoreleasepool pop.
+///
+/// This means that it *could* interrupt the RV optimization.
+bool llvm::objcarc::CanInterruptRV(ARCInstKind Class) {
+ switch (Class) {
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ return true;
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::ClaimRV:
+ case ARCInstKind::Release:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ case ARCInstKind::NoopCast:
+ return false;
+ }
+ llvm_unreachable("covered switch isn't covered?");
+}
+
+bool llvm::objcarc::CanDecrementRefCount(ARCInstKind Kind) {
+ switch (Kind) {
+ case ARCInstKind::Retain:
+ case ARCInstKind::RetainRV:
+ case ARCInstKind::Autorelease:
+ case ARCInstKind::AutoreleaseRV:
+ case ARCInstKind::NoopCast:
+ case ARCInstKind::FusedRetainAutorelease:
+ case ARCInstKind::FusedRetainAutoreleaseRV:
+ case ARCInstKind::IntrinsicUser:
+ case ARCInstKind::User:
+ case ARCInstKind::None:
+ return false;
+
+ // The cases below are conservative.
+
+ // RetainBlock can result in user defined copy constructors being called
+ // implying releases may occur.
+ case ARCInstKind::RetainBlock:
+ case ARCInstKind::Release:
+ case ARCInstKind::AutoreleasepoolPush:
+ case ARCInstKind::AutoreleasepoolPop:
+ case ARCInstKind::LoadWeakRetained:
+ case ARCInstKind::StoreWeak:
+ case ARCInstKind::InitWeak:
+ case ARCInstKind::LoadWeak:
+ case ARCInstKind::MoveWeak:
+ case ARCInstKind::CopyWeak:
+ case ARCInstKind::DestroyWeak:
+ case ARCInstKind::StoreStrong:
+ case ARCInstKind::CallOrUser:
+ case ARCInstKind::Call:
+ case ARCInstKind::ClaimRV:
+ return true;
+ }
+
+ llvm_unreachable("covered switch isn't covered?");
+}
diff --git a/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp
new file mode 100644
index 000000000000..07a5619a35b9
--- /dev/null
+++ b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp
@@ -0,0 +1,133 @@
+//===- OptimizationRemarkEmitter.cpp - Optimization Diagnostic --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimization diagnostic interfaces. It's packaged as an analysis pass so
+// that by using this service passes become dependent on BFI as well. BFI is
+// used to compute the "hotness" of the diagnostic message.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LazyBlockFrequencyInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/LLVMContext.h"
+
+using namespace llvm;
+
+OptimizationRemarkEmitter::OptimizationRemarkEmitter(const Function *F)
+ : F(F), BFI(nullptr) {
+ if (!F->getContext().getDiagnosticsHotnessRequested())
+ return;
+
+ // First create a dominator tree.
+ DominatorTree DT;
+ DT.recalculate(*const_cast<Function *>(F));
+
+ // Generate LoopInfo from it.
+ LoopInfo LI;
+ LI.analyze(DT);
+
+ // Then compute BranchProbabilityInfo.
+ BranchProbabilityInfo BPI;
+ BPI.calculate(*F, LI);
+
+ // Finally compute BFI.
+ OwnedBFI = std::make_unique<BlockFrequencyInfo>(*F, BPI, LI);
+ BFI = OwnedBFI.get();
+}
+
+bool OptimizationRemarkEmitter::invalidate(
+ Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // This analysis has no state and so can be trivially preserved but it needs
+ // a fresh view of BFI if it was constructed with one.
+ if (BFI && Inv.invalidate<BlockFrequencyAnalysis>(F, PA))
+ return true;
+
+ // Otherwise this analysis result remains valid.
+ return false;
+}
+
+Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(const Value *V) {
+ if (!BFI)
+ return None;
+
+ return BFI->getBlockProfileCount(cast<BasicBlock>(V));
+}
+
+void OptimizationRemarkEmitter::computeHotness(
+ DiagnosticInfoIROptimization &OptDiag) {
+ const Value *V = OptDiag.getCodeRegion();
+ if (V)
+ OptDiag.setHotness(computeHotness(V));
+}
+
+void OptimizationRemarkEmitter::emit(
+ DiagnosticInfoOptimizationBase &OptDiagBase) {
+ auto &OptDiag = cast<DiagnosticInfoIROptimization>(OptDiagBase);
+ computeHotness(OptDiag);
+
+ // Only emit it if its hotness meets the threshold.
+ if (OptDiag.getHotness().getValueOr(0) <
+ F->getContext().getDiagnosticsHotnessThreshold()) {
+ return;
+ }
+
+ F->getContext().diagnose(OptDiag);
+}
+
+OptimizationRemarkEmitterWrapperPass::OptimizationRemarkEmitterWrapperPass()
+ : FunctionPass(ID) {
+ initializeOptimizationRemarkEmitterWrapperPassPass(
+ *PassRegistry::getPassRegistry());
+}
+
+bool OptimizationRemarkEmitterWrapperPass::runOnFunction(Function &Fn) {
+ BlockFrequencyInfo *BFI;
+
+ if (Fn.getContext().getDiagnosticsHotnessRequested())
+ BFI = &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI();
+ else
+ BFI = nullptr;
+
+ ORE = std::make_unique<OptimizationRemarkEmitter>(&Fn, BFI);
+ return false;
+}
+
+void OptimizationRemarkEmitterWrapperPass::getAnalysisUsage(
+ AnalysisUsage &AU) const {
+ LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU);
+ AU.setPreservesAll();
+}
+
+AnalysisKey OptimizationRemarkEmitterAnalysis::Key;
+
+OptimizationRemarkEmitter
+OptimizationRemarkEmitterAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ BlockFrequencyInfo *BFI;
+
+ if (F.getContext().getDiagnosticsHotnessRequested())
+ BFI = &AM.getResult<BlockFrequencyAnalysis>(F);
+ else
+ BFI = nullptr;
+
+ return OptimizationRemarkEmitter(&F, BFI);
+}
+
+char OptimizationRemarkEmitterWrapperPass::ID = 0;
+static const char ore_name[] = "Optimization Remark Emitter";
+#define ORE_NAME "opt-remark-emitter"
+
+INITIALIZE_PASS_BEGIN(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name,
+ false, true)
+INITIALIZE_PASS_DEPENDENCY(LazyBFIPass)
+INITIALIZE_PASS_END(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name,
+ false, true)
diff --git a/llvm/lib/Analysis/OrderedBasicBlock.cpp b/llvm/lib/Analysis/OrderedBasicBlock.cpp
new file mode 100644
index 000000000000..48f2a4020c66
--- /dev/null
+++ b/llvm/lib/Analysis/OrderedBasicBlock.cpp
@@ -0,0 +1,111 @@
+//===- OrderedBasicBlock.cpp --------------------------------- -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the OrderedBasicBlock class. OrderedBasicBlock
+// maintains an interface where clients can query if one instruction comes
+// before another in a BasicBlock. Since BasicBlock currently lacks a reliable
+// way to query relative position between instructions one can use
+// OrderedBasicBlock to do such queries. OrderedBasicBlock is lazily built on a
+// source BasicBlock and maintains an internal Instruction -> Position map. A
+// OrderedBasicBlock instance should be discarded whenever the source
+// BasicBlock changes.
+//
+// It's currently used by the CaptureTracker in order to find relative
+// positions of a pair of instructions inside a BasicBlock.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/OrderedBasicBlock.h"
+#include "llvm/IR/Instruction.h"
+using namespace llvm;
+
+OrderedBasicBlock::OrderedBasicBlock(const BasicBlock *BasicB)
+ : NextInstPos(0), BB(BasicB) {
+ LastInstFound = BB->end();
+}
+
+/// Given no cached results, find if \p A comes before \p B in \p BB.
+/// Cache and number out instruction while walking \p BB.
+bool OrderedBasicBlock::comesBefore(const Instruction *A,
+ const Instruction *B) {
+ const Instruction *Inst = nullptr;
+ assert(!(LastInstFound == BB->end() && NextInstPos != 0) &&
+ "Instruction supposed to be in NumberedInsts");
+ assert(A->getParent() == BB && "Instruction supposed to be in the block!");
+ assert(B->getParent() == BB && "Instruction supposed to be in the block!");
+
+ // Start the search with the instruction found in the last lookup round.
+ auto II = BB->begin();
+ auto IE = BB->end();
+ if (LastInstFound != IE)
+ II = std::next(LastInstFound);
+
+ // Number all instructions up to the point where we find 'A' or 'B'.
+ for (; II != IE; ++II) {
+ Inst = cast<Instruction>(II);
+ NumberedInsts[Inst] = NextInstPos++;
+ if (Inst == A || Inst == B)
+ break;
+ }
+
+ assert(II != IE && "Instruction not found?");
+ assert((Inst == A || Inst == B) && "Should find A or B");
+ LastInstFound = II;
+ return Inst != B;
+}
+
+/// Find out whether \p A dominates \p B, meaning whether \p A
+/// comes before \p B in \p BB. This is a simplification that considers
+/// cached instruction positions and ignores other basic blocks, being
+/// only relevant to compare relative instructions positions inside \p BB.
+bool OrderedBasicBlock::dominates(const Instruction *A, const Instruction *B) {
+ assert(A->getParent() == B->getParent() &&
+ "Instructions must be in the same basic block!");
+ assert(A->getParent() == BB && "Instructions must be in the tracked block!");
+
+ // First we lookup the instructions. If they don't exist, lookup will give us
+ // back ::end(). If they both exist, we compare the numbers. Otherwise, if NA
+ // exists and NB doesn't, it means NA must come before NB because we would
+ // have numbered NB as well if it didn't. The same is true for NB. If it
+ // exists, but NA does not, NA must come after it. If neither exist, we need
+ // to number the block and cache the results (by calling comesBefore).
+ auto NAI = NumberedInsts.find(A);
+ auto NBI = NumberedInsts.find(B);
+ if (NAI != NumberedInsts.end() && NBI != NumberedInsts.end())
+ return NAI->second < NBI->second;
+ if (NAI != NumberedInsts.end())
+ return true;
+ if (NBI != NumberedInsts.end())
+ return false;
+
+ return comesBefore(A, B);
+}
+
+void OrderedBasicBlock::eraseInstruction(const Instruction *I) {
+ if (LastInstFound != BB->end() && I == &*LastInstFound) {
+ if (LastInstFound == BB->begin()) {
+ LastInstFound = BB->end();
+ NextInstPos = 0;
+ } else
+ LastInstFound--;
+ }
+
+ NumberedInsts.erase(I);
+}
+
+void OrderedBasicBlock::replaceInstruction(const Instruction *Old,
+ const Instruction *New) {
+ auto OI = NumberedInsts.find(Old);
+ if (OI == NumberedInsts.end())
+ return;
+
+ NumberedInsts.insert({New, OI->second});
+ if (LastInstFound != BB->end() && Old == &*LastInstFound)
+ LastInstFound = New->getIterator();
+ NumberedInsts.erase(Old);
+}
diff --git a/llvm/lib/Analysis/OrderedInstructions.cpp b/llvm/lib/Analysis/OrderedInstructions.cpp
new file mode 100644
index 000000000000..e947e5e388a8
--- /dev/null
+++ b/llvm/lib/Analysis/OrderedInstructions.cpp
@@ -0,0 +1,50 @@
+//===-- OrderedInstructions.cpp - Instruction dominance function ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utility to check dominance relation of 2 instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/OrderedInstructions.h"
+using namespace llvm;
+
+bool OrderedInstructions::localDominates(const Instruction *InstA,
+ const Instruction *InstB) const {
+ assert(InstA->getParent() == InstB->getParent() &&
+ "Instructions must be in the same basic block");
+
+ const BasicBlock *IBB = InstA->getParent();
+ auto OBB = OBBMap.find(IBB);
+ if (OBB == OBBMap.end())
+ OBB = OBBMap.insert({IBB, std::make_unique<OrderedBasicBlock>(IBB)}).first;
+ return OBB->second->dominates(InstA, InstB);
+}
+
+/// Given 2 instructions, use OrderedBasicBlock to check for dominance relation
+/// if the instructions are in the same basic block, Otherwise, use dominator
+/// tree.
+bool OrderedInstructions::dominates(const Instruction *InstA,
+ const Instruction *InstB) const {
+ // Use ordered basic block to do dominance check in case the 2 instructions
+ // are in the same basic block.
+ if (InstA->getParent() == InstB->getParent())
+ return localDominates(InstA, InstB);
+ return DT->dominates(InstA->getParent(), InstB->getParent());
+}
+
+bool OrderedInstructions::dfsBefore(const Instruction *InstA,
+ const Instruction *InstB) const {
+ // Use ordered basic block in case the 2 instructions are in the same basic
+ // block.
+ if (InstA->getParent() == InstB->getParent())
+ return localDominates(InstA, InstB);
+
+ DomTreeNode *DA = DT->getNode(InstA->getParent());
+ DomTreeNode *DB = DT->getNode(InstB->getParent());
+ return DA->getDFSNumIn() < DB->getDFSNumIn();
+}
diff --git a/llvm/lib/Analysis/PHITransAddr.cpp b/llvm/lib/Analysis/PHITransAddr.cpp
new file mode 100644
index 000000000000..7f77ab146c4c
--- /dev/null
+++ b/llvm/lib/Analysis/PHITransAddr.cpp
@@ -0,0 +1,439 @@
+//===- PHITransAddr.cpp - PHI Translation for Addresses -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the PHITransAddr class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/PHITransAddr.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+static bool CanPHITrans(Instruction *Inst) {
+ if (isa<PHINode>(Inst) ||
+ isa<GetElementPtrInst>(Inst))
+ return true;
+
+ if (isa<CastInst>(Inst) &&
+ isSafeToSpeculativelyExecute(Inst))
+ return true;
+
+ if (Inst->getOpcode() == Instruction::Add &&
+ isa<ConstantInt>(Inst->getOperand(1)))
+ return true;
+
+ // cerr << "MEMDEP: Could not PHI translate: " << *Pointer;
+ // if (isa<BitCastInst>(PtrInst) || isa<GetElementPtrInst>(PtrInst))
+ // cerr << "OP:\t\t\t\t" << *PtrInst->getOperand(0);
+ return false;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void PHITransAddr::dump() const {
+ if (!Addr) {
+ dbgs() << "PHITransAddr: null\n";
+ return;
+ }
+ dbgs() << "PHITransAddr: " << *Addr << "\n";
+ for (unsigned i = 0, e = InstInputs.size(); i != e; ++i)
+ dbgs() << " Input #" << i << " is " << *InstInputs[i] << "\n";
+}
+#endif
+
+
+static bool VerifySubExpr(Value *Expr,
+ SmallVectorImpl<Instruction*> &InstInputs) {
+ // If this is a non-instruction value, there is nothing to do.
+ Instruction *I = dyn_cast<Instruction>(Expr);
+ if (!I) return true;
+
+ // If it's an instruction, it is either in Tmp or its operands recursively
+ // are.
+ SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I);
+ if (Entry != InstInputs.end()) {
+ InstInputs.erase(Entry);
+ return true;
+ }
+
+ // If it isn't in the InstInputs list it is a subexpr incorporated into the
+ // address. Sanity check that it is phi translatable.
+ if (!CanPHITrans(I)) {
+ errs() << "Instruction in PHITransAddr is not phi-translatable:\n";
+ errs() << *I << '\n';
+ llvm_unreachable("Either something is missing from InstInputs or "
+ "CanPHITrans is wrong.");
+ }
+
+ // Validate the operands of the instruction.
+ for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
+ if (!VerifySubExpr(I->getOperand(i), InstInputs))
+ return false;
+
+ return true;
+}
+
+/// Verify - Check internal consistency of this data structure. If the
+/// structure is valid, it returns true. If invalid, it prints errors and
+/// returns false.
+bool PHITransAddr::Verify() const {
+ if (!Addr) return true;
+
+ SmallVector<Instruction*, 8> Tmp(InstInputs.begin(), InstInputs.end());
+
+ if (!VerifySubExpr(Addr, Tmp))
+ return false;
+
+ if (!Tmp.empty()) {
+ errs() << "PHITransAddr contains extra instructions:\n";
+ for (unsigned i = 0, e = InstInputs.size(); i != e; ++i)
+ errs() << " InstInput #" << i << " is " << *InstInputs[i] << "\n";
+ llvm_unreachable("This is unexpected.");
+ }
+
+ // a-ok.
+ return true;
+}
+
+
+/// IsPotentiallyPHITranslatable - If this needs PHI translation, return true
+/// if we have some hope of doing it. This should be used as a filter to
+/// avoid calling PHITranslateValue in hopeless situations.
+bool PHITransAddr::IsPotentiallyPHITranslatable() const {
+ // If the input value is not an instruction, or if it is not defined in CurBB,
+ // then we don't need to phi translate it.
+ Instruction *Inst = dyn_cast<Instruction>(Addr);
+ return !Inst || CanPHITrans(Inst);
+}
+
+
+static void RemoveInstInputs(Value *V,
+ SmallVectorImpl<Instruction*> &InstInputs) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I) return;
+
+ // If the instruction is in the InstInputs list, remove it.
+ SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I);
+ if (Entry != InstInputs.end()) {
+ InstInputs.erase(Entry);
+ return;
+ }
+
+ assert(!isa<PHINode>(I) && "Error, removing something that isn't an input");
+
+ // Otherwise, it must have instruction inputs itself. Zap them recursively.
+ for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+ if (Instruction *Op = dyn_cast<Instruction>(I->getOperand(i)))
+ RemoveInstInputs(Op, InstInputs);
+ }
+}
+
+Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB,
+ BasicBlock *PredBB,
+ const DominatorTree *DT) {
+ // If this is a non-instruction value, it can't require PHI translation.
+ Instruction *Inst = dyn_cast<Instruction>(V);
+ if (!Inst) return V;
+
+ // Determine whether 'Inst' is an input to our PHI translatable expression.
+ bool isInput = is_contained(InstInputs, Inst);
+
+ // Handle inputs instructions if needed.
+ if (isInput) {
+ if (Inst->getParent() != CurBB) {
+ // If it is an input defined in a different block, then it remains an
+ // input.
+ return Inst;
+ }
+
+ // If 'Inst' is defined in this block and is an input that needs to be phi
+ // translated, we need to incorporate the value into the expression or fail.
+
+ // In either case, the instruction itself isn't an input any longer.
+ InstInputs.erase(find(InstInputs, Inst));
+
+ // If this is a PHI, go ahead and translate it.
+ if (PHINode *PN = dyn_cast<PHINode>(Inst))
+ return AddAsInput(PN->getIncomingValueForBlock(PredBB));
+
+ // If this is a non-phi value, and it is analyzable, we can incorporate it
+ // into the expression by making all instruction operands be inputs.
+ if (!CanPHITrans(Inst))
+ return nullptr;
+
+ // All instruction operands are now inputs (and of course, they may also be
+ // defined in this block, so they may need to be phi translated themselves.
+ for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i)
+ if (Instruction *Op = dyn_cast<Instruction>(Inst->getOperand(i)))
+ InstInputs.push_back(Op);
+ }
+
+ // Ok, it must be an intermediate result (either because it started that way
+ // or because we just incorporated it into the expression). See if its
+ // operands need to be phi translated, and if so, reconstruct it.
+
+ if (CastInst *Cast = dyn_cast<CastInst>(Inst)) {
+ if (!isSafeToSpeculativelyExecute(Cast)) return nullptr;
+ Value *PHIIn = PHITranslateSubExpr(Cast->getOperand(0), CurBB, PredBB, DT);
+ if (!PHIIn) return nullptr;
+ if (PHIIn == Cast->getOperand(0))
+ return Cast;
+
+ // Find an available version of this cast.
+
+ // Constants are trivial to find.
+ if (Constant *C = dyn_cast<Constant>(PHIIn))
+ return AddAsInput(ConstantExpr::getCast(Cast->getOpcode(),
+ C, Cast->getType()));
+
+ // Otherwise we have to see if a casted version of the incoming pointer
+ // is available. If so, we can use it, otherwise we have to fail.
+ for (User *U : PHIIn->users()) {
+ if (CastInst *CastI = dyn_cast<CastInst>(U))
+ if (CastI->getOpcode() == Cast->getOpcode() &&
+ CastI->getType() == Cast->getType() &&
+ (!DT || DT->dominates(CastI->getParent(), PredBB)))
+ return CastI;
+ }
+ return nullptr;
+ }
+
+ // Handle getelementptr with at least one PHI translatable operand.
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) {
+ SmallVector<Value*, 8> GEPOps;
+ bool AnyChanged = false;
+ for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) {
+ Value *GEPOp = PHITranslateSubExpr(GEP->getOperand(i), CurBB, PredBB, DT);
+ if (!GEPOp) return nullptr;
+
+ AnyChanged |= GEPOp != GEP->getOperand(i);
+ GEPOps.push_back(GEPOp);
+ }
+
+ if (!AnyChanged)
+ return GEP;
+
+ // Simplify the GEP to handle 'gep x, 0' -> x etc.
+ if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(),
+ GEPOps, {DL, TLI, DT, AC})) {
+ for (unsigned i = 0, e = GEPOps.size(); i != e; ++i)
+ RemoveInstInputs(GEPOps[i], InstInputs);
+
+ return AddAsInput(V);
+ }
+
+ // Scan to see if we have this GEP available.
+ Value *APHIOp = GEPOps[0];
+ for (User *U : APHIOp->users()) {
+ if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U))
+ if (GEPI->getType() == GEP->getType() &&
+ GEPI->getNumOperands() == GEPOps.size() &&
+ GEPI->getParent()->getParent() == CurBB->getParent() &&
+ (!DT || DT->dominates(GEPI->getParent(), PredBB))) {
+ if (std::equal(GEPOps.begin(), GEPOps.end(), GEPI->op_begin()))
+ return GEPI;
+ }
+ }
+ return nullptr;
+ }
+
+ // Handle add with a constant RHS.
+ if (Inst->getOpcode() == Instruction::Add &&
+ isa<ConstantInt>(Inst->getOperand(1))) {
+ // PHI translate the LHS.
+ Constant *RHS = cast<ConstantInt>(Inst->getOperand(1));
+ bool isNSW = cast<BinaryOperator>(Inst)->hasNoSignedWrap();
+ bool isNUW = cast<BinaryOperator>(Inst)->hasNoUnsignedWrap();
+
+ Value *LHS = PHITranslateSubExpr(Inst->getOperand(0), CurBB, PredBB, DT);
+ if (!LHS) return nullptr;
+
+ // If the PHI translated LHS is an add of a constant, fold the immediates.
+ if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(LHS))
+ if (BOp->getOpcode() == Instruction::Add)
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
+ LHS = BOp->getOperand(0);
+ RHS = ConstantExpr::getAdd(RHS, CI);
+ isNSW = isNUW = false;
+
+ // If the old 'LHS' was an input, add the new 'LHS' as an input.
+ if (is_contained(InstInputs, BOp)) {
+ RemoveInstInputs(BOp, InstInputs);
+ AddAsInput(LHS);
+ }
+ }
+
+ // See if the add simplifies away.
+ if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) {
+ // If we simplified the operands, the LHS is no longer an input, but Res
+ // is.
+ RemoveInstInputs(LHS, InstInputs);
+ return AddAsInput(Res);
+ }
+
+ // If we didn't modify the add, just return it.
+ if (LHS == Inst->getOperand(0) && RHS == Inst->getOperand(1))
+ return Inst;
+
+ // Otherwise, see if we have this add available somewhere.
+ for (User *U : LHS->users()) {
+ if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U))
+ if (BO->getOpcode() == Instruction::Add &&
+ BO->getOperand(0) == LHS && BO->getOperand(1) == RHS &&
+ BO->getParent()->getParent() == CurBB->getParent() &&
+ (!DT || DT->dominates(BO->getParent(), PredBB)))
+ return BO;
+ }
+
+ return nullptr;
+ }
+
+ // Otherwise, we failed.
+ return nullptr;
+}
+
+
+/// PHITranslateValue - PHI translate the current address up the CFG from
+/// CurBB to Pred, updating our state to reflect any needed changes. If
+/// 'MustDominate' is true, the translated value must dominate
+/// PredBB. This returns true on failure and sets Addr to null.
+bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB,
+ const DominatorTree *DT,
+ bool MustDominate) {
+ assert(DT || !MustDominate);
+ assert(Verify() && "Invalid PHITransAddr!");
+ if (DT && DT->isReachableFromEntry(PredBB))
+ Addr =
+ PHITranslateSubExpr(Addr, CurBB, PredBB, MustDominate ? DT : nullptr);
+ else
+ Addr = nullptr;
+ assert(Verify() && "Invalid PHITransAddr!");
+
+ if (MustDominate)
+ // Make sure the value is live in the predecessor.
+ if (Instruction *Inst = dyn_cast_or_null<Instruction>(Addr))
+ if (!DT->dominates(Inst->getParent(), PredBB))
+ Addr = nullptr;
+
+ return Addr == nullptr;
+}
+
+/// PHITranslateWithInsertion - PHI translate this value into the specified
+/// predecessor block, inserting a computation of the value if it is
+/// unavailable.
+///
+/// All newly created instructions are added to the NewInsts list. This
+/// returns null on failure.
+///
+Value *PHITransAddr::
+PHITranslateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB,
+ const DominatorTree &DT,
+ SmallVectorImpl<Instruction*> &NewInsts) {
+ unsigned NISize = NewInsts.size();
+
+ // Attempt to PHI translate with insertion.
+ Addr = InsertPHITranslatedSubExpr(Addr, CurBB, PredBB, DT, NewInsts);
+
+ // If successful, return the new value.
+ if (Addr) return Addr;
+
+ // If not, destroy any intermediate instructions inserted.
+ while (NewInsts.size() != NISize)
+ NewInsts.pop_back_val()->eraseFromParent();
+ return nullptr;
+}
+
+
+/// InsertPHITranslatedPointer - Insert a computation of the PHI translated
+/// version of 'V' for the edge PredBB->CurBB into the end of the PredBB
+/// block. All newly created instructions are added to the NewInsts list.
+/// This returns null on failure.
+///
+Value *PHITransAddr::
+InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB,
+ BasicBlock *PredBB, const DominatorTree &DT,
+ SmallVectorImpl<Instruction*> &NewInsts) {
+ // See if we have a version of this value already available and dominating
+ // PredBB. If so, there is no need to insert a new instance of it.
+ PHITransAddr Tmp(InVal, DL, AC);
+ if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT, /*MustDominate=*/true))
+ return Tmp.getAddr();
+
+ // We don't need to PHI translate values which aren't instructions.
+ auto *Inst = dyn_cast<Instruction>(InVal);
+ if (!Inst)
+ return nullptr;
+
+ // Handle cast of PHI translatable value.
+ if (CastInst *Cast = dyn_cast<CastInst>(Inst)) {
+ if (!isSafeToSpeculativelyExecute(Cast)) return nullptr;
+ Value *OpVal = InsertPHITranslatedSubExpr(Cast->getOperand(0),
+ CurBB, PredBB, DT, NewInsts);
+ if (!OpVal) return nullptr;
+
+ // Otherwise insert a cast at the end of PredBB.
+ CastInst *New = CastInst::Create(Cast->getOpcode(), OpVal, InVal->getType(),
+ InVal->getName() + ".phi.trans.insert",
+ PredBB->getTerminator());
+ New->setDebugLoc(Inst->getDebugLoc());
+ NewInsts.push_back(New);
+ return New;
+ }
+
+ // Handle getelementptr with at least one PHI operand.
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) {
+ SmallVector<Value*, 8> GEPOps;
+ BasicBlock *CurBB = GEP->getParent();
+ for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) {
+ Value *OpVal = InsertPHITranslatedSubExpr(GEP->getOperand(i),
+ CurBB, PredBB, DT, NewInsts);
+ if (!OpVal) return nullptr;
+ GEPOps.push_back(OpVal);
+ }
+
+ GetElementPtrInst *Result = GetElementPtrInst::Create(
+ GEP->getSourceElementType(), GEPOps[0], makeArrayRef(GEPOps).slice(1),
+ InVal->getName() + ".phi.trans.insert", PredBB->getTerminator());
+ Result->setDebugLoc(Inst->getDebugLoc());
+ Result->setIsInBounds(GEP->isInBounds());
+ NewInsts.push_back(Result);
+ return Result;
+ }
+
+#if 0
+ // FIXME: This code works, but it is unclear that we actually want to insert
+ // a big chain of computation in order to make a value available in a block.
+ // This needs to be evaluated carefully to consider its cost trade offs.
+
+ // Handle add with a constant RHS.
+ if (Inst->getOpcode() == Instruction::Add &&
+ isa<ConstantInt>(Inst->getOperand(1))) {
+ // PHI translate the LHS.
+ Value *OpVal = InsertPHITranslatedSubExpr(Inst->getOperand(0),
+ CurBB, PredBB, DT, NewInsts);
+ if (OpVal == 0) return 0;
+
+ BinaryOperator *Res = BinaryOperator::CreateAdd(OpVal, Inst->getOperand(1),
+ InVal->getName()+".phi.trans.insert",
+ PredBB->getTerminator());
+ Res->setHasNoSignedWrap(cast<BinaryOperator>(Inst)->hasNoSignedWrap());
+ Res->setHasNoUnsignedWrap(cast<BinaryOperator>(Inst)->hasNoUnsignedWrap());
+ NewInsts.push_back(Res);
+ return Res;
+ }
+#endif
+
+ return nullptr;
+}
diff --git a/llvm/lib/Analysis/PhiValues.cpp b/llvm/lib/Analysis/PhiValues.cpp
new file mode 100644
index 000000000000..49749bc44746
--- /dev/null
+++ b/llvm/lib/Analysis/PhiValues.cpp
@@ -0,0 +1,212 @@
+//===- PhiValues.cpp - Phi Value Analysis ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/PhiValues.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/Instructions.h"
+
+using namespace llvm;
+
+void PhiValues::PhiValuesCallbackVH::deleted() {
+ PV->invalidateValue(getValPtr());
+}
+
+void PhiValues::PhiValuesCallbackVH::allUsesReplacedWith(Value *) {
+ // We could potentially update the cached values we have with the new value,
+ // but it's simpler to just treat the old value as invalidated.
+ PV->invalidateValue(getValPtr());
+}
+
+bool PhiValues::invalidate(Function &, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // PhiValues is invalidated if it isn't preserved.
+ auto PAC = PA.getChecker<PhiValuesAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>());
+}
+
+// The goal here is to find all of the non-phi values reachable from this phi,
+// and to do the same for all of the phis reachable from this phi, as doing so
+// is necessary anyway in order to get the values for this phi. We do this using
+// Tarjan's algorithm with Nuutila's improvements to find the strongly connected
+// components of the phi graph rooted in this phi:
+// * All phis in a strongly connected component will have the same reachable
+// non-phi values. The SCC may not be the maximal subgraph for that set of
+// reachable values, but finding out that isn't really necessary (it would
+// only reduce the amount of memory needed to store the values).
+// * Tarjan's algorithm completes components in a bottom-up manner, i.e. it
+// never completes a component before the components reachable from it have
+// been completed. This means that when we complete a component we have
+// everything we need to collect the values reachable from that component.
+// * We collect both the non-phi values reachable from each SCC, as that's what
+// we're ultimately interested in, and all of the reachable values, i.e.
+// including phis, as that makes invalidateValue easier.
+void PhiValues::processPhi(const PHINode *Phi,
+ SmallVector<const PHINode *, 8> &Stack) {
+ // Initialize the phi with the next depth number.
+ assert(DepthMap.lookup(Phi) == 0);
+ assert(NextDepthNumber != UINT_MAX);
+ unsigned int DepthNumber = ++NextDepthNumber;
+ DepthMap[Phi] = DepthNumber;
+
+ // Recursively process the incoming phis of this phi.
+ TrackedValues.insert(PhiValuesCallbackVH(const_cast<PHINode *>(Phi), this));
+ for (Value *PhiOp : Phi->incoming_values()) {
+ if (PHINode *PhiPhiOp = dyn_cast<PHINode>(PhiOp)) {
+ // Recurse if the phi has not yet been visited.
+ if (DepthMap.lookup(PhiPhiOp) == 0)
+ processPhi(PhiPhiOp, Stack);
+ assert(DepthMap.lookup(PhiPhiOp) != 0);
+ // If the phi did not become part of a component then this phi and that
+ // phi are part of the same component, so adjust the depth number.
+ if (!ReachableMap.count(DepthMap[PhiPhiOp]))
+ DepthMap[Phi] = std::min(DepthMap[Phi], DepthMap[PhiPhiOp]);
+ } else {
+ TrackedValues.insert(PhiValuesCallbackVH(PhiOp, this));
+ }
+ }
+
+ // Now that incoming phis have been handled, push this phi to the stack.
+ Stack.push_back(Phi);
+
+ // If the depth number has not changed then we've finished collecting the phis
+ // of a strongly connected component.
+ if (DepthMap[Phi] == DepthNumber) {
+ // Collect the reachable values for this component. The phis of this
+ // component will be those on top of the depth stach with the same or
+ // greater depth number.
+ ConstValueSet Reachable;
+ while (!Stack.empty() && DepthMap[Stack.back()] >= DepthNumber) {
+ const PHINode *ComponentPhi = Stack.pop_back_val();
+ Reachable.insert(ComponentPhi);
+ DepthMap[ComponentPhi] = DepthNumber;
+ for (Value *Op : ComponentPhi->incoming_values()) {
+ if (PHINode *PhiOp = dyn_cast<PHINode>(Op)) {
+ // If this phi is not part of the same component then that component
+ // is guaranteed to have been completed before this one. Therefore we
+ // can just add its reachable values to the reachable values of this
+ // component.
+ auto It = ReachableMap.find(DepthMap[PhiOp]);
+ if (It != ReachableMap.end())
+ Reachable.insert(It->second.begin(), It->second.end());
+ } else {
+ Reachable.insert(Op);
+ }
+ }
+ }
+ ReachableMap.insert({DepthNumber,Reachable});
+
+ // Filter out phis to get the non-phi reachable values.
+ ValueSet NonPhi;
+ for (const Value *V : Reachable)
+ if (!isa<PHINode>(V))
+ NonPhi.insert(const_cast<Value*>(V));
+ NonPhiReachableMap.insert({DepthNumber,NonPhi});
+ }
+}
+
+const PhiValues::ValueSet &PhiValues::getValuesForPhi(const PHINode *PN) {
+ if (DepthMap.count(PN) == 0) {
+ SmallVector<const PHINode *, 8> Stack;
+ processPhi(PN, Stack);
+ assert(Stack.empty());
+ }
+ assert(DepthMap.lookup(PN) != 0);
+ return NonPhiReachableMap[DepthMap[PN]];
+}
+
+void PhiValues::invalidateValue(const Value *V) {
+ // Components that can reach V are invalid.
+ SmallVector<unsigned int, 8> InvalidComponents;
+ for (auto &Pair : ReachableMap)
+ if (Pair.second.count(V))
+ InvalidComponents.push_back(Pair.first);
+
+ for (unsigned int N : InvalidComponents) {
+ for (const Value *V : ReachableMap[N])
+ if (const PHINode *PN = dyn_cast<PHINode>(V))
+ DepthMap.erase(PN);
+ NonPhiReachableMap.erase(N);
+ ReachableMap.erase(N);
+ }
+ // This value is no longer tracked
+ auto It = TrackedValues.find_as(V);
+ if (It != TrackedValues.end())
+ TrackedValues.erase(It);
+}
+
+void PhiValues::releaseMemory() {
+ DepthMap.clear();
+ NonPhiReachableMap.clear();
+ ReachableMap.clear();
+}
+
+void PhiValues::print(raw_ostream &OS) const {
+ // Iterate through the phi nodes of the function rather than iterating through
+ // DepthMap in order to get predictable ordering.
+ for (const BasicBlock &BB : F) {
+ for (const PHINode &PN : BB.phis()) {
+ OS << "PHI ";
+ PN.printAsOperand(OS, false);
+ OS << " has values:\n";
+ unsigned int N = DepthMap.lookup(&PN);
+ auto It = NonPhiReachableMap.find(N);
+ if (It == NonPhiReachableMap.end())
+ OS << " UNKNOWN\n";
+ else if (It->second.empty())
+ OS << " NONE\n";
+ else
+ for (Value *V : It->second)
+ // Printing of an instruction prints two spaces at the start, so
+ // handle instructions and everything else slightly differently in
+ // order to get consistent indenting.
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ OS << *I << "\n";
+ else
+ OS << " " << *V << "\n";
+ }
+ }
+}
+
+AnalysisKey PhiValuesAnalysis::Key;
+PhiValues PhiValuesAnalysis::run(Function &F, FunctionAnalysisManager &) {
+ return PhiValues(F);
+}
+
+PreservedAnalyses PhiValuesPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ OS << "PHI Values for function: " << F.getName() << "\n";
+ PhiValues &PI = AM.getResult<PhiValuesAnalysis>(F);
+ for (const BasicBlock &BB : F)
+ for (const PHINode &PN : BB.phis())
+ PI.getValuesForPhi(&PN);
+ PI.print(OS);
+ return PreservedAnalyses::all();
+}
+
+PhiValuesWrapperPass::PhiValuesWrapperPass() : FunctionPass(ID) {
+ initializePhiValuesWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool PhiValuesWrapperPass::runOnFunction(Function &F) {
+ Result.reset(new PhiValues(F));
+ return false;
+}
+
+void PhiValuesWrapperPass::releaseMemory() {
+ Result->releaseMemory();
+}
+
+void PhiValuesWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
+
+char PhiValuesWrapperPass::ID = 0;
+
+INITIALIZE_PASS(PhiValuesWrapperPass, "phi-values", "Phi Values Analysis", false,
+ true)
diff --git a/llvm/lib/Analysis/PostDominators.cpp b/llvm/lib/Analysis/PostDominators.cpp
new file mode 100644
index 000000000000..4afe22bd5342
--- /dev/null
+++ b/llvm/lib/Analysis/PostDominators.cpp
@@ -0,0 +1,84 @@
+//===- PostDominators.cpp - Post-Dominator Calculation --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the post-dominator construction algorithms.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "postdomtree"
+
+#ifdef EXPENSIVE_CHECKS
+static constexpr bool ExpensiveChecksEnabled = true;
+#else
+static constexpr bool ExpensiveChecksEnabled = false;
+#endif
+
+//===----------------------------------------------------------------------===//
+// PostDominatorTree Implementation
+//===----------------------------------------------------------------------===//
+
+char PostDominatorTreeWrapperPass::ID = 0;
+
+INITIALIZE_PASS(PostDominatorTreeWrapperPass, "postdomtree",
+ "Post-Dominator Tree Construction", true, true)
+
+bool PostDominatorTree::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // Check whether the analysis, all analyses on functions, or the function's
+ // CFG have been preserved.
+ auto PAC = PA.getChecker<PostDominatorTreeAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
+ PAC.preservedSet<CFGAnalyses>());
+}
+
+bool PostDominatorTreeWrapperPass::runOnFunction(Function &F) {
+ DT.recalculate(F);
+ return false;
+}
+
+void PostDominatorTreeWrapperPass::verifyAnalysis() const {
+ if (VerifyDomInfo)
+ assert(DT.verify(PostDominatorTree::VerificationLevel::Full));
+ else if (ExpensiveChecksEnabled)
+ assert(DT.verify(PostDominatorTree::VerificationLevel::Basic));
+}
+
+void PostDominatorTreeWrapperPass::print(raw_ostream &OS, const Module *) const {
+ DT.print(OS);
+}
+
+FunctionPass* llvm::createPostDomTree() {
+ return new PostDominatorTreeWrapperPass();
+}
+
+AnalysisKey PostDominatorTreeAnalysis::Key;
+
+PostDominatorTree PostDominatorTreeAnalysis::run(Function &F,
+ FunctionAnalysisManager &) {
+ PostDominatorTree PDT(F);
+ return PDT;
+}
+
+PostDominatorTreePrinterPass::PostDominatorTreePrinterPass(raw_ostream &OS)
+ : OS(OS) {}
+
+PreservedAnalyses
+PostDominatorTreePrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ OS << "PostDominatorTree for function: " << F.getName() << "\n";
+ AM.getResult<PostDominatorTreeAnalysis>(F).print(OS);
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
new file mode 100644
index 000000000000..b99b75715025
--- /dev/null
+++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
@@ -0,0 +1,392 @@
+//===- ProfileSummaryInfo.cpp - Global profile summary information --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass that provides access to the global profile summary
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/ProfileSummary.h"
+using namespace llvm;
+
+// The following two parameters determine the threshold for a count to be
+// considered hot/cold. These two parameters are percentile values (multiplied
+// by 10000). If the counts are sorted in descending order, the minimum count to
+// reach ProfileSummaryCutoffHot gives the threshold to determine a hot count.
+// Similarly, the minimum count to reach ProfileSummaryCutoffCold gives the
+// threshold for determining cold count (everything <= this threshold is
+// considered cold).
+
+static cl::opt<int> ProfileSummaryCutoffHot(
+ "profile-summary-cutoff-hot", cl::Hidden, cl::init(990000), cl::ZeroOrMore,
+ cl::desc("A count is hot if it exceeds the minimum count to"
+ " reach this percentile of total counts."));
+
+static cl::opt<int> ProfileSummaryCutoffCold(
+ "profile-summary-cutoff-cold", cl::Hidden, cl::init(999999), cl::ZeroOrMore,
+ cl::desc("A count is cold if it is below the minimum count"
+ " to reach this percentile of total counts."));
+
+static cl::opt<unsigned> ProfileSummaryHugeWorkingSetSizeThreshold(
+ "profile-summary-huge-working-set-size-threshold", cl::Hidden,
+ cl::init(15000), cl::ZeroOrMore,
+ cl::desc("The code working set size is considered huge if the number of"
+ " blocks required to reach the -profile-summary-cutoff-hot"
+ " percentile exceeds this count."));
+
+static cl::opt<unsigned> ProfileSummaryLargeWorkingSetSizeThreshold(
+ "profile-summary-large-working-set-size-threshold", cl::Hidden,
+ cl::init(12500), cl::ZeroOrMore,
+ cl::desc("The code working set size is considered large if the number of"
+ " blocks required to reach the -profile-summary-cutoff-hot"
+ " percentile exceeds this count."));
+
+// The next two options override the counts derived from summary computation and
+// are useful for debugging purposes.
+static cl::opt<int> ProfileSummaryHotCount(
+ "profile-summary-hot-count", cl::ReallyHidden, cl::ZeroOrMore,
+ cl::desc("A fixed hot count that overrides the count derived from"
+ " profile-summary-cutoff-hot"));
+
+static cl::opt<int> ProfileSummaryColdCount(
+ "profile-summary-cold-count", cl::ReallyHidden, cl::ZeroOrMore,
+ cl::desc("A fixed cold count that overrides the count derived from"
+ " profile-summary-cutoff-cold"));
+
+// Find the summary entry for a desired percentile of counts.
+static const ProfileSummaryEntry &getEntryForPercentile(SummaryEntryVector &DS,
+ uint64_t Percentile) {
+ auto It = partition_point(DS, [=](const ProfileSummaryEntry &Entry) {
+ return Entry.Cutoff < Percentile;
+ });
+ // The required percentile has to be <= one of the percentiles in the
+ // detailed summary.
+ if (It == DS.end())
+ report_fatal_error("Desired percentile exceeds the maximum cutoff");
+ return *It;
+}
+
+// The profile summary metadata may be attached either by the frontend or by
+// any backend passes (IR level instrumentation, for example). This method
+// checks if the Summary is null and if so checks if the summary metadata is now
+// available in the module and parses it to get the Summary object. Returns true
+// if a valid Summary is available.
+bool ProfileSummaryInfo::computeSummary() {
+ if (Summary)
+ return true;
+ // First try to get context sensitive ProfileSummary.
+ auto *SummaryMD = M.getProfileSummary(/* IsCS */ true);
+ if (SummaryMD) {
+ Summary.reset(ProfileSummary::getFromMD(SummaryMD));
+ return true;
+ }
+ // This will actually return PSK_Instr or PSK_Sample summary.
+ SummaryMD = M.getProfileSummary(/* IsCS */ false);
+ if (!SummaryMD)
+ return false;
+ Summary.reset(ProfileSummary::getFromMD(SummaryMD));
+ return true;
+}
+
+Optional<uint64_t>
+ProfileSummaryInfo::getProfileCount(const Instruction *Inst,
+ BlockFrequencyInfo *BFI,
+ bool AllowSynthetic) {
+ if (!Inst)
+ return None;
+ assert((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) &&
+ "We can only get profile count for call/invoke instruction.");
+ if (hasSampleProfile()) {
+ // In sample PGO mode, check if there is a profile metadata on the
+ // instruction. If it is present, determine hotness solely based on that,
+ // since the sampled entry count may not be accurate. If there is no
+ // annotated on the instruction, return None.
+ uint64_t TotalCount;
+ if (Inst->extractProfTotalWeight(TotalCount))
+ return TotalCount;
+ return None;
+ }
+ if (BFI)
+ return BFI->getBlockProfileCount(Inst->getParent(), AllowSynthetic);
+ return None;
+}
+
+/// Returns true if the function's entry is hot. If it returns false, it
+/// either means it is not hot or it is unknown whether it is hot or not (for
+/// example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) {
+ if (!F || !computeSummary())
+ return false;
+ auto FunctionCount = F->getEntryCount();
+ // FIXME: The heuristic used below for determining hotness is based on
+ // preliminary SPEC tuning for inliner. This will eventually be a
+ // convenience method that calls isHotCount.
+ return FunctionCount && isHotCount(FunctionCount.getCount());
+}
+
+/// Returns true if the function contains hot code. This can include a hot
+/// function entry count, hot basic block, or (in the case of Sample PGO)
+/// hot total call edge count.
+/// If it returns false, it either means it is not hot or it is unknown
+/// (for example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionHotInCallGraph(const Function *F,
+ BlockFrequencyInfo &BFI) {
+ if (!F || !computeSummary())
+ return false;
+ if (auto FunctionCount = F->getEntryCount())
+ if (isHotCount(FunctionCount.getCount()))
+ return true;
+
+ if (hasSampleProfile()) {
+ uint64_t TotalCallCount = 0;
+ for (const auto &BB : *F)
+ for (const auto &I : BB)
+ if (isa<CallInst>(I) || isa<InvokeInst>(I))
+ if (auto CallCount = getProfileCount(&I, nullptr))
+ TotalCallCount += CallCount.getValue();
+ if (isHotCount(TotalCallCount))
+ return true;
+ }
+ for (const auto &BB : *F)
+ if (isHotBlock(&BB, &BFI))
+ return true;
+ return false;
+}
+
+/// Returns true if the function only contains cold code. This means that
+/// the function entry and blocks are all cold, and (in the case of Sample PGO)
+/// the total call edge count is cold.
+/// If it returns false, it either means it is not cold or it is unknown
+/// (for example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionColdInCallGraph(const Function *F,
+ BlockFrequencyInfo &BFI) {
+ if (!F || !computeSummary())
+ return false;
+ if (auto FunctionCount = F->getEntryCount())
+ if (!isColdCount(FunctionCount.getCount()))
+ return false;
+
+ if (hasSampleProfile()) {
+ uint64_t TotalCallCount = 0;
+ for (const auto &BB : *F)
+ for (const auto &I : BB)
+ if (isa<CallInst>(I) || isa<InvokeInst>(I))
+ if (auto CallCount = getProfileCount(&I, nullptr))
+ TotalCallCount += CallCount.getValue();
+ if (!isColdCount(TotalCallCount))
+ return false;
+ }
+ for (const auto &BB : *F)
+ if (!isColdBlock(&BB, &BFI))
+ return false;
+ return true;
+}
+
+// Like isFunctionHotInCallGraph but for a given cutoff.
+bool ProfileSummaryInfo::isFunctionHotInCallGraphNthPercentile(
+ int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) {
+ if (!F || !computeSummary())
+ return false;
+ if (auto FunctionCount = F->getEntryCount())
+ if (isHotCountNthPercentile(PercentileCutoff, FunctionCount.getCount()))
+ return true;
+
+ if (hasSampleProfile()) {
+ uint64_t TotalCallCount = 0;
+ for (const auto &BB : *F)
+ for (const auto &I : BB)
+ if (isa<CallInst>(I) || isa<InvokeInst>(I))
+ if (auto CallCount = getProfileCount(&I, nullptr))
+ TotalCallCount += CallCount.getValue();
+ if (isHotCountNthPercentile(PercentileCutoff, TotalCallCount))
+ return true;
+ }
+ for (const auto &BB : *F)
+ if (isHotBlockNthPercentile(PercentileCutoff, &BB, &BFI))
+ return true;
+ return false;
+}
+
+/// Returns true if the function's entry is a cold. If it returns false, it
+/// either means it is not cold or it is unknown whether it is cold or not (for
+/// example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) {
+ if (!F)
+ return false;
+ if (F->hasFnAttribute(Attribute::Cold))
+ return true;
+ if (!computeSummary())
+ return false;
+ auto FunctionCount = F->getEntryCount();
+ // FIXME: The heuristic used below for determining coldness is based on
+ // preliminary SPEC tuning for inliner. This will eventually be a
+ // convenience method that calls isHotCount.
+ return FunctionCount && isColdCount(FunctionCount.getCount());
+}
+
+/// Compute the hot and cold thresholds.
+void ProfileSummaryInfo::computeThresholds() {
+ if (!computeSummary())
+ return;
+ auto &DetailedSummary = Summary->getDetailedSummary();
+ auto &HotEntry =
+ getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffHot);
+ HotCountThreshold = HotEntry.MinCount;
+ if (ProfileSummaryHotCount.getNumOccurrences() > 0)
+ HotCountThreshold = ProfileSummaryHotCount;
+ auto &ColdEntry =
+ getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffCold);
+ ColdCountThreshold = ColdEntry.MinCount;
+ if (ProfileSummaryColdCount.getNumOccurrences() > 0)
+ ColdCountThreshold = ProfileSummaryColdCount;
+ assert(ColdCountThreshold <= HotCountThreshold &&
+ "Cold count threshold cannot exceed hot count threshold!");
+ HasHugeWorkingSetSize =
+ HotEntry.NumCounts > ProfileSummaryHugeWorkingSetSizeThreshold;
+ HasLargeWorkingSetSize =
+ HotEntry.NumCounts > ProfileSummaryLargeWorkingSetSizeThreshold;
+}
+
+Optional<uint64_t> ProfileSummaryInfo::computeThreshold(int PercentileCutoff) {
+ if (!computeSummary())
+ return None;
+ auto iter = ThresholdCache.find(PercentileCutoff);
+ if (iter != ThresholdCache.end()) {
+ return iter->second;
+ }
+ auto &DetailedSummary = Summary->getDetailedSummary();
+ auto &Entry =
+ getEntryForPercentile(DetailedSummary, PercentileCutoff);
+ uint64_t CountThreshold = Entry.MinCount;
+ ThresholdCache[PercentileCutoff] = CountThreshold;
+ return CountThreshold;
+}
+
+bool ProfileSummaryInfo::hasHugeWorkingSetSize() {
+ if (!HasHugeWorkingSetSize)
+ computeThresholds();
+ return HasHugeWorkingSetSize && HasHugeWorkingSetSize.getValue();
+}
+
+bool ProfileSummaryInfo::hasLargeWorkingSetSize() {
+ if (!HasLargeWorkingSetSize)
+ computeThresholds();
+ return HasLargeWorkingSetSize && HasLargeWorkingSetSize.getValue();
+}
+
+bool ProfileSummaryInfo::isHotCount(uint64_t C) {
+ if (!HotCountThreshold)
+ computeThresholds();
+ return HotCountThreshold && C >= HotCountThreshold.getValue();
+}
+
+bool ProfileSummaryInfo::isColdCount(uint64_t C) {
+ if (!ColdCountThreshold)
+ computeThresholds();
+ return ColdCountThreshold && C <= ColdCountThreshold.getValue();
+}
+
+bool ProfileSummaryInfo::isHotCountNthPercentile(int PercentileCutoff, uint64_t C) {
+ auto CountThreshold = computeThreshold(PercentileCutoff);
+ return CountThreshold && C >= CountThreshold.getValue();
+}
+
+uint64_t ProfileSummaryInfo::getOrCompHotCountThreshold() {
+ if (!HotCountThreshold)
+ computeThresholds();
+ return HotCountThreshold ? HotCountThreshold.getValue() : UINT64_MAX;
+}
+
+uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() {
+ if (!ColdCountThreshold)
+ computeThresholds();
+ return ColdCountThreshold ? ColdCountThreshold.getValue() : 0;
+}
+
+bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) {
+ auto Count = BFI->getBlockProfileCount(BB);
+ return Count && isHotCount(*Count);
+}
+
+bool ProfileSummaryInfo::isColdBlock(const BasicBlock *BB,
+ BlockFrequencyInfo *BFI) {
+ auto Count = BFI->getBlockProfileCount(BB);
+ return Count && isColdCount(*Count);
+}
+
+bool ProfileSummaryInfo::isHotBlockNthPercentile(int PercentileCutoff,
+ const BasicBlock *BB,
+ BlockFrequencyInfo *BFI) {
+ auto Count = BFI->getBlockProfileCount(BB);
+ return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
+}
+
+bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS,
+ BlockFrequencyInfo *BFI) {
+ auto C = getProfileCount(CS.getInstruction(), BFI);
+ return C && isHotCount(*C);
+}
+
+bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS,
+ BlockFrequencyInfo *BFI) {
+ auto C = getProfileCount(CS.getInstruction(), BFI);
+ if (C)
+ return isColdCount(*C);
+
+ // In SamplePGO, if the caller has been sampled, and there is no profile
+ // annotated on the callsite, we consider the callsite as cold.
+ return hasSampleProfile() && CS.getCaller()->hasProfileData();
+}
+
+INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info",
+ "Profile summary info", false, true)
+
+ProfileSummaryInfoWrapperPass::ProfileSummaryInfoWrapperPass()
+ : ImmutablePass(ID) {
+ initializeProfileSummaryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool ProfileSummaryInfoWrapperPass::doInitialization(Module &M) {
+ PSI.reset(new ProfileSummaryInfo(M));
+ return false;
+}
+
+bool ProfileSummaryInfoWrapperPass::doFinalization(Module &M) {
+ PSI.reset();
+ return false;
+}
+
+AnalysisKey ProfileSummaryAnalysis::Key;
+ProfileSummaryInfo ProfileSummaryAnalysis::run(Module &M,
+ ModuleAnalysisManager &) {
+ return ProfileSummaryInfo(M);
+}
+
+PreservedAnalyses ProfileSummaryPrinterPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M);
+
+ OS << "Functions in " << M.getName() << " with hot/cold annotations: \n";
+ for (auto &F : M) {
+ OS << F.getName();
+ if (PSI.isFunctionEntryHot(&F))
+ OS << " :hot entry ";
+ else if (PSI.isFunctionEntryCold(&F))
+ OS << " :cold entry ";
+ OS << "\n";
+ }
+ return PreservedAnalyses::all();
+}
+
+char ProfileSummaryInfoWrapperPass::ID = 0;
diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp
new file mode 100644
index 000000000000..9a834ba4866a
--- /dev/null
+++ b/llvm/lib/Analysis/PtrUseVisitor.cpp
@@ -0,0 +1,44 @@
+//===- PtrUseVisitor.cpp - InstVisitors over a pointers uses --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Implementation of the pointer use visitors.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/PtrUseVisitor.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include <algorithm>
+
+using namespace llvm;
+
+void detail::PtrUseVisitorBase::enqueueUsers(Instruction &I) {
+ for (Use &U : I.uses()) {
+ if (VisitedUses.insert(&U).second) {
+ UseToVisit NewU = {
+ UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown),
+ Offset
+ };
+ Worklist.push_back(std::move(NewU));
+ }
+ }
+}
+
+bool detail::PtrUseVisitorBase::adjustOffsetForGEP(GetElementPtrInst &GEPI) {
+ if (!IsOffsetKnown)
+ return false;
+
+ APInt TmpOffset(DL.getIndexTypeSizeInBits(GEPI.getType()), 0);
+ if (GEPI.accumulateConstantOffset(DL, TmpOffset)) {
+ Offset += TmpOffset.sextOrTrunc(Offset.getBitWidth());
+ return true;
+ }
+
+ return false;
+}
diff --git a/llvm/lib/Analysis/RegionInfo.cpp b/llvm/lib/Analysis/RegionInfo.cpp
new file mode 100644
index 000000000000..8ba38adfb0d2
--- /dev/null
+++ b/llvm/lib/Analysis/RegionInfo.cpp
@@ -0,0 +1,215 @@
+//===- RegionInfo.cpp - SESE region detection analysis --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Detects single entry single exit regions in the control flow graph.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/RegionInfo.h"
+#include "llvm/ADT/Statistic.h"
+#ifndef NDEBUG
+#include "llvm/Analysis/RegionPrinter.h"
+#endif
+#include "llvm/Analysis/RegionInfoImpl.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "region"
+
+namespace llvm {
+
+template class RegionBase<RegionTraits<Function>>;
+template class RegionNodeBase<RegionTraits<Function>>;
+template class RegionInfoBase<RegionTraits<Function>>;
+
+} // end namespace llvm
+
+STATISTIC(numRegions, "The # of regions");
+STATISTIC(numSimpleRegions, "The # of simple regions");
+
+// Always verify if expensive checking is enabled.
+
+static cl::opt<bool,true>
+VerifyRegionInfoX(
+ "verify-region-info",
+ cl::location(RegionInfoBase<RegionTraits<Function>>::VerifyRegionInfo),
+ cl::desc("Verify region info (time consuming)"));
+
+static cl::opt<Region::PrintStyle, true> printStyleX("print-region-style",
+ cl::location(RegionInfo::printStyle),
+ cl::Hidden,
+ cl::desc("style of printing regions"),
+ cl::values(
+ clEnumValN(Region::PrintNone, "none", "print no details"),
+ clEnumValN(Region::PrintBB, "bb",
+ "print regions in detail with block_iterator"),
+ clEnumValN(Region::PrintRN, "rn",
+ "print regions in detail with element_iterator")));
+
+//===----------------------------------------------------------------------===//
+// Region implementation
+//
+
+Region::Region(BasicBlock *Entry, BasicBlock *Exit,
+ RegionInfo* RI,
+ DominatorTree *DT, Region *Parent) :
+ RegionBase<RegionTraits<Function>>(Entry, Exit, RI, DT, Parent) {
+
+}
+
+Region::~Region() = default;
+
+//===----------------------------------------------------------------------===//
+// RegionInfo implementation
+//
+
+RegionInfo::RegionInfo() = default;
+
+RegionInfo::~RegionInfo() = default;
+
+bool RegionInfo::invalidate(Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &) {
+ // Check whether the analysis, all analyses on functions, or the function's
+ // CFG has been preserved.
+ auto PAC = PA.getChecker<RegionInfoAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
+ PAC.preservedSet<CFGAnalyses>());
+}
+
+void RegionInfo::updateStatistics(Region *R) {
+ ++numRegions;
+
+ // TODO: Slow. Should only be enabled if -stats is used.
+ if (R->isSimple())
+ ++numSimpleRegions;
+}
+
+void RegionInfo::recalculate(Function &F, DominatorTree *DT_,
+ PostDominatorTree *PDT_, DominanceFrontier *DF_) {
+ DT = DT_;
+ PDT = PDT_;
+ DF = DF_;
+
+ TopLevelRegion = new Region(&F.getEntryBlock(), nullptr,
+ this, DT, nullptr);
+ updateStatistics(TopLevelRegion);
+ calculate(F);
+}
+
+#ifndef NDEBUG
+void RegionInfo::view() { viewRegion(this); }
+
+void RegionInfo::viewOnly() { viewRegionOnly(this); }
+#endif
+
+//===----------------------------------------------------------------------===//
+// RegionInfoPass implementation
+//
+
+RegionInfoPass::RegionInfoPass() : FunctionPass(ID) {
+ initializeRegionInfoPassPass(*PassRegistry::getPassRegistry());
+}
+
+RegionInfoPass::~RegionInfoPass() = default;
+
+bool RegionInfoPass::runOnFunction(Function &F) {
+ releaseMemory();
+
+ auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
+ auto DF = &getAnalysis<DominanceFrontierWrapperPass>().getDominanceFrontier();
+
+ RI.recalculate(F, DT, PDT, DF);
+ return false;
+}
+
+void RegionInfoPass::releaseMemory() {
+ RI.releaseMemory();
+}
+
+void RegionInfoPass::verifyAnalysis() const {
+ RI.verifyAnalysis();
+}
+
+void RegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<DominatorTreeWrapperPass>();
+ AU.addRequired<PostDominatorTreeWrapperPass>();
+ AU.addRequired<DominanceFrontierWrapperPass>();
+}
+
+void RegionInfoPass::print(raw_ostream &OS, const Module *) const {
+ RI.print(OS);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void RegionInfoPass::dump() const {
+ RI.dump();
+}
+#endif
+
+char RegionInfoPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(RegionInfoPass, "regions",
+ "Detect single entry single exit regions", true, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominanceFrontierWrapperPass)
+INITIALIZE_PASS_END(RegionInfoPass, "regions",
+ "Detect single entry single exit regions", true, true)
+
+// Create methods available outside of this file, to use them
+// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
+// the link time optimization.
+
+namespace llvm {
+
+ FunctionPass *createRegionInfoPass() {
+ return new RegionInfoPass();
+ }
+
+} // end namespace llvm
+
+//===----------------------------------------------------------------------===//
+// RegionInfoAnalysis implementation
+//
+
+AnalysisKey RegionInfoAnalysis::Key;
+
+RegionInfo RegionInfoAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+ RegionInfo RI;
+ auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
+ auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F);
+ auto *DF = &AM.getResult<DominanceFrontierAnalysis>(F);
+
+ RI.recalculate(F, DT, PDT, DF);
+ return RI;
+}
+
+RegionInfoPrinterPass::RegionInfoPrinterPass(raw_ostream &OS)
+ : OS(OS) {}
+
+PreservedAnalyses RegionInfoPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ OS << "Region Tree for function: " << F.getName() << "\n";
+ AM.getResult<RegionInfoAnalysis>(F).print(OS);
+
+ return PreservedAnalyses::all();
+}
+
+PreservedAnalyses RegionInfoVerifierPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AM.getResult<RegionInfoAnalysis>(F).verifyAnalysis();
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Analysis/RegionPass.cpp b/llvm/lib/Analysis/RegionPass.cpp
new file mode 100644
index 000000000000..6c0d17b45c62
--- /dev/null
+++ b/llvm/lib/Analysis/RegionPass.cpp
@@ -0,0 +1,299 @@
+//===- RegionPass.cpp - Region Pass and Region Pass Manager ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements RegionPass and RGPassManager. All region optimization
+// and transformation passes are derived from RegionPass. RGPassManager is
+// responsible for managing RegionPasses.
+// Most of this code has been COPIED from LoopPass.cpp
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/Analysis/RegionPass.h"
+#include "llvm/IR/OptBisect.h"
+#include "llvm/IR/PassTimingInfo.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Timer.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace llvm;
+
+#define DEBUG_TYPE "regionpassmgr"
+
+//===----------------------------------------------------------------------===//
+// RGPassManager
+//
+
+char RGPassManager::ID = 0;
+
+RGPassManager::RGPassManager()
+ : FunctionPass(ID), PMDataManager() {
+ skipThisRegion = false;
+ redoThisRegion = false;
+ RI = nullptr;
+ CurrentRegion = nullptr;
+}
+
+// Recurse through all subregions and all regions into RQ.
+static void addRegionIntoQueue(Region &R, std::deque<Region *> &RQ) {
+ RQ.push_back(&R);
+ for (const auto &E : R)
+ addRegionIntoQueue(*E, RQ);
+}
+
+/// Pass Manager itself does not invalidate any analysis info.
+void RGPassManager::getAnalysisUsage(AnalysisUsage &Info) const {
+ Info.addRequired<RegionInfoPass>();
+ Info.setPreservesAll();
+}
+
+/// run - Execute all of the passes scheduled for execution. Keep track of
+/// whether any of the passes modifies the function, and if so, return true.
+bool RGPassManager::runOnFunction(Function &F) {
+ RI = &getAnalysis<RegionInfoPass>().getRegionInfo();
+ bool Changed = false;
+
+ // Collect inherited analysis from Module level pass manager.
+ populateInheritedAnalysis(TPM->activeStack);
+
+ addRegionIntoQueue(*RI->getTopLevelRegion(), RQ);
+
+ if (RQ.empty()) // No regions, skip calling finalizers
+ return false;
+
+ // Initialization
+ for (Region *R : RQ) {
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ RegionPass *RP = (RegionPass *)getContainedPass(Index);
+ Changed |= RP->doInitialization(R, *this);
+ }
+ }
+
+ // Walk Regions
+ while (!RQ.empty()) {
+
+ CurrentRegion = RQ.back();
+ skipThisRegion = false;
+ redoThisRegion = false;
+
+ // Run all passes on the current Region.
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ RegionPass *P = (RegionPass*)getContainedPass(Index);
+
+ if (isPassDebuggingExecutionsOrMore()) {
+ dumpPassInfo(P, EXECUTION_MSG, ON_REGION_MSG,
+ CurrentRegion->getNameStr());
+ dumpRequiredSet(P);
+ }
+
+ initializeAnalysisImpl(P);
+
+ {
+ PassManagerPrettyStackEntry X(P, *CurrentRegion->getEntry());
+
+ TimeRegion PassTimer(getPassTimer(P));
+ Changed |= P->runOnRegion(CurrentRegion, *this);
+ }
+
+ if (isPassDebuggingExecutionsOrMore()) {
+ if (Changed)
+ dumpPassInfo(P, MODIFICATION_MSG, ON_REGION_MSG,
+ skipThisRegion ? "<deleted>" :
+ CurrentRegion->getNameStr());
+ dumpPreservedSet(P);
+ }
+
+ if (!skipThisRegion) {
+ // Manually check that this region is still healthy. This is done
+ // instead of relying on RegionInfo::verifyRegion since RegionInfo
+ // is a function pass and it's really expensive to verify every
+ // Region in the function every time. That level of checking can be
+ // enabled with the -verify-region-info option.
+ {
+ TimeRegion PassTimer(getPassTimer(P));
+ CurrentRegion->verifyRegion();
+ }
+
+ // Then call the regular verifyAnalysis functions.
+ verifyPreservedAnalysis(P);
+ }
+
+ removeNotPreservedAnalysis(P);
+ recordAvailableAnalysis(P);
+ removeDeadPasses(P,
+ (!isPassDebuggingExecutionsOrMore() || skipThisRegion) ?
+ "<deleted>" : CurrentRegion->getNameStr(),
+ ON_REGION_MSG);
+
+ if (skipThisRegion)
+ // Do not run other passes on this region.
+ break;
+ }
+
+ // If the region was deleted, release all the region passes. This frees up
+ // some memory, and avoids trouble with the pass manager trying to call
+ // verifyAnalysis on them.
+ if (skipThisRegion)
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ Pass *P = getContainedPass(Index);
+ freePass(P, "<deleted>", ON_REGION_MSG);
+ }
+
+ // Pop the region from queue after running all passes.
+ RQ.pop_back();
+
+ if (redoThisRegion)
+ RQ.push_back(CurrentRegion);
+
+ // Free all region nodes created in region passes.
+ RI->clearNodeCache();
+ }
+
+ // Finalization
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ RegionPass *P = (RegionPass*)getContainedPass(Index);
+ Changed |= P->doFinalization();
+ }
+
+ // Print the region tree after all pass.
+ LLVM_DEBUG(dbgs() << "\nRegion tree of function " << F.getName()
+ << " after all region Pass:\n";
+ RI->dump(); dbgs() << "\n";);
+
+ return Changed;
+}
+
+/// Print passes managed by this manager
+void RGPassManager::dumpPassStructure(unsigned Offset) {
+ errs().indent(Offset*2) << "Region Pass Manager\n";
+ for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) {
+ Pass *P = getContainedPass(Index);
+ P->dumpPassStructure(Offset + 1);
+ dumpLastUses(P, Offset+1);
+ }
+}
+
+namespace {
+//===----------------------------------------------------------------------===//
+// PrintRegionPass
+class PrintRegionPass : public RegionPass {
+private:
+ std::string Banner;
+ raw_ostream &Out; // raw_ostream to print on.
+
+public:
+ static char ID;
+ PrintRegionPass(const std::string &B, raw_ostream &o)
+ : RegionPass(ID), Banner(B), Out(o) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+
+ bool runOnRegion(Region *R, RGPassManager &RGM) override {
+ Out << Banner;
+ for (const auto *BB : R->blocks()) {
+ if (BB)
+ BB->print(Out);
+ else
+ Out << "Printing <null> Block";
+ }
+
+ return false;
+ }
+
+ StringRef getPassName() const override { return "Print Region IR"; }
+};
+
+char PrintRegionPass::ID = 0;
+} //end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// RegionPass
+
+// Check if this pass is suitable for the current RGPassManager, if
+// available. This pass P is not suitable for a RGPassManager if P
+// is not preserving higher level analysis info used by other
+// RGPassManager passes. In such case, pop RGPassManager from the
+// stack. This will force assignPassManager() to create new
+// LPPassManger as expected.
+void RegionPass::preparePassManager(PMStack &PMS) {
+
+ // Find RGPassManager
+ while (!PMS.empty() &&
+ PMS.top()->getPassManagerType() > PMT_RegionPassManager)
+ PMS.pop();
+
+
+ // If this pass is destroying high level information that is used
+ // by other passes that are managed by LPM then do not insert
+ // this pass in current LPM. Use new RGPassManager.
+ if (PMS.top()->getPassManagerType() == PMT_RegionPassManager &&
+ !PMS.top()->preserveHigherLevelAnalysis(this))
+ PMS.pop();
+}
+
+/// Assign pass manager to manage this pass.
+void RegionPass::assignPassManager(PMStack &PMS,
+ PassManagerType PreferredType) {
+ // Find RGPassManager
+ while (!PMS.empty() &&
+ PMS.top()->getPassManagerType() > PMT_RegionPassManager)
+ PMS.pop();
+
+ RGPassManager *RGPM;
+
+ // Create new Region Pass Manager if it does not exist.
+ if (PMS.top()->getPassManagerType() == PMT_RegionPassManager)
+ RGPM = (RGPassManager*)PMS.top();
+ else {
+
+ assert (!PMS.empty() && "Unable to create Region Pass Manager");
+ PMDataManager *PMD = PMS.top();
+
+ // [1] Create new Region Pass Manager
+ RGPM = new RGPassManager();
+ RGPM->populateInheritedAnalysis(PMS);
+
+ // [2] Set up new manager's top level manager
+ PMTopLevelManager *TPM = PMD->getTopLevelManager();
+ TPM->addIndirectPassManager(RGPM);
+
+ // [3] Assign manager to manage this new manager. This may create
+ // and push new managers into PMS
+ TPM->schedulePass(RGPM);
+
+ // [4] Push new manager into PMS
+ PMS.push(RGPM);
+ }
+
+ RGPM->add(this);
+}
+
+/// Get the printer pass
+Pass *RegionPass::createPrinterPass(raw_ostream &O,
+ const std::string &Banner) const {
+ return new PrintRegionPass(Banner, O);
+}
+
+static std::string getDescription(const Region &R) {
+ return "region";
+}
+
+bool RegionPass::skipRegion(Region &R) const {
+ Function &F = *R.getEntry()->getParent();
+ OptPassGate &Gate = F.getContext().getOptPassGate();
+ if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(R)))
+ return true;
+
+ if (F.hasOptNone()) {
+ // Report this only once per function.
+ if (R.getEntry() == &F.getEntryBlock())
+ LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName()
+ << "' on function " << F.getName() << "\n");
+ return true;
+ }
+ return false;
+}
diff --git a/llvm/lib/Analysis/RegionPrinter.cpp b/llvm/lib/Analysis/RegionPrinter.cpp
new file mode 100644
index 000000000000..5bdcb31fbe99
--- /dev/null
+++ b/llvm/lib/Analysis/RegionPrinter.cpp
@@ -0,0 +1,266 @@
+//===- RegionPrinter.cpp - Print regions tree pass ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Print out the region tree of a function using dotty/graphviz.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/RegionPrinter.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/DOTGraphTraitsPass.h"
+#include "llvm/Analysis/Passes.h"
+#include "llvm/Analysis/RegionInfo.h"
+#include "llvm/Analysis/RegionIterator.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#ifndef NDEBUG
+#include "llvm/IR/LegacyPassManager.h"
+#endif
+
+using namespace llvm;
+
+//===----------------------------------------------------------------------===//
+/// onlySimpleRegion - Show only the simple regions in the RegionViewer.
+static cl::opt<bool>
+onlySimpleRegions("only-simple-regions",
+ cl::desc("Show only simple regions in the graphviz viewer"),
+ cl::Hidden,
+ cl::init(false));
+
+namespace llvm {
+template<>
+struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits {
+
+ DOTGraphTraits (bool isSimple=false)
+ : DefaultDOTGraphTraits(isSimple) {}
+
+ std::string getNodeLabel(RegionNode *Node, RegionNode *Graph) {
+
+ if (!Node->isSubRegion()) {
+ BasicBlock *BB = Node->getNodeAs<BasicBlock>();
+
+ if (isSimple())
+ return DOTGraphTraits<const Function*>
+ ::getSimpleNodeLabel(BB, BB->getParent());
+ else
+ return DOTGraphTraits<const Function*>
+ ::getCompleteNodeLabel(BB, BB->getParent());
+ }
+
+ return "Not implemented";
+ }
+};
+
+template <>
+struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> {
+
+ DOTGraphTraits (bool isSimple = false)
+ : DOTGraphTraits<RegionNode*>(isSimple) {}
+
+ static std::string getGraphName(const RegionInfo *) { return "Region Graph"; }
+
+ std::string getNodeLabel(RegionNode *Node, RegionInfo *G) {
+ return DOTGraphTraits<RegionNode *>::getNodeLabel(
+ Node, reinterpret_cast<RegionNode *>(G->getTopLevelRegion()));
+ }
+
+ std::string getEdgeAttributes(RegionNode *srcNode,
+ GraphTraits<RegionInfo *>::ChildIteratorType CI,
+ RegionInfo *G) {
+ RegionNode *destNode = *CI;
+
+ if (srcNode->isSubRegion() || destNode->isSubRegion())
+ return "";
+
+ // In case of a backedge, do not use it to define the layout of the nodes.
+ BasicBlock *srcBB = srcNode->getNodeAs<BasicBlock>();
+ BasicBlock *destBB = destNode->getNodeAs<BasicBlock>();
+
+ Region *R = G->getRegionFor(destBB);
+
+ while (R && R->getParent())
+ if (R->getParent()->getEntry() == destBB)
+ R = R->getParent();
+ else
+ break;
+
+ if (R && R->getEntry() == destBB && R->contains(srcBB))
+ return "constraint=false";
+
+ return "";
+ }
+
+ // Print the cluster of the subregions. This groups the single basic blocks
+ // and adds a different background color for each group.
+ static void printRegionCluster(const Region &R, GraphWriter<RegionInfo *> &GW,
+ unsigned depth = 0) {
+ raw_ostream &O = GW.getOStream();
+ O.indent(2 * depth) << "subgraph cluster_" << static_cast<const void*>(&R)
+ << " {\n";
+ O.indent(2 * (depth + 1)) << "label = \"\";\n";
+
+ if (!onlySimpleRegions || R.isSimple()) {
+ O.indent(2 * (depth + 1)) << "style = filled;\n";
+ O.indent(2 * (depth + 1)) << "color = "
+ << ((R.getDepth() * 2 % 12) + 1) << "\n";
+
+ } else {
+ O.indent(2 * (depth + 1)) << "style = solid;\n";
+ O.indent(2 * (depth + 1)) << "color = "
+ << ((R.getDepth() * 2 % 12) + 2) << "\n";
+ }
+
+ for (const auto &RI : R)
+ printRegionCluster(*RI, GW, depth + 1);
+
+ const RegionInfo &RI = *static_cast<const RegionInfo*>(R.getRegionInfo());
+
+ for (auto *BB : R.blocks())
+ if (RI.getRegionFor(BB) == &R)
+ O.indent(2 * (depth + 1)) << "Node"
+ << static_cast<const void*>(RI.getTopLevelRegion()->getBBNode(BB))
+ << ";\n";
+
+ O.indent(2 * depth) << "}\n";
+ }
+
+ static void addCustomGraphFeatures(const RegionInfo *G,
+ GraphWriter<RegionInfo *> &GW) {
+ raw_ostream &O = GW.getOStream();
+ O << "\tcolorscheme = \"paired12\"\n";
+ printRegionCluster(*G->getTopLevelRegion(), GW, 4);
+ }
+};
+} //end namespace llvm
+
+namespace {
+
+struct RegionInfoPassGraphTraits {
+ static RegionInfo *getGraph(RegionInfoPass *RIP) {
+ return &RIP->getRegionInfo();
+ }
+};
+
+struct RegionPrinter
+ : public DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *,
+ RegionInfoPassGraphTraits> {
+ static char ID;
+ RegionPrinter()
+ : DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *,
+ RegionInfoPassGraphTraits>("reg", ID) {
+ initializeRegionPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+char RegionPrinter::ID = 0;
+
+struct RegionOnlyPrinter
+ : public DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *,
+ RegionInfoPassGraphTraits> {
+ static char ID;
+ RegionOnlyPrinter()
+ : DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *,
+ RegionInfoPassGraphTraits>("reg", ID) {
+ initializeRegionOnlyPrinterPass(*PassRegistry::getPassRegistry());
+ }
+};
+char RegionOnlyPrinter::ID = 0;
+
+struct RegionViewer
+ : public DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *,
+ RegionInfoPassGraphTraits> {
+ static char ID;
+ RegionViewer()
+ : DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *,
+ RegionInfoPassGraphTraits>("reg", ID) {
+ initializeRegionViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+char RegionViewer::ID = 0;
+
+struct RegionOnlyViewer
+ : public DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *,
+ RegionInfoPassGraphTraits> {
+ static char ID;
+ RegionOnlyViewer()
+ : DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *,
+ RegionInfoPassGraphTraits>("regonly", ID) {
+ initializeRegionOnlyViewerPass(*PassRegistry::getPassRegistry());
+ }
+};
+char RegionOnlyViewer::ID = 0;
+
+} //end anonymous namespace
+
+INITIALIZE_PASS(RegionPrinter, "dot-regions",
+ "Print regions of function to 'dot' file", true, true)
+
+INITIALIZE_PASS(
+ RegionOnlyPrinter, "dot-regions-only",
+ "Print regions of function to 'dot' file (with no function bodies)", true,
+ true)
+
+INITIALIZE_PASS(RegionViewer, "view-regions", "View regions of function",
+ true, true)
+
+INITIALIZE_PASS(RegionOnlyViewer, "view-regions-only",
+ "View regions of function (with no function bodies)",
+ true, true)
+
+FunctionPass *llvm::createRegionPrinterPass() { return new RegionPrinter(); }
+
+FunctionPass *llvm::createRegionOnlyPrinterPass() {
+ return new RegionOnlyPrinter();
+}
+
+FunctionPass* llvm::createRegionViewerPass() {
+ return new RegionViewer();
+}
+
+FunctionPass* llvm::createRegionOnlyViewerPass() {
+ return new RegionOnlyViewer();
+}
+
+#ifndef NDEBUG
+static void viewRegionInfo(RegionInfo *RI, bool ShortNames) {
+ assert(RI && "Argument must be non-null");
+
+ llvm::Function *F = RI->getTopLevelRegion()->getEntry()->getParent();
+ std::string GraphName = DOTGraphTraits<RegionInfo *>::getGraphName(RI);
+
+ llvm::ViewGraph(RI, "reg", ShortNames,
+ Twine(GraphName) + " for '" + F->getName() + "' function");
+}
+
+static void invokeFunctionPass(const Function *F, FunctionPass *ViewerPass) {
+ assert(F && "Argument must be non-null");
+ assert(!F->isDeclaration() && "Function must have an implementation");
+
+ // The viewer and analysis passes do not modify anything, so we can safely
+ // remove the const qualifier
+ auto NonConstF = const_cast<Function *>(F);
+
+ llvm::legacy::FunctionPassManager FPM(NonConstF->getParent());
+ FPM.add(ViewerPass);
+ FPM.doInitialization();
+ FPM.run(*NonConstF);
+ FPM.doFinalization();
+}
+
+void llvm::viewRegion(RegionInfo *RI) { viewRegionInfo(RI, false); }
+
+void llvm::viewRegion(const Function *F) {
+ invokeFunctionPass(F, createRegionViewerPass());
+}
+
+void llvm::viewRegionOnly(RegionInfo *RI) { viewRegionInfo(RI, true); }
+
+void llvm::viewRegionOnly(const Function *F) {
+ invokeFunctionPass(F, createRegionOnlyViewerPass());
+}
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
new file mode 100644
index 000000000000..5ce0a1adeaa0
--- /dev/null
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -0,0 +1,12530 @@
+//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the scalar evolution analysis
+// engine, which is used primarily to analyze expressions involving induction
+// variables in loops.
+//
+// There are several aspects to this library. First is the representation of
+// scalar expressions, which are represented as subclasses of the SCEV class.
+// These classes are used to represent certain types of subexpressions that we
+// can handle. We only create one SCEV of a particular shape, so
+// pointer-comparisons for equality are legal.
+//
+// One important aspect of the SCEV objects is that they are never cyclic, even
+// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
+// the PHI node is one of the idioms that we can represent (e.g., a polynomial
+// recurrence) then we represent it directly as a recurrence node, otherwise we
+// represent it as a SCEVUnknown node.
+//
+// In addition to being able to represent expressions of various types, we also
+// have folders that are used to build the *canonical* representation for a
+// particular expression. These folders are capable of using a variety of
+// rewrite rules to simplify the expressions.
+//
+// Once the folders are defined, we can implement the more interesting
+// higher-level code, such as the code that recognizes PHI nodes of various
+// types, computes the execution count of a loop, etc.
+//
+// TODO: We should use these routines and value representations to implement
+// dependence analysis!
+//
+//===----------------------------------------------------------------------===//
+//
+// There are several good references for the techniques used in this analysis.
+//
+// Chains of recurrences -- a method to expedite the evaluation
+// of closed-form functions
+// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
+//
+// On computational properties of chains of recurrences
+// Eugene V. Zima
+//
+// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
+// Robert A. van Engelen
+//
+// Efficient Symbolic Analysis for Optimizing Compilers
+// Robert A. van Engelen
+//
+// Using the chains of recurrences algebra for data dependence testing and
+// induction variable substitution
+// MS Thesis, Johnie Birch
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/SaveAndRestore.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <cassert>
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <map>
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "scalar-evolution"
+
+STATISTIC(NumArrayLenItCounts,
+ "Number of trip counts computed with array length");
+STATISTIC(NumTripCountsComputed,
+ "Number of loops with predictable loop counts");
+STATISTIC(NumTripCountsNotComputed,
+ "Number of loops without predictable loop counts");
+STATISTIC(NumBruteForceTripCountsComputed,
+ "Number of loops with trip counts computed by force");
+
+static cl::opt<unsigned>
+MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
+ cl::ZeroOrMore,
+ cl::desc("Maximum number of iterations SCEV will "
+ "symbolically execute a constant "
+ "derived loop"),
+ cl::init(100));
+
+// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
+static cl::opt<bool> VerifySCEV(
+ "verify-scev", cl::Hidden,
+ cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
+static cl::opt<bool> VerifySCEVStrict(
+ "verify-scev-strict", cl::Hidden,
+ cl::desc("Enable stricter verification with -verify-scev is passed"));
+static cl::opt<bool>
+ VerifySCEVMap("verify-scev-maps", cl::Hidden,
+ cl::desc("Verify no dangling value in ScalarEvolution's "
+ "ExprValueMap (slow)"));
+
+static cl::opt<bool> VerifyIR(
+ "scev-verify-ir", cl::Hidden,
+ cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
+ cl::init(false));
+
+static cl::opt<unsigned> MulOpsInlineThreshold(
+ "scev-mulops-inline-threshold", cl::Hidden,
+ cl::desc("Threshold for inlining multiplication operands into a SCEV"),
+ cl::init(32));
+
+static cl::opt<unsigned> AddOpsInlineThreshold(
+ "scev-addops-inline-threshold", cl::Hidden,
+ cl::desc("Threshold for inlining addition operands into a SCEV"),
+ cl::init(500));
+
+static cl::opt<unsigned> MaxSCEVCompareDepth(
+ "scalar-evolution-max-scev-compare-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
+ cl::init(32));
+
+static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
+ "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
+ cl::init(2));
+
+static cl::opt<unsigned> MaxValueCompareDepth(
+ "scalar-evolution-max-value-compare-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive value complexity comparisons"),
+ cl::init(2));
+
+static cl::opt<unsigned>
+ MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive arithmetics"),
+ cl::init(32));
+
+static cl::opt<unsigned> MaxConstantEvolvingDepth(
+ "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
+
+static cl::opt<unsigned>
+ MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
+ cl::init(8));
+
+static cl::opt<unsigned>
+ MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
+ cl::desc("Max coefficients in AddRec during evolving"),
+ cl::init(8));
+
+static cl::opt<unsigned>
+ HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
+ cl::desc("Size of the expression which is considered huge"),
+ cl::init(4096));
+
+//===----------------------------------------------------------------------===//
+// SCEV class definitions
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Implementation of the SCEV class.
+//
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void SCEV::dump() const {
+ print(dbgs());
+ dbgs() << '\n';
+}
+#endif
+
+void SCEV::print(raw_ostream &OS) const {
+ switch (static_cast<SCEVTypes>(getSCEVType())) {
+ case scConstant:
+ cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
+ return;
+ case scTruncate: {
+ const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
+ const SCEV *Op = Trunc->getOperand();
+ OS << "(trunc " << *Op->getType() << " " << *Op << " to "
+ << *Trunc->getType() << ")";
+ return;
+ }
+ case scZeroExtend: {
+ const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
+ const SCEV *Op = ZExt->getOperand();
+ OS << "(zext " << *Op->getType() << " " << *Op << " to "
+ << *ZExt->getType() << ")";
+ return;
+ }
+ case scSignExtend: {
+ const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
+ const SCEV *Op = SExt->getOperand();
+ OS << "(sext " << *Op->getType() << " " << *Op << " to "
+ << *SExt->getType() << ")";
+ return;
+ }
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
+ OS << "{" << *AR->getOperand(0);
+ for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
+ OS << ",+," << *AR->getOperand(i);
+ OS << "}<";
+ if (AR->hasNoUnsignedWrap())
+ OS << "nuw><";
+ if (AR->hasNoSignedWrap())
+ OS << "nsw><";
+ if (AR->hasNoSelfWrap() &&
+ !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
+ OS << "nw><";
+ AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ">";
+ return;
+ }
+ case scAddExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr: {
+ const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
+ const char *OpStr = nullptr;
+ switch (NAry->getSCEVType()) {
+ case scAddExpr: OpStr = " + "; break;
+ case scMulExpr: OpStr = " * "; break;
+ case scUMaxExpr: OpStr = " umax "; break;
+ case scSMaxExpr: OpStr = " smax "; break;
+ case scUMinExpr:
+ OpStr = " umin ";
+ break;
+ case scSMinExpr:
+ OpStr = " smin ";
+ break;
+ }
+ OS << "(";
+ for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
+ I != E; ++I) {
+ OS << **I;
+ if (std::next(I) != E)
+ OS << OpStr;
+ }
+ OS << ")";
+ switch (NAry->getSCEVType()) {
+ case scAddExpr:
+ case scMulExpr:
+ if (NAry->hasNoUnsignedWrap())
+ OS << "<nuw>";
+ if (NAry->hasNoSignedWrap())
+ OS << "<nsw>";
+ }
+ return;
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
+ OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
+ return;
+ }
+ case scUnknown: {
+ const SCEVUnknown *U = cast<SCEVUnknown>(this);
+ Type *AllocTy;
+ if (U->isSizeOf(AllocTy)) {
+ OS << "sizeof(" << *AllocTy << ")";
+ return;
+ }
+ if (U->isAlignOf(AllocTy)) {
+ OS << "alignof(" << *AllocTy << ")";
+ return;
+ }
+
+ Type *CTy;
+ Constant *FieldNo;
+ if (U->isOffsetOf(CTy, FieldNo)) {
+ OS << "offsetof(" << *CTy << ", ";
+ FieldNo->printAsOperand(OS, false);
+ OS << ")";
+ return;
+ }
+
+ // Otherwise just print it normally.
+ U->getValue()->printAsOperand(OS, false);
+ return;
+ }
+ case scCouldNotCompute:
+ OS << "***COULDNOTCOMPUTE***";
+ return;
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
+
+Type *SCEV::getType() const {
+ switch (static_cast<SCEVTypes>(getSCEVType())) {
+ case scConstant:
+ return cast<SCEVConstant>(this)->getType();
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ return cast<SCEVCastExpr>(this)->getType();
+ case scAddRecExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr:
+ return cast<SCEVNAryExpr>(this)->getType();
+ case scAddExpr:
+ return cast<SCEVAddExpr>(this)->getType();
+ case scUDivExpr:
+ return cast<SCEVUDivExpr>(this)->getType();
+ case scUnknown:
+ return cast<SCEVUnknown>(this)->getType();
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
+
+bool SCEV::isZero() const {
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
+ return SC->getValue()->isZero();
+ return false;
+}
+
+bool SCEV::isOne() const {
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
+ return SC->getValue()->isOne();
+ return false;
+}
+
+bool SCEV::isAllOnesValue() const {
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
+ return SC->getValue()->isMinusOne();
+ return false;
+}
+
+bool SCEV::isNonConstantNegative() const {
+ const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
+ if (!Mul) return false;
+
+ // If there is a constant factor, it will be first.
+ const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
+ if (!SC) return false;
+
+ // Return true if the value is negative, this matches things like (-42 * V).
+ return SC->getAPInt().isNegative();
+}
+
+SCEVCouldNotCompute::SCEVCouldNotCompute() :
+ SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
+
+bool SCEVCouldNotCompute::classof(const SCEV *S) {
+ return S->getSCEVType() == scCouldNotCompute;
+}
+
+const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scConstant);
+ ID.AddPointer(V);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
+ UniqueSCEVs.InsertNode(S, IP);
+ return S;
+}
+
+const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
+ return getConstant(ConstantInt::get(getContext(), Val));
+}
+
+const SCEV *
+ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
+ IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
+ return getConstant(ConstantInt::get(ITy, V, isSigned));
+}
+
+SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
+ unsigned SCEVTy, const SCEV *op, Type *ty)
+ : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
+
+SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
+ const SCEV *op, Type *ty)
+ : SCEVCastExpr(ID, scTruncate, op, ty) {
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot truncate non-integer value!");
+}
+
+SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
+ const SCEV *op, Type *ty)
+ : SCEVCastExpr(ID, scZeroExtend, op, ty) {
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot zero extend non-integer value!");
+}
+
+SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
+ const SCEV *op, Type *ty)
+ : SCEVCastExpr(ID, scSignExtend, op, ty) {
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot sign extend non-integer value!");
+}
+
+void SCEVUnknown::deleted() {
+ // Clear this SCEVUnknown from various maps.
+ SE->forgetMemoizedResults(this);
+
+ // Remove this SCEVUnknown from the uniquing map.
+ SE->UniqueSCEVs.RemoveNode(this);
+
+ // Release the value.
+ setValPtr(nullptr);
+}
+
+void SCEVUnknown::allUsesReplacedWith(Value *New) {
+ // Remove this SCEVUnknown from the uniquing map.
+ SE->UniqueSCEVs.RemoveNode(this);
+
+ // Update this SCEVUnknown to point to the new value. This is needed
+ // because there may still be outstanding SCEVs which still point to
+ // this SCEVUnknown.
+ setValPtr(New);
+}
+
+bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
+ if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
+ if (VCE->getOpcode() == Instruction::PtrToInt)
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
+ if (CE->getOpcode() == Instruction::GetElementPtr &&
+ CE->getOperand(0)->isNullValue() &&
+ CE->getNumOperands() == 2)
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
+ if (CI->isOne()) {
+ AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
+ ->getElementType();
+ return true;
+ }
+
+ return false;
+}
+
+bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
+ if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
+ if (VCE->getOpcode() == Instruction::PtrToInt)
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
+ if (CE->getOpcode() == Instruction::GetElementPtr &&
+ CE->getOperand(0)->isNullValue()) {
+ Type *Ty =
+ cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
+ if (StructType *STy = dyn_cast<StructType>(Ty))
+ if (!STy->isPacked() &&
+ CE->getNumOperands() == 3 &&
+ CE->getOperand(1)->isNullValue()) {
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
+ if (CI->isOne() &&
+ STy->getNumElements() == 2 &&
+ STy->getElementType(0)->isIntegerTy(1)) {
+ AllocTy = STy->getElementType(1);
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
+ if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
+ if (VCE->getOpcode() == Instruction::PtrToInt)
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
+ if (CE->getOpcode() == Instruction::GetElementPtr &&
+ CE->getNumOperands() == 3 &&
+ CE->getOperand(0)->isNullValue() &&
+ CE->getOperand(1)->isNullValue()) {
+ Type *Ty =
+ cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
+ // Ignore vector types here so that ScalarEvolutionExpander doesn't
+ // emit getelementptrs that index into vectors.
+ if (Ty->isStructTy() || Ty->isArrayTy()) {
+ CTy = Ty;
+ FieldNo = CE->getOperand(2);
+ return true;
+ }
+ }
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// SCEV Utilities
+//===----------------------------------------------------------------------===//
+
+/// Compare the two values \p LV and \p RV in terms of their "complexity" where
+/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
+/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
+/// have been previously deemed to be "equally complex" by this routine. It is
+/// intended to avoid exponential time complexity in cases like:
+///
+/// %a = f(%x, %y)
+/// %b = f(%a, %a)
+/// %c = f(%b, %b)
+///
+/// %d = f(%x, %y)
+/// %e = f(%d, %d)
+/// %f = f(%e, %e)
+///
+/// CompareValueComplexity(%f, %c)
+///
+/// Since we do not continue running this routine on expression trees once we
+/// have seen unequal values, there is no need to track them in the cache.
+static int
+CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
+ const LoopInfo *const LI, Value *LV, Value *RV,
+ unsigned Depth) {
+ if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
+ return 0;
+
+ // Order pointer values after integer values. This helps SCEVExpander form
+ // GEPs.
+ bool LIsPointer = LV->getType()->isPointerTy(),
+ RIsPointer = RV->getType()->isPointerTy();
+ if (LIsPointer != RIsPointer)
+ return (int)LIsPointer - (int)RIsPointer;
+
+ // Compare getValueID values.
+ unsigned LID = LV->getValueID(), RID = RV->getValueID();
+ if (LID != RID)
+ return (int)LID - (int)RID;
+
+ // Sort arguments by their position.
+ if (const auto *LA = dyn_cast<Argument>(LV)) {
+ const auto *RA = cast<Argument>(RV);
+ unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
+ return (int)LArgNo - (int)RArgNo;
+ }
+
+ if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
+ const auto *RGV = cast<GlobalValue>(RV);
+
+ const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
+ auto LT = GV->getLinkage();
+ return !(GlobalValue::isPrivateLinkage(LT) ||
+ GlobalValue::isInternalLinkage(LT));
+ };
+
+ // Use the names to distinguish the two values, but only if the
+ // names are semantically important.
+ if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
+ return LGV->getName().compare(RGV->getName());
+ }
+
+ // For instructions, compare their loop depth, and their operand count. This
+ // is pretty loose.
+ if (const auto *LInst = dyn_cast<Instruction>(LV)) {
+ const auto *RInst = cast<Instruction>(RV);
+
+ // Compare loop depths.
+ const BasicBlock *LParent = LInst->getParent(),
+ *RParent = RInst->getParent();
+ if (LParent != RParent) {
+ unsigned LDepth = LI->getLoopDepth(LParent),
+ RDepth = LI->getLoopDepth(RParent);
+ if (LDepth != RDepth)
+ return (int)LDepth - (int)RDepth;
+ }
+
+ // Compare the number of operands.
+ unsigned LNumOps = LInst->getNumOperands(),
+ RNumOps = RInst->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
+
+ for (unsigned Idx : seq(0u, LNumOps)) {
+ int Result =
+ CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
+ RInst->getOperand(Idx), Depth + 1);
+ if (Result != 0)
+ return Result;
+ }
+ }
+
+ EqCacheValue.unionSets(LV, RV);
+ return 0;
+}
+
+// Return negative, zero, or positive, if LHS is less than, equal to, or greater
+// than RHS, respectively. A three-way result allows recursive comparisons to be
+// more efficient.
+static int CompareSCEVComplexity(
+ EquivalenceClasses<const SCEV *> &EqCacheSCEV,
+ EquivalenceClasses<const Value *> &EqCacheValue,
+ const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
+ DominatorTree &DT, unsigned Depth = 0) {
+ // Fast-path: SCEVs are uniqued so we can do a quick equality check.
+ if (LHS == RHS)
+ return 0;
+
+ // Primarily, sort the SCEVs by their getSCEVType().
+ unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+ if (LType != RType)
+ return (int)LType - (int)RType;
+
+ if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
+ return 0;
+ // Aside from the getSCEVType() ordering, the particular ordering
+ // isn't very important except that it's beneficial to be consistent,
+ // so that (a + b) and (b + a) don't end up as different expressions.
+ switch (static_cast<SCEVTypes>(LType)) {
+ case scUnknown: {
+ const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
+ const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
+
+ int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
+ RU->getValue(), Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.unionSets(LHS, RHS);
+ return X;
+ }
+
+ case scConstant: {
+ const SCEVConstant *LC = cast<SCEVConstant>(LHS);
+ const SCEVConstant *RC = cast<SCEVConstant>(RHS);
+
+ // Compare constant values.
+ const APInt &LA = LC->getAPInt();
+ const APInt &RA = RC->getAPInt();
+ unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
+ if (LBitWidth != RBitWidth)
+ return (int)LBitWidth - (int)RBitWidth;
+ return LA.ult(RA) ? -1 : 1;
+ }
+
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
+ const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+
+ // There is always a dominance between two recs that are used by one SCEV,
+ // so we can safely sort recs by loop header dominance. We require such
+ // order in getAddExpr.
+ const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
+ if (LLoop != RLoop) {
+ const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
+ assert(LHead != RHead && "Two loops share the same header?");
+ if (DT.dominates(LHead, RHead))
+ return 1;
+ else
+ assert(DT.dominates(RHead, LHead) &&
+ "No dominance between recurrences used by one SCEV?");
+ return -1;
+ }
+
+ // Addrec complexity grows with operand count.
+ unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
+
+ // Lexicographically compare.
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LA->getOperand(i), RA->getOperand(i), DT,
+ Depth + 1);
+ if (X != 0)
+ return X;
+ }
+ EqCacheSCEV.unionSets(LHS, RHS);
+ return 0;
+ }
+
+ case scAddExpr:
+ case scMulExpr:
+ case scSMaxExpr:
+ case scUMaxExpr:
+ case scSMinExpr:
+ case scUMinExpr: {
+ const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
+ const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
+
+ // Lexicographically compare n-ary expressions.
+ unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
+
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LC->getOperand(i), RC->getOperand(i), DT,
+ Depth + 1);
+ if (X != 0)
+ return X;
+ }
+ EqCacheSCEV.unionSets(LHS, RHS);
+ return 0;
+ }
+
+ case scUDivExpr: {
+ const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
+ const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
+
+ // Lexicographically compare udiv expressions.
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
+ RC->getLHS(), DT, Depth + 1);
+ if (X != 0)
+ return X;
+ X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
+ RC->getRHS(), DT, Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.unionSets(LHS, RHS);
+ return X;
+ }
+
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend: {
+ const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
+ const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+
+ // Compare cast expressions by operand.
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LC->getOperand(), RC->getOperand(), DT,
+ Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.unionSets(LHS, RHS);
+ return X;
+ }
+
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
+
+/// Given a list of SCEV objects, order them by their complexity, and group
+/// objects of the same complexity together by value. When this routine is
+/// finished, we know that any duplicates in the vector are consecutive and that
+/// complexity is monotonically increasing.
+///
+/// Note that we go take special precautions to ensure that we get deterministic
+/// results from this routine. In other words, we don't want the results of
+/// this to depend on where the addresses of various SCEV objects happened to
+/// land in memory.
+static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
+ LoopInfo *LI, DominatorTree &DT) {
+ if (Ops.size() < 2) return; // Noop
+
+ EquivalenceClasses<const SCEV *> EqCacheSCEV;
+ EquivalenceClasses<const Value *> EqCacheValue;
+ if (Ops.size() == 2) {
+ // This is the common case, which also happens to be trivially simple.
+ // Special case it.
+ const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
+ if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
+ std::swap(LHS, RHS);
+ return;
+ }
+
+ // Do the rough sort by complexity.
+ llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
+ return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) <
+ 0;
+ });
+
+ // Now that we are sorted by complexity, group elements of the same
+ // complexity. Note that this is, at worst, N^2, but the vector is likely to
+ // be extremely short in practice. Note that we take this approach because we
+ // do not want to depend on the addresses of the objects we are grouping.
+ for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
+ const SCEV *S = Ops[i];
+ unsigned Complexity = S->getSCEVType();
+
+ // If there are any objects of the same complexity and same value as this
+ // one, group them.
+ for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
+ if (Ops[j] == S) { // Found a duplicate.
+ // Move it to immediately after i'th element.
+ std::swap(Ops[i+1], Ops[j]);
+ ++i; // no need to rescan it.
+ if (i == e-2) return; // Done!
+ }
+ }
+ }
+}
+
+// Returns the size of the SCEV S.
+static inline int sizeOfSCEV(const SCEV *S) {
+ struct FindSCEVSize {
+ int Size = 0;
+
+ FindSCEVSize() = default;
+
+ bool follow(const SCEV *S) {
+ ++Size;
+ // Keep looking at all operands of S.
+ return true;
+ }
+
+ bool isDone() const {
+ return false;
+ }
+ };
+
+ FindSCEVSize F;
+ SCEVTraversal<FindSCEVSize> ST(F);
+ ST.visitAll(S);
+ return F.Size;
+}
+
+/// Returns true if the subtree of \p S contains at least HugeExprThreshold
+/// nodes.
+static bool isHugeExpression(const SCEV *S) {
+ return S->getExpressionSize() >= HugeExprThreshold;
+}
+
+/// Returns true of \p Ops contains a huge SCEV (see definition above).
+static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
+ return any_of(Ops, isHugeExpression);
+}
+
+namespace {
+
+struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
+public:
+ // Computes the Quotient and Remainder of the division of Numerator by
+ // Denominator.
+ static void divide(ScalarEvolution &SE, const SCEV *Numerator,
+ const SCEV *Denominator, const SCEV **Quotient,
+ const SCEV **Remainder) {
+ assert(Numerator && Denominator && "Uninitialized SCEV");
+
+ SCEVDivision D(SE, Numerator, Denominator);
+
+ // Check for the trivial case here to avoid having to check for it in the
+ // rest of the code.
+ if (Numerator == Denominator) {
+ *Quotient = D.One;
+ *Remainder = D.Zero;
+ return;
+ }
+
+ if (Numerator->isZero()) {
+ *Quotient = D.Zero;
+ *Remainder = D.Zero;
+ return;
+ }
+
+ // A simple case when N/1. The quotient is N.
+ if (Denominator->isOne()) {
+ *Quotient = Numerator;
+ *Remainder = D.Zero;
+ return;
+ }
+
+ // Split the Denominator when it is a product.
+ if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
+ const SCEV *Q, *R;
+ *Quotient = Numerator;
+ for (const SCEV *Op : T->operands()) {
+ divide(SE, *Quotient, Op, &Q, &R);
+ *Quotient = Q;
+
+ // Bail out when the Numerator is not divisible by one of the terms of
+ // the Denominator.
+ if (!R->isZero()) {
+ *Quotient = D.Zero;
+ *Remainder = Numerator;
+ return;
+ }
+ }
+ *Remainder = D.Zero;
+ return;
+ }
+
+ D.visit(Numerator);
+ *Quotient = D.Quotient;
+ *Remainder = D.Remainder;
+ }
+
+ // Except in the trivial case described above, we do not know how to divide
+ // Expr by Denominator for the following functions with empty implementation.
+ void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
+ void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
+ void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
+ void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
+ void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
+ void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
+ void visitSMinExpr(const SCEVSMinExpr *Numerator) {}
+ void visitUMinExpr(const SCEVUMinExpr *Numerator) {}
+ void visitUnknown(const SCEVUnknown *Numerator) {}
+ void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
+
+ void visitConstant(const SCEVConstant *Numerator) {
+ if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+ APInt NumeratorVal = Numerator->getAPInt();
+ APInt DenominatorVal = D->getAPInt();
+ uint32_t NumeratorBW = NumeratorVal.getBitWidth();
+ uint32_t DenominatorBW = DenominatorVal.getBitWidth();
+
+ if (NumeratorBW > DenominatorBW)
+ DenominatorVal = DenominatorVal.sext(NumeratorBW);
+ else if (NumeratorBW < DenominatorBW)
+ NumeratorVal = NumeratorVal.sext(DenominatorBW);
+
+ APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
+ APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
+ APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
+ Quotient = SE.getConstant(QuotientVal);
+ Remainder = SE.getConstant(RemainderVal);
+ return;
+ }
+ }
+
+ void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
+ const SCEV *StartQ, *StartR, *StepQ, *StepR;
+ if (!Numerator->isAffine())
+ return cannotDivide(Numerator);
+ divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
+ divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
+ // Bail out if the types do not match.
+ Type *Ty = Denominator->getType();
+ if (Ty != StartQ->getType() || Ty != StartR->getType() ||
+ Ty != StepQ->getType() || Ty != StepR->getType())
+ return cannotDivide(Numerator);
+ Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
+ Numerator->getNoWrapFlags());
+ Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
+ Numerator->getNoWrapFlags());
+ }
+
+ void visitAddExpr(const SCEVAddExpr *Numerator) {
+ SmallVector<const SCEV *, 2> Qs, Rs;
+ Type *Ty = Denominator->getType();
+
+ for (const SCEV *Op : Numerator->operands()) {
+ const SCEV *Q, *R;
+ divide(SE, Op, Denominator, &Q, &R);
+
+ // Bail out if types do not match.
+ if (Ty != Q->getType() || Ty != R->getType())
+ return cannotDivide(Numerator);
+
+ Qs.push_back(Q);
+ Rs.push_back(R);
+ }
+
+ if (Qs.size() == 1) {
+ Quotient = Qs[0];
+ Remainder = Rs[0];
+ return;
+ }
+
+ Quotient = SE.getAddExpr(Qs);
+ Remainder = SE.getAddExpr(Rs);
+ }
+
+ void visitMulExpr(const SCEVMulExpr *Numerator) {
+ SmallVector<const SCEV *, 2> Qs;
+ Type *Ty = Denominator->getType();
+
+ bool FoundDenominatorTerm = false;
+ for (const SCEV *Op : Numerator->operands()) {
+ // Bail out if types do not match.
+ if (Ty != Op->getType())
+ return cannotDivide(Numerator);
+
+ if (FoundDenominatorTerm) {
+ Qs.push_back(Op);
+ continue;
+ }
+
+ // Check whether Denominator divides one of the product operands.
+ const SCEV *Q, *R;
+ divide(SE, Op, Denominator, &Q, &R);
+ if (!R->isZero()) {
+ Qs.push_back(Op);
+ continue;
+ }
+
+ // Bail out if types do not match.
+ if (Ty != Q->getType())
+ return cannotDivide(Numerator);
+
+ FoundDenominatorTerm = true;
+ Qs.push_back(Q);
+ }
+
+ if (FoundDenominatorTerm) {
+ Remainder = Zero;
+ if (Qs.size() == 1)
+ Quotient = Qs[0];
+ else
+ Quotient = SE.getMulExpr(Qs);
+ return;
+ }
+
+ if (!isa<SCEVUnknown>(Denominator))
+ return cannotDivide(Numerator);
+
+ // The Remainder is obtained by replacing Denominator by 0 in Numerator.
+ ValueToValueMap RewriteMap;
+ RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
+ cast<SCEVConstant>(Zero)->getValue();
+ Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
+
+ if (Remainder->isZero()) {
+ // The Quotient is obtained by replacing Denominator by 1 in Numerator.
+ RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
+ cast<SCEVConstant>(One)->getValue();
+ Quotient =
+ SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
+ return;
+ }
+
+ // Quotient is (Numerator - Remainder) divided by Denominator.
+ const SCEV *Q, *R;
+ const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
+ // This SCEV does not seem to simplify: fail the division here.
+ if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
+ return cannotDivide(Numerator);
+ divide(SE, Diff, Denominator, &Q, &R);
+ if (R != Zero)
+ return cannotDivide(Numerator);
+ Quotient = Q;
+ }
+
+private:
+ SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
+ const SCEV *Denominator)
+ : SE(S), Denominator(Denominator) {
+ Zero = SE.getZero(Denominator->getType());
+ One = SE.getOne(Denominator->getType());
+
+ // We generally do not know how to divide Expr by Denominator. We
+ // initialize the division to a "cannot divide" state to simplify the rest
+ // of the code.
+ cannotDivide(Numerator);
+ }
+
+ // Convenience function for giving up on the division. We set the quotient to
+ // be equal to zero and the remainder to be equal to the numerator.
+ void cannotDivide(const SCEV *Numerator) {
+ Quotient = Zero;
+ Remainder = Numerator;
+ }
+
+ ScalarEvolution &SE;
+ const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Simple SCEV method implementations
+//===----------------------------------------------------------------------===//
+
+/// Compute BC(It, K). The result has width W. Assume, K > 0.
+static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
+ ScalarEvolution &SE,
+ Type *ResultTy) {
+ // Handle the simplest case efficiently.
+ if (K == 1)
+ return SE.getTruncateOrZeroExtend(It, ResultTy);
+
+ // We are using the following formula for BC(It, K):
+ //
+ // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
+ //
+ // Suppose, W is the bitwidth of the return value. We must be prepared for
+ // overflow. Hence, we must assure that the result of our computation is
+ // equal to the accurate one modulo 2^W. Unfortunately, division isn't
+ // safe in modular arithmetic.
+ //
+ // However, this code doesn't use exactly that formula; the formula it uses
+ // is something like the following, where T is the number of factors of 2 in
+ // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
+ // exponentiation:
+ //
+ // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
+ //
+ // This formula is trivially equivalent to the previous formula. However,
+ // this formula can be implemented much more efficiently. The trick is that
+ // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
+ // arithmetic. To do exact division in modular arithmetic, all we have
+ // to do is multiply by the inverse. Therefore, this step can be done at
+ // width W.
+ //
+ // The next issue is how to safely do the division by 2^T. The way this
+ // is done is by doing the multiplication step at a width of at least W + T
+ // bits. This way, the bottom W+T bits of the product are accurate. Then,
+ // when we perform the division by 2^T (which is equivalent to a right shift
+ // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
+ // truncated out after the division by 2^T.
+ //
+ // In comparison to just directly using the first formula, this technique
+ // is much more efficient; using the first formula requires W * K bits,
+ // but this formula less than W + K bits. Also, the first formula requires
+ // a division step, whereas this formula only requires multiplies and shifts.
+ //
+ // It doesn't matter whether the subtraction step is done in the calculation
+ // width or the input iteration count's width; if the subtraction overflows,
+ // the result must be zero anyway. We prefer here to do it in the width of
+ // the induction variable because it helps a lot for certain cases; CodeGen
+ // isn't smart enough to ignore the overflow, which leads to much less
+ // efficient code if the width of the subtraction is wider than the native
+ // register width.
+ //
+ // (It's possible to not widen at all by pulling out factors of 2 before
+ // the multiplication; for example, K=2 can be calculated as
+ // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
+ // extra arithmetic, so it's not an obvious win, and it gets
+ // much more complicated for K > 3.)
+
+ // Protection from insane SCEVs; this bound is conservative,
+ // but it probably doesn't matter.
+ if (K > 1000)
+ return SE.getCouldNotCompute();
+
+ unsigned W = SE.getTypeSizeInBits(ResultTy);
+
+ // Calculate K! / 2^T and T; we divide out the factors of two before
+ // multiplying for calculating K! / 2^T to avoid overflow.
+ // Other overflow doesn't matter because we only care about the bottom
+ // W bits of the result.
+ APInt OddFactorial(W, 1);
+ unsigned T = 1;
+ for (unsigned i = 3; i <= K; ++i) {
+ APInt Mult(W, i);
+ unsigned TwoFactors = Mult.countTrailingZeros();
+ T += TwoFactors;
+ Mult.lshrInPlace(TwoFactors);
+ OddFactorial *= Mult;
+ }
+
+ // We need at least W + T bits for the multiplication step
+ unsigned CalculationBits = W + T;
+
+ // Calculate 2^T, at width T+W.
+ APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
+
+ // Calculate the multiplicative inverse of K! / 2^T;
+ // this multiplication factor will perform the exact division by
+ // K! / 2^T.
+ APInt Mod = APInt::getSignedMinValue(W+1);
+ APInt MultiplyFactor = OddFactorial.zext(W+1);
+ MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
+ MultiplyFactor = MultiplyFactor.trunc(W);
+
+ // Calculate the product, at width T+W
+ IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
+ CalculationBits);
+ const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
+ for (unsigned i = 1; i != K; ++i) {
+ const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
+ Dividend = SE.getMulExpr(Dividend,
+ SE.getTruncateOrZeroExtend(S, CalculationTy));
+ }
+
+ // Divide by 2^T
+ const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
+
+ // Truncate the result, and divide by K! / 2^T.
+
+ return SE.getMulExpr(SE.getConstant(MultiplyFactor),
+ SE.getTruncateOrZeroExtend(DivResult, ResultTy));
+}
+
+/// Return the value of this chain of recurrences at the specified iteration
+/// number. We can evaluate this recurrence by multiplying each element in the
+/// chain by the binomial coefficient corresponding to it. In other words, we
+/// can evaluate {A,+,B,+,C,+,D} as:
+///
+/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
+///
+/// where BC(It, k) stands for binomial coefficient.
+const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
+ ScalarEvolution &SE) const {
+ const SCEV *Result = getStart();
+ for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
+ // The computation is correct in the face of overflow provided that the
+ // multiplication is performed _after_ the evaluation of the binomial
+ // coefficient.
+ const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
+ if (isa<SCEVCouldNotCompute>(Coeff))
+ return Coeff;
+
+ Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
+ }
+ return Result;
+}
+
+//===----------------------------------------------------------------------===//
+// SCEV Expression folder implementations
+//===----------------------------------------------------------------------===//
+
+const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
+ unsigned Depth) {
+ assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
+ "This is not a truncating conversion!");
+ assert(isSCEVable(Ty) &&
+ "This is not a conversion to a SCEVable type!");
+ Ty = getEffectiveSCEVType(Ty);
+
+ FoldingSetNodeID ID;
+ ID.AddInteger(scTruncate);
+ ID.AddPointer(Op);
+ ID.AddPointer(Ty);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+
+ // Fold if the operand is constant.
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+ return getConstant(
+ cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
+
+ // trunc(trunc(x)) --> trunc(x)
+ if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
+ return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
+
+ // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
+ if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
+ return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
+
+ // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
+ if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
+ return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
+
+ if (Depth > MaxCastDepth) {
+ SCEV *S =
+ new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+ }
+
+ // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
+ // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
+ // if after transforming we have at most one truncate, not counting truncates
+ // that replace other casts.
+ if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
+ auto *CommOp = cast<SCEVCommutativeExpr>(Op);
+ SmallVector<const SCEV *, 4> Operands;
+ unsigned numTruncs = 0;
+ for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
+ ++i) {
+ const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
+ if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S))
+ numTruncs++;
+ Operands.push_back(S);
+ }
+ if (numTruncs < 2) {
+ if (isa<SCEVAddExpr>(Op))
+ return getAddExpr(Operands);
+ else if (isa<SCEVMulExpr>(Op))
+ return getMulExpr(Operands);
+ else
+ llvm_unreachable("Unexpected SCEV type for Op.");
+ }
+ // Although we checked in the beginning that ID is not in the cache, it is
+ // possible that during recursion and different modification ID was inserted
+ // into the cache. So if we find it, just return it.
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
+ return S;
+ }
+
+ // If the input value is a chrec scev, truncate the chrec's operands.
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (const SCEV *Op : AddRec->operands())
+ Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
+ return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
+ }
+
+ // The cast wasn't folded; create an explicit cast node. We can reuse
+ // the existing insert position since if we get here, we won't have
+ // made any changes which would invalidate it.
+ SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
+ Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+}
+
+// Get the limit of a recurrence such that incrementing by Step cannot cause
+// signed overflow as long as the value of the recurrence within the
+// loop does not exceed this limit before incrementing.
+static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
+ ICmpInst::Predicate *Pred,
+ ScalarEvolution *SE) {
+ unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
+ if (SE->isKnownPositive(Step)) {
+ *Pred = ICmpInst::ICMP_SLT;
+ return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
+ SE->getSignedRangeMax(Step));
+ }
+ if (SE->isKnownNegative(Step)) {
+ *Pred = ICmpInst::ICMP_SGT;
+ return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
+ SE->getSignedRangeMin(Step));
+ }
+ return nullptr;
+}
+
+// Get the limit of a recurrence such that incrementing by Step cannot cause
+// unsigned overflow as long as the value of the recurrence within the loop does
+// not exceed this limit before incrementing.
+static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
+ ICmpInst::Predicate *Pred,
+ ScalarEvolution *SE) {
+ unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
+ *Pred = ICmpInst::ICMP_ULT;
+
+ return SE->getConstant(APInt::getMinValue(BitWidth) -
+ SE->getUnsignedRangeMax(Step));
+}
+
+namespace {
+
+struct ExtendOpTraitsBase {
+ typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
+ unsigned);
+};
+
+// Used to make code generic over signed and unsigned overflow.
+template <typename ExtendOp> struct ExtendOpTraits {
+ // Members present:
+ //
+ // static const SCEV::NoWrapFlags WrapType;
+ //
+ // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
+ //
+ // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
+ // ICmpInst::Predicate *Pred,
+ // ScalarEvolution *SE);
+};
+
+template <>
+struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
+ static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
+
+ static const GetExtendExprTy GetExtendExpr;
+
+ static const SCEV *getOverflowLimitForStep(const SCEV *Step,
+ ICmpInst::Predicate *Pred,
+ ScalarEvolution *SE) {
+ return getSignedOverflowLimitForStep(Step, Pred, SE);
+ }
+};
+
+const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
+ SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
+
+template <>
+struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
+ static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
+
+ static const GetExtendExprTy GetExtendExpr;
+
+ static const SCEV *getOverflowLimitForStep(const SCEV *Step,
+ ICmpInst::Predicate *Pred,
+ ScalarEvolution *SE) {
+ return getUnsignedOverflowLimitForStep(Step, Pred, SE);
+ }
+};
+
+const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
+ SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
+
+} // end anonymous namespace
+
+// The recurrence AR has been shown to have no signed/unsigned wrap or something
+// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
+// easily prove NSW/NUW for its preincrement or postincrement sibling. This
+// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
+// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
+// expression "Step + sext/zext(PreIncAR)" is congruent with
+// "sext/zext(PostIncAR)"
+template <typename ExtendOpTy>
+static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
+ ScalarEvolution *SE, unsigned Depth) {
+ auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
+ auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
+
+ const Loop *L = AR->getLoop();
+ const SCEV *Start = AR->getStart();
+ const SCEV *Step = AR->getStepRecurrence(*SE);
+
+ // Check for a simple looking step prior to loop entry.
+ const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
+ if (!SA)
+ return nullptr;
+
+ // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
+ // subtraction is expensive. For this purpose, perform a quick and dirty
+ // difference, by checking for Step in the operand list.
+ SmallVector<const SCEV *, 4> DiffOps;
+ for (const SCEV *Op : SA->operands())
+ if (Op != Step)
+ DiffOps.push_back(Op);
+
+ if (DiffOps.size() == SA->getNumOperands())
+ return nullptr;
+
+ // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
+ // `Step`:
+
+ // 1. NSW/NUW flags on the step increment.
+ auto PreStartFlags =
+ ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
+ const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
+ const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
+ SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
+
+ // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
+ // "S+X does not sign/unsign-overflow".
+ //
+
+ const SCEV *BECount = SE->getBackedgeTakenCount(L);
+ if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
+ !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
+ return PreStart;
+
+ // 2. Direct overflow check on the step operation's expression.
+ unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
+ Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
+ const SCEV *OperandExtendedStart =
+ SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
+ (SE->*GetExtendExpr)(Step, WideTy, Depth));
+ if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
+ if (PreAR && AR->getNoWrapFlags(WrapType)) {
+ // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
+ // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
+ // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
+ const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
+ }
+ return PreStart;
+ }
+
+ // 3. Loop precondition.
+ ICmpInst::Predicate Pred;
+ const SCEV *OverflowLimit =
+ ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
+
+ if (OverflowLimit &&
+ SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
+ return PreStart;
+
+ return nullptr;
+}
+
+// Get the normalized zero or sign extended expression for this AddRec's Start.
+template <typename ExtendOpTy>
+static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
+ ScalarEvolution *SE,
+ unsigned Depth) {
+ auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
+
+ const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
+ if (!PreStart)
+ return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
+
+ return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
+ Depth),
+ (SE->*GetExtendExpr)(PreStart, Ty, Depth));
+}
+
+// Try to prove away overflow by looking at "nearby" add recurrences. A
+// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
+// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
+//
+// Formally:
+//
+// {S,+,X} == {S-T,+,X} + T
+// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
+//
+// If ({S-T,+,X} + T) does not overflow ... (1)
+//
+// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
+//
+// If {S-T,+,X} does not overflow ... (2)
+//
+// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
+// == {Ext(S-T)+Ext(T),+,Ext(X)}
+//
+// If (S-T)+T does not overflow ... (3)
+//
+// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
+// == {Ext(S),+,Ext(X)} == LHS
+//
+// Thus, if (1), (2) and (3) are true for some T, then
+// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
+//
+// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
+// does not overflow" restricted to the 0th iteration. Therefore we only need
+// to check for (1) and (2).
+//
+// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
+// is `Delta` (defined below).
+template <typename ExtendOpTy>
+bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
+ const SCEV *Step,
+ const Loop *L) {
+ auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
+
+ // We restrict `Start` to a constant to prevent SCEV from spending too much
+ // time here. It is correct (but more expensive) to continue with a
+ // non-constant `Start` and do a general SCEV subtraction to compute
+ // `PreStart` below.
+ const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
+ if (!StartC)
+ return false;
+
+ APInt StartAI = StartC->getAPInt();
+
+ for (unsigned Delta : {-2, -1, 1, 2}) {
+ const SCEV *PreStart = getConstant(StartAI - Delta);
+
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ ID.AddPointer(PreStart);
+ ID.AddPointer(Step);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ const auto *PreAR =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+
+ // Give up if we don't already have the add recurrence we need because
+ // actually constructing an add recurrence is relatively expensive.
+ if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
+ const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
+ ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
+ const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
+ DeltaS, &Pred, this);
+ if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
+ return true;
+ }
+ }
+
+ return false;
+}
+
+// Finds an integer D for an expression (C + x + y + ...) such that the top
+// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
+// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
+// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
+// the (C + x + y + ...) expression is \p WholeAddExpr.
+static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
+ const SCEVConstant *ConstantTerm,
+ const SCEVAddExpr *WholeAddExpr) {
+ const APInt C = ConstantTerm->getAPInt();
+ const unsigned BitWidth = C.getBitWidth();
+ // Find number of trailing zeros of (x + y + ...) w/o the C first:
+ uint32_t TZ = BitWidth;
+ for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
+ TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
+ if (TZ) {
+ // Set D to be as many least significant bits of C as possible while still
+ // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
+ return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
+ }
+ return APInt(BitWidth, 0);
+}
+
+// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
+// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
+// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
+// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
+static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
+ const APInt &ConstantStart,
+ const SCEV *Step) {
+ const unsigned BitWidth = ConstantStart.getBitWidth();
+ const uint32_t TZ = SE.GetMinTrailingZeros(Step);
+ if (TZ)
+ return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
+ : ConstantStart;
+ return APInt(BitWidth, 0);
+}
+
+const SCEV *
+ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
+ assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
+ "This is not an extending conversion!");
+ assert(isSCEVable(Ty) &&
+ "This is not a conversion to a SCEVable type!");
+ Ty = getEffectiveSCEVType(Ty);
+
+ // Fold if the operand is constant.
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+ return getConstant(
+ cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
+
+ // zext(zext(x)) --> zext(x)
+ if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
+ return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
+
+ // Before doing any expensive analysis, check to see if we've already
+ // computed a SCEV for this Op and Ty.
+ FoldingSetNodeID ID;
+ ID.AddInteger(scZeroExtend);
+ ID.AddPointer(Op);
+ ID.AddPointer(Ty);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ if (Depth > MaxCastDepth) {
+ SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
+ Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+ }
+
+ // zext(trunc(x)) --> zext(x) or x or trunc(x)
+ if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+ // It's possible the bits taken off by the truncate were all zero bits. If
+ // so, we should be able to simplify this further.
+ const SCEV *X = ST->getOperand();
+ ConstantRange CR = getUnsignedRange(X);
+ unsigned TruncBits = getTypeSizeInBits(ST->getType());
+ unsigned NewBits = getTypeSizeInBits(Ty);
+ if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
+ CR.zextOrTrunc(NewBits)))
+ return getTruncateOrZeroExtend(X, Ty, Depth);
+ }
+
+ // If the input value is a chrec scev, and we can prove that the value
+ // did not overflow the old, smaller, value, we can zero extend all of the
+ // operands (often constants). This allows analysis of something like
+ // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
+ if (AR->isAffine()) {
+ const SCEV *Start = AR->getStart();
+ const SCEV *Step = AR->getStepRecurrence(*this);
+ unsigned BitWidth = getTypeSizeInBits(AR->getType());
+ const Loop *L = AR->getLoop();
+
+ if (!AR->hasNoUnsignedWrap()) {
+ auto NewFlags = proveNoWrapViaConstantRanges(AR);
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ }
+
+ // If we have special knowledge that this addrec won't overflow,
+ // we don't need to do any further analysis.
+ if (AR->hasNoUnsignedWrap())
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
+ getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
+
+ // Check whether the backedge-taken count is SCEVCouldNotCompute.
+ // Note that this serves two purposes: It filters out loops that are
+ // simply not analyzable, and it covers the case where this code is
+ // being called from within backedge-taken count analysis, such that
+ // attempting to ask for the backedge-taken count would likely result
+ // in infinite recursion. In the later case, the analysis code will
+ // cope with a conservative value, and it will take care to purge
+ // that value once it has finished.
+ const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
+ if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
+ // Manually compute the final value for AR, checking for
+ // overflow.
+
+ // Check whether the backedge-taken count can be losslessly casted to
+ // the addrec's type. The count is always unsigned.
+ const SCEV *CastedMaxBECount =
+ getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
+ const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
+ CastedMaxBECount, MaxBECount->getType(), Depth);
+ if (MaxBECount == RecastedMaxBECount) {
+ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
+ // Check whether Start+Step*MaxBECount has no unsigned overflow.
+ const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
+ SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
+ SCEV::FlagAnyWrap,
+ Depth + 1),
+ WideTy, Depth + 1);
+ const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
+ const SCEV *WideMaxBECount =
+ getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
+ const SCEV *OperandExtendedAdd =
+ getAddExpr(WideStart,
+ getMulExpr(WideMaxBECount,
+ getZeroExtendExpr(Step, WideTy, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (ZAdd == OperandExtendedAdd) {
+ // Cache knowledge of AR NUW, which is propagated to this AddRec.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getZeroExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ // Similar to above, only this time treat the step value as signed.
+ // This covers loops that count down.
+ OperandExtendedAdd =
+ getAddExpr(WideStart,
+ getMulExpr(WideMaxBECount,
+ getSignExtendExpr(Step, WideTy, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (ZAdd == OperandExtendedAdd) {
+ // Cache knowledge of AR NW, which is propagated to this AddRec.
+ // Negative step causes unsigned wrap, but it still can't self-wrap.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ }
+ }
+
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+ if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
+ !AC.assumptions().empty()) {
+ // If the backedge is guarded by a comparison with the pre-inc
+ // value the addrec is safe. Also, if the entry is guarded by
+ // a comparison with the start value and the backedge is
+ // guarded by a comparison with the post-inc value, the addrec
+ // is safe.
+ if (isKnownPositive(Step)) {
+ const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
+ getUnsignedRangeMax(Step));
+ if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
+ isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
+ // Cache knowledge of AR NUW, which is propagated to this
+ // AddRec.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getZeroExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ } else if (isKnownNegative(Step)) {
+ const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
+ getSignedRangeMin(Step));
+ if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
+ isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
+ // Cache knowledge of AR NW, which is propagated to this
+ // AddRec. Negative step causes unsigned wrap, but it
+ // still can't self-wrap.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ }
+ }
+
+ // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
+ // if D + (C - D + Step * n) could be proven to not unsigned wrap
+ // where D maximizes the number of trailing zeros of (C - D + Step * n)
+ if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
+ const APInt &C = SC->getAPInt();
+ const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
+ if (D != 0) {
+ const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
+ const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SZExtD, SZExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+
+ if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
+ getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
+ }
+ }
+
+ // zext(A % B) --> zext(A) % zext(B)
+ {
+ const SCEV *LHS;
+ const SCEV *RHS;
+ if (matchURem(Op, LHS, RHS))
+ return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
+ getZeroExtendExpr(RHS, Ty, Depth + 1));
+ }
+
+ // zext(A / B) --> zext(A) / zext(B).
+ if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
+ return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
+ getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
+
+ if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
+ // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
+ if (SA->hasNoUnsignedWrap()) {
+ // If the addition does not unsign overflow then we can, by definition,
+ // commute the zero extension with the addition operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SA->operands())
+ Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
+ return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
+ }
+
+ // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
+ // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
+ // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
+ //
+ // Often address arithmetics contain expressions like
+ // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
+ // This transformation is useful while proving that such expressions are
+ // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
+ if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
+ const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
+ if (D != 0) {
+ const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
+ const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SZExtD, SZExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+ }
+
+ if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
+ // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
+ if (SM->hasNoUnsignedWrap()) {
+ // If the multiply does not unsign overflow then we can, by definition,
+ // commute the zero extension with the multiply operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SM->operands())
+ Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
+ return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
+ }
+
+ // zext(2^K * (trunc X to iN)) to iM ->
+ // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
+ //
+ // Proof:
+ //
+ // zext(2^K * (trunc X to iN)) to iM
+ // = zext((trunc X to iN) << K) to iM
+ // = zext((trunc X to i{N-K}) << K)<nuw> to iM
+ // (because shl removes the top K bits)
+ // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
+ // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
+ //
+ if (SM->getNumOperands() == 2)
+ if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
+ if (MulLHS->getAPInt().isPowerOf2())
+ if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
+ int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
+ MulLHS->getAPInt().logBase2();
+ Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
+ return getMulExpr(
+ getZeroExtendExpr(MulLHS, Ty),
+ getZeroExtendExpr(
+ getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
+ SCEV::FlagNUW, Depth + 1);
+ }
+ }
+
+ // The cast wasn't folded; create an explicit cast node.
+ // Recompute the insert position, as it may have been invalidated.
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
+ Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+}
+
+const SCEV *
+ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
+ assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
+ "This is not an extending conversion!");
+ assert(isSCEVable(Ty) &&
+ "This is not a conversion to a SCEVable type!");
+ Ty = getEffectiveSCEVType(Ty);
+
+ // Fold if the operand is constant.
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+ return getConstant(
+ cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
+
+ // sext(sext(x)) --> sext(x)
+ if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
+ return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
+
+ // sext(zext(x)) --> zext(x)
+ if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
+ return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
+
+ // Before doing any expensive analysis, check to see if we've already
+ // computed a SCEV for this Op and Ty.
+ FoldingSetNodeID ID;
+ ID.AddInteger(scSignExtend);
+ ID.AddPointer(Op);
+ ID.AddPointer(Ty);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ // Limit recursion depth.
+ if (Depth > MaxCastDepth) {
+ SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
+ Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+ }
+
+ // sext(trunc(x)) --> sext(x) or x or trunc(x)
+ if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+ // It's possible the bits taken off by the truncate were all sign bits. If
+ // so, we should be able to simplify this further.
+ const SCEV *X = ST->getOperand();
+ ConstantRange CR = getSignedRange(X);
+ unsigned TruncBits = getTypeSizeInBits(ST->getType());
+ unsigned NewBits = getTypeSizeInBits(Ty);
+ if (CR.truncate(TruncBits).signExtend(NewBits).contains(
+ CR.sextOrTrunc(NewBits)))
+ return getTruncateOrSignExtend(X, Ty, Depth);
+ }
+
+ if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
+ // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+ if (SA->hasNoSignedWrap()) {
+ // If the addition does not sign overflow then we can, by definition,
+ // commute the sign extension with the addition operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SA->operands())
+ Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
+ return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
+ }
+
+ // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
+ // if D + (C - D + x + y + ...) could be proven to not signed wrap
+ // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
+ //
+ // For instance, this will bring two seemingly different expressions:
+ // 1 + sext(5 + 20 * %x + 24 * %y) and
+ // sext(6 + 20 * %x + 24 * %y)
+ // to the same form:
+ // 2 + sext(4 + 20 * %x + 24 * %y)
+ if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
+ const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
+ if (D != 0) {
+ const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
+ const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SSExtD, SSExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+ }
+ // If the input value is a chrec scev, and we can prove that the value
+ // did not overflow the old, smaller, value, we can sign extend all of the
+ // operands (often constants). This allows analysis of something like
+ // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
+ if (AR->isAffine()) {
+ const SCEV *Start = AR->getStart();
+ const SCEV *Step = AR->getStepRecurrence(*this);
+ unsigned BitWidth = getTypeSizeInBits(AR->getType());
+ const Loop *L = AR->getLoop();
+
+ if (!AR->hasNoSignedWrap()) {
+ auto NewFlags = proveNoWrapViaConstantRanges(AR);
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ }
+
+ // If we have special knowledge that this addrec won't overflow,
+ // we don't need to do any further analysis.
+ if (AR->hasNoSignedWrap())
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
+
+ // Check whether the backedge-taken count is SCEVCouldNotCompute.
+ // Note that this serves two purposes: It filters out loops that are
+ // simply not analyzable, and it covers the case where this code is
+ // being called from within backedge-taken count analysis, such that
+ // attempting to ask for the backedge-taken count would likely result
+ // in infinite recursion. In the later case, the analysis code will
+ // cope with a conservative value, and it will take care to purge
+ // that value once it has finished.
+ const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
+ if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
+ // Manually compute the final value for AR, checking for
+ // overflow.
+
+ // Check whether the backedge-taken count can be losslessly casted to
+ // the addrec's type. The count is always unsigned.
+ const SCEV *CastedMaxBECount =
+ getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
+ const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
+ CastedMaxBECount, MaxBECount->getType(), Depth);
+ if (MaxBECount == RecastedMaxBECount) {
+ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
+ // Check whether Start+Step*MaxBECount has no signed overflow.
+ const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
+ SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
+ SCEV::FlagAnyWrap,
+ Depth + 1),
+ WideTy, Depth + 1);
+ const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
+ const SCEV *WideMaxBECount =
+ getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
+ const SCEV *OperandExtendedAdd =
+ getAddExpr(WideStart,
+ getMulExpr(WideMaxBECount,
+ getSignExtendExpr(Step, WideTy, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (SAdd == OperandExtendedAdd) {
+ // Cache knowledge of AR NSW, which is propagated to this AddRec.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ // Similar to above, only this time treat the step value as unsigned.
+ // This covers loops that count up with an unsigned step.
+ OperandExtendedAdd =
+ getAddExpr(WideStart,
+ getMulExpr(WideMaxBECount,
+ getZeroExtendExpr(Step, WideTy, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (SAdd == OperandExtendedAdd) {
+ // If AR wraps around then
+ //
+ // abs(Step) * MaxBECount > unsigned-max(AR->getType())
+ // => SAdd != OperandExtendedAdd
+ //
+ // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
+ // (SAdd == OperandExtendedAdd => AR is NW)
+
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+
+ // Return the expression with the addrec on the outside.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
+ Depth + 1),
+ getZeroExtendExpr(Step, Ty, Depth + 1), L,
+ AR->getNoWrapFlags());
+ }
+ }
+ }
+
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+
+ if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
+ !AC.assumptions().empty()) {
+ // If the backedge is guarded by a comparison with the pre-inc
+ // value the addrec is safe. Also, if the entry is guarded by
+ // a comparison with the start value and the backedge is
+ // guarded by a comparison with the post-inc value, the addrec
+ // is safe.
+ ICmpInst::Predicate Pred;
+ const SCEV *OverflowLimit =
+ getSignedOverflowLimitForStep(Step, &Pred, this);
+ if (OverflowLimit &&
+ (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
+ isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
+ // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
+ }
+ }
+
+ // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
+ // if D + (C - D + Step * n) could be proven to not signed wrap
+ // where D maximizes the number of trailing zeros of (C - D + Step * n)
+ if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
+ const APInt &C = SC->getAPInt();
+ const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
+ if (D != 0) {
+ const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
+ const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SSExtD, SSExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+
+ if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
+ }
+ }
+
+ // If the input value is provably positive and we could not simplify
+ // away the sext build a zext instead.
+ if (isKnownNonNegative(Op))
+ return getZeroExtendExpr(Op, Ty, Depth + 1);
+
+ // The cast wasn't folded; create an explicit cast node.
+ // Recompute the insert position, as it may have been invalidated.
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
+ Op, Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+}
+
+/// getAnyExtendExpr - Return a SCEV for the given operand extended with
+/// unspecified bits out to the given type.
+const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
+ Type *Ty) {
+ assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
+ "This is not an extending conversion!");
+ assert(isSCEVable(Ty) &&
+ "This is not a conversion to a SCEVable type!");
+ Ty = getEffectiveSCEVType(Ty);
+
+ // Sign-extend negative constants.
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+ if (SC->getAPInt().isNegative())
+ return getSignExtendExpr(Op, Ty);
+
+ // Peel off a truncate cast.
+ if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
+ const SCEV *NewOp = T->getOperand();
+ if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
+ return getAnyExtendExpr(NewOp, Ty);
+ return getTruncateOrNoop(NewOp, Ty);
+ }
+
+ // Next try a zext cast. If the cast is folded, use it.
+ const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
+ if (!isa<SCEVZeroExtendExpr>(ZExt))
+ return ZExt;
+
+ // Next try a sext cast. If the cast is folded, use it.
+ const SCEV *SExt = getSignExtendExpr(Op, Ty);
+ if (!isa<SCEVSignExtendExpr>(SExt))
+ return SExt;
+
+ // Force the cast to be folded into the operands of an addrec.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
+ SmallVector<const SCEV *, 4> Ops;
+ for (const SCEV *Op : AR->operands())
+ Ops.push_back(getAnyExtendExpr(Op, Ty));
+ return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
+ }
+
+ // If the expression is obviously signed, use the sext cast value.
+ if (isa<SCEVSMaxExpr>(Op))
+ return SExt;
+
+ // Absent any other information, use the zext cast value.
+ return ZExt;
+}
+
+/// Process the given Ops list, which is a list of operands to be added under
+/// the given scale, update the given map. This is a helper function for
+/// getAddRecExpr. As an example of what it does, given a sequence of operands
+/// that would form an add expression like this:
+///
+/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
+///
+/// where A and B are constants, update the map with these values:
+///
+/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
+///
+/// and add 13 + A*B*29 to AccumulatedConstant.
+/// This will allow getAddRecExpr to produce this:
+///
+/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
+///
+/// This form often exposes folding opportunities that are hidden in
+/// the original operand list.
+///
+/// Return true iff it appears that any interesting folding opportunities
+/// may be exposed. This helps getAddRecExpr short-circuit extra work in
+/// the common case where no interesting opportunities are present, and
+/// is also used as a check to avoid infinite recursion.
+static bool
+CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
+ SmallVectorImpl<const SCEV *> &NewOps,
+ APInt &AccumulatedConstant,
+ const SCEV *const *Ops, size_t NumOperands,
+ const APInt &Scale,
+ ScalarEvolution &SE) {
+ bool Interesting = false;
+
+ // Iterate over the add operands. They are sorted, with constants first.
+ unsigned i = 0;
+ while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
+ ++i;
+ // Pull a buried constant out to the outside.
+ if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
+ Interesting = true;
+ AccumulatedConstant += Scale * C->getAPInt();
+ }
+
+ // Next comes everything else. We're especially interested in multiplies
+ // here, but they're in the middle, so just visit the rest with one loop.
+ for (; i != NumOperands; ++i) {
+ const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
+ if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
+ APInt NewScale =
+ Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
+ if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
+ // A multiplication of a constant with another add; recurse.
+ const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
+ Interesting |=
+ CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
+ Add->op_begin(), Add->getNumOperands(),
+ NewScale, SE);
+ } else {
+ // A multiplication of a constant with some other value. Update
+ // the map.
+ SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
+ const SCEV *Key = SE.getMulExpr(MulOps);
+ auto Pair = M.insert({Key, NewScale});
+ if (Pair.second) {
+ NewOps.push_back(Pair.first->first);
+ } else {
+ Pair.first->second += NewScale;
+ // The map already had an entry for this value, which may indicate
+ // a folding opportunity.
+ Interesting = true;
+ }
+ }
+ } else {
+ // An ordinary operand. Update the map.
+ std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
+ M.insert({Ops[i], Scale});
+ if (Pair.second) {
+ NewOps.push_back(Pair.first->first);
+ } else {
+ Pair.first->second += Scale;
+ // The map already had an entry for this value, which may indicate
+ // a folding opportunity.
+ Interesting = true;
+ }
+ }
+ }
+
+ return Interesting;
+}
+
+// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
+// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
+// can't-overflow flags for the operation if possible.
+static SCEV::NoWrapFlags
+StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
+ const ArrayRef<const SCEV *> Ops,
+ SCEV::NoWrapFlags Flags) {
+ using namespace std::placeholders;
+
+ using OBO = OverflowingBinaryOperator;
+
+ bool CanAnalyze =
+ Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
+ (void)CanAnalyze;
+ assert(CanAnalyze && "don't call from other places!");
+
+ int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
+ SCEV::NoWrapFlags SignOrUnsignWrap =
+ ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
+
+ // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
+ auto IsKnownNonNegative = [&](const SCEV *S) {
+ return SE->isKnownNonNegative(S);
+ };
+
+ if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
+ Flags =
+ ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+ SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
+
+ if (SignOrUnsignWrap != SignOrUnsignMask &&
+ (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
+ isa<SCEVConstant>(Ops[0])) {
+
+ auto Opcode = [&] {
+ switch (Type) {
+ case scAddExpr:
+ return Instruction::Add;
+ case scMulExpr:
+ return Instruction::Mul;
+ default:
+ llvm_unreachable("Unexpected SCEV op.");
+ }
+ }();
+
+ const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
+
+ // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
+ if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
+ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Opcode, C, OBO::NoSignedWrap);
+ if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+ }
+
+ // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
+ if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
+ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Opcode, C, OBO::NoUnsignedWrap);
+ if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+ }
+ }
+
+ return Flags;
+}
+
+bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
+ return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
+}
+
+/// Get a canonical add expression, or something simpler if possible.
+const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
+ SCEV::NoWrapFlags Flags,
+ unsigned Depth) {
+ assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
+ "only nuw or nsw allowed");
+ assert(!Ops.empty() && "Cannot get empty add!");
+ if (Ops.size() == 1) return Ops[0];
+#ifndef NDEBUG
+ Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
+ for (unsigned i = 1, e = Ops.size(); i != e; ++i)
+ assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
+ "SCEVAddExpr operand types don't match!");
+#endif
+
+ // Sort by complexity, this groups all similar expression types together.
+ GroupByComplexity(Ops, &LI, DT);
+
+ Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
+
+ // If there are any constants, fold them together.
+ unsigned Idx = 0;
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+ ++Idx;
+ assert(Idx < Ops.size());
+ while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+ // We found two constants, fold them together!
+ Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
+ if (Ops.size() == 2) return Ops[0];
+ Ops.erase(Ops.begin()+1); // Erase the folded element
+ LHSC = cast<SCEVConstant>(Ops[0]);
+ }
+
+ // If we are left with a constant zero being added, strip it off.
+ if (LHSC->getValue()->isZero()) {
+ Ops.erase(Ops.begin());
+ --Idx;
+ }
+
+ if (Ops.size() == 1) return Ops[0];
+ }
+
+ // Limit recursion calls depth.
+ if (Depth > MaxArithDepth || hasHugeExpression(Ops))
+ return getOrCreateAddExpr(Ops, Flags);
+
+ // Okay, check to see if the same value occurs in the operand list more than
+ // once. If so, merge them together into an multiply expression. Since we
+ // sorted the list, these values are required to be adjacent.
+ Type *Ty = Ops[0]->getType();
+ bool FoundMatch = false;
+ for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
+ if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
+ // Scan ahead to count how many equal operands there are.
+ unsigned Count = 2;
+ while (i+Count != e && Ops[i+Count] == Ops[i])
+ ++Count;
+ // Merge the values into a multiply.
+ const SCEV *Scale = getConstant(Ty, Count);
+ const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
+ if (Ops.size() == Count)
+ return Mul;
+ Ops[i] = Mul;
+ Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
+ --i; e -= Count - 1;
+ FoundMatch = true;
+ }
+ if (FoundMatch)
+ return getAddExpr(Ops, Flags, Depth + 1);
+
+ // Check for truncates. If all the operands are truncated from the same
+ // type, see if factoring out the truncate would permit the result to be
+ // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
+ // if the contents of the resulting outer trunc fold to something simple.
+ auto FindTruncSrcType = [&]() -> Type * {
+ // We're ultimately looking to fold an addrec of truncs and muls of only
+ // constants and truncs, so if we find any other types of SCEV
+ // as operands of the addrec then we bail and return nullptr here.
+ // Otherwise, we return the type of the operand of a trunc that we find.
+ if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
+ return T->getOperand()->getType();
+ if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
+ const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
+ if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
+ return T->getOperand()->getType();
+ }
+ return nullptr;
+ };
+ if (auto *SrcType = FindTruncSrcType()) {
+ SmallVector<const SCEV *, 8> LargeOps;
+ bool Ok = true;
+ // Check all the operands to see if they can be represented in the
+ // source type of the truncate.
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
+ if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
+ if (T->getOperand()->getType() != SrcType) {
+ Ok = false;
+ break;
+ }
+ LargeOps.push_back(T->getOperand());
+ } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
+ LargeOps.push_back(getAnyExtendExpr(C, SrcType));
+ } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
+ SmallVector<const SCEV *, 8> LargeMulOps;
+ for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
+ if (const SCEVTruncateExpr *T =
+ dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
+ if (T->getOperand()->getType() != SrcType) {
+ Ok = false;
+ break;
+ }
+ LargeMulOps.push_back(T->getOperand());
+ } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
+ LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
+ } else {
+ Ok = false;
+ break;
+ }
+ }
+ if (Ok)
+ LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
+ } else {
+ Ok = false;
+ break;
+ }
+ }
+ if (Ok) {
+ // Evaluate the expression in the larger type.
+ const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
+ // If it folds to something simple, use it. Otherwise, don't.
+ if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
+ return getTruncateExpr(Fold, Ty);
+ }
+ }
+
+ // Skip past any other cast SCEVs.
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
+ ++Idx;
+
+ // If there are add operands they would be next.
+ if (Idx < Ops.size()) {
+ bool DeletedAdd = false;
+ while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
+ if (Ops.size() > AddOpsInlineThreshold ||
+ Add->getNumOperands() > AddOpsInlineThreshold)
+ break;
+ // If we have an add, expand the add operands onto the end of the operands
+ // list.
+ Ops.erase(Ops.begin()+Idx);
+ Ops.append(Add->op_begin(), Add->op_end());
+ DeletedAdd = true;
+ }
+
+ // If we deleted at least one add, we added operands to the end of the list,
+ // and they are not necessarily sorted. Recurse to resort and resimplify
+ // any operands we just acquired.
+ if (DeletedAdd)
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+
+ // Skip over the add expression until we get to a multiply.
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
+ ++Idx;
+
+ // Check to see if there are any folding opportunities present with
+ // operands multiplied by constant values.
+ if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
+ uint64_t BitWidth = getTypeSizeInBits(Ty);
+ DenseMap<const SCEV *, APInt> M;
+ SmallVector<const SCEV *, 8> NewOps;
+ APInt AccumulatedConstant(BitWidth, 0);
+ if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
+ Ops.data(), Ops.size(),
+ APInt(BitWidth, 1), *this)) {
+ struct APIntCompare {
+ bool operator()(const APInt &LHS, const APInt &RHS) const {
+ return LHS.ult(RHS);
+ }
+ };
+
+ // Some interesting folding opportunity is present, so its worthwhile to
+ // re-generate the operands list. Group the operands by constant scale,
+ // to avoid multiplying by the same constant scale multiple times.
+ std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
+ for (const SCEV *NewOp : NewOps)
+ MulOpLists[M.find(NewOp)->second].push_back(NewOp);
+ // Re-generate the operands list.
+ Ops.clear();
+ if (AccumulatedConstant != 0)
+ Ops.push_back(getConstant(AccumulatedConstant));
+ for (auto &MulOp : MulOpLists)
+ if (MulOp.first != 0)
+ Ops.push_back(getMulExpr(
+ getConstant(MulOp.first),
+ getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1));
+ if (Ops.empty())
+ return getZero(Ty);
+ if (Ops.size() == 1)
+ return Ops[0];
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ }
+
+ // If we are adding something to a multiply expression, make sure the
+ // something is not already an operand of the multiply. If so, merge it into
+ // the multiply.
+ for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
+ const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
+ for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
+ const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
+ if (isa<SCEVConstant>(MulOpSCEV))
+ continue;
+ for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
+ if (MulOpSCEV == Ops[AddOp]) {
+ // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
+ const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
+ if (Mul->getNumOperands() != 2) {
+ // If the multiply has more than two operands, we must get the
+ // Y*Z term.
+ SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
+ Mul->op_begin()+MulOp);
+ MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
+ InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
+ const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (Ops.size() == 2) return OuterMul;
+ if (AddOp < Idx) {
+ Ops.erase(Ops.begin()+AddOp);
+ Ops.erase(Ops.begin()+Idx-1);
+ } else {
+ Ops.erase(Ops.begin()+Idx);
+ Ops.erase(Ops.begin()+AddOp-1);
+ }
+ Ops.push_back(OuterMul);
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+
+ // Check this multiply against other multiplies being added together.
+ for (unsigned OtherMulIdx = Idx+1;
+ OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
+ ++OtherMulIdx) {
+ const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
+ // If MulOp occurs in OtherMul, we can fold the two multiplies
+ // together.
+ for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
+ OMulOp != e; ++OMulOp)
+ if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
+ // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
+ const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
+ if (Mul->getNumOperands() != 2) {
+ SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
+ Mul->op_begin()+MulOp);
+ MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
+ InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
+ if (OtherMul->getNumOperands() != 2) {
+ SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
+ OtherMul->op_begin()+OMulOp);
+ MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
+ InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
+ const SCEV *InnerMulSum =
+ getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
+ SCEV::FlagAnyWrap, Depth + 1);
+ if (Ops.size() == 2) return OuterMul;
+ Ops.erase(Ops.begin()+Idx);
+ Ops.erase(Ops.begin()+OtherMulIdx-1);
+ Ops.push_back(OuterMul);
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ }
+ }
+ }
+
+ // If there are any add recurrences in the operands list, see if any other
+ // added values are loop invariant. If so, we can fold them into the
+ // recurrence.
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
+ ++Idx;
+
+ // Scan over all recurrences, trying to fold loop invariants into them.
+ for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
+ // Scan all of the other operands to this add and add them to the vector if
+ // they are loop invariant w.r.t. the recurrence.
+ SmallVector<const SCEV *, 8> LIOps;
+ const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
+ const Loop *AddRecLoop = AddRec->getLoop();
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
+ LIOps.push_back(Ops[i]);
+ Ops.erase(Ops.begin()+i);
+ --i; --e;
+ }
+
+ // If we found some loop invariants, fold them into the recurrence.
+ if (!LIOps.empty()) {
+ // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
+ LIOps.push_back(AddRec->getStart());
+
+ SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
+ AddRec->op_end());
+ // This follows from the fact that the no-wrap flags on the outer add
+ // expression are applicable on the 0th iteration, when the add recurrence
+ // will be equal to its start value.
+ AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
+
+ // Build the new addrec. Propagate the NUW and NSW flags if both the
+ // outer add and the inner addrec are guaranteed to have no overflow.
+ // Always propagate NW.
+ Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
+ const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
+
+ // If all of the other operands were loop invariant, we are done.
+ if (Ops.size() == 1) return NewRec;
+
+ // Otherwise, add the folded AddRec by the non-invariant parts.
+ for (unsigned i = 0;; ++i)
+ if (Ops[i] == AddRec) {
+ Ops[i] = NewRec;
+ break;
+ }
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+
+ // Okay, if there weren't any loop invariants to be folded, check to see if
+ // there are multiple AddRec's with the same loop induction variable being
+ // added together. If so, we can fold them.
+ for (unsigned OtherIdx = Idx+1;
+ OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
+ ++OtherIdx) {
+ // We expect the AddRecExpr's to be sorted in reverse dominance order,
+ // so that the 1st found AddRecExpr is dominated by all others.
+ assert(DT.dominates(
+ cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
+ AddRec->getLoop()->getHeader()) &&
+ "AddRecExprs are not sorted in reverse dominance order?");
+ if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
+ // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
+ SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
+ AddRec->op_end());
+ for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
+ ++OtherIdx) {
+ const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
+ if (OtherAddRec->getLoop() == AddRecLoop) {
+ for (unsigned i = 0, e = OtherAddRec->getNumOperands();
+ i != e; ++i) {
+ if (i >= AddRecOps.size()) {
+ AddRecOps.append(OtherAddRec->op_begin()+i,
+ OtherAddRec->op_end());
+ break;
+ }
+ SmallVector<const SCEV *, 2> TwoOps = {
+ AddRecOps[i], OtherAddRec->getOperand(i)};
+ AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
+ }
+ }
+ // Step size has changed, so we cannot guarantee no self-wraparound.
+ Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
+ return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+ }
+
+ // Otherwise couldn't fold anything into this recurrence. Move onto the
+ // next one.
+ }
+
+ // Okay, it looks like we really DO need an add expr. Check to see if we
+ // already have one, otherwise create a new one.
+ return getOrCreateAddExpr(Ops, Flags);
+}
+
+const SCEV *
+ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
+ SCEV::NoWrapFlags Flags) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddExpr);
+ for (const SCEV *Op : Ops)
+ ID.AddPointer(Op);
+ void *IP = nullptr;
+ SCEVAddExpr *S =
+ static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
+ std::uninitialized_copy(Ops.begin(), Ops.end(), O);
+ S = new (SCEVAllocator)
+ SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ }
+ S->setNoWrapFlags(Flags);
+ return S;
+}
+
+const SCEV *
+ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
+ const Loop *L, SCEV::NoWrapFlags Flags) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ ID.AddPointer(Ops[i]);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ SCEVAddRecExpr *S =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
+ std::uninitialized_copy(Ops.begin(), Ops.end(), O);
+ S = new (SCEVAllocator)
+ SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ }
+ S->setNoWrapFlags(Flags);
+ return S;
+}
+
+const SCEV *
+ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
+ SCEV::NoWrapFlags Flags) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scMulExpr);
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ ID.AddPointer(Ops[i]);
+ void *IP = nullptr;
+ SCEVMulExpr *S =
+ static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
+ std::uninitialized_copy(Ops.begin(), Ops.end(), O);
+ S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
+ O, Ops.size());
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ }
+ S->setNoWrapFlags(Flags);
+ return S;
+}
+
+static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
+ uint64_t k = i*j;
+ if (j > 1 && k / j != i) Overflow = true;
+ return k;
+}
+
+/// Compute the result of "n choose k", the binomial coefficient. If an
+/// intermediate computation overflows, Overflow will be set and the return will
+/// be garbage. Overflow is not cleared on absence of overflow.
+static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
+ // We use the multiplicative formula:
+ // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
+ // At each iteration, we take the n-th term of the numeral and divide by the
+ // (k-n)th term of the denominator. This division will always produce an
+ // integral result, and helps reduce the chance of overflow in the
+ // intermediate computations. However, we can still overflow even when the
+ // final result would fit.
+
+ if (n == 0 || n == k) return 1;
+ if (k > n) return 0;
+
+ if (k > n/2)
+ k = n-k;
+
+ uint64_t r = 1;
+ for (uint64_t i = 1; i <= k; ++i) {
+ r = umul_ov(r, n-(i-1), Overflow);
+ r /= i;
+ }
+ return r;
+}
+
+/// Determine if any of the operands in this SCEV are a constant or if
+/// any of the add or multiply expressions in this SCEV contain a constant.
+static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
+ struct FindConstantInAddMulChain {
+ bool FoundConstant = false;
+
+ bool follow(const SCEV *S) {
+ FoundConstant |= isa<SCEVConstant>(S);
+ return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
+ }
+
+ bool isDone() const {
+ return FoundConstant;
+ }
+ };
+
+ FindConstantInAddMulChain F;
+ SCEVTraversal<FindConstantInAddMulChain> ST(F);
+ ST.visitAll(StartExpr);
+ return F.FoundConstant;
+}
+
+/// Get a canonical multiply expression, or something simpler if possible.
+const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
+ SCEV::NoWrapFlags Flags,
+ unsigned Depth) {
+ assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
+ "only nuw or nsw allowed");
+ assert(!Ops.empty() && "Cannot get empty mul!");
+ if (Ops.size() == 1) return Ops[0];
+#ifndef NDEBUG
+ Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
+ for (unsigned i = 1, e = Ops.size(); i != e; ++i)
+ assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
+ "SCEVMulExpr operand types don't match!");
+#endif
+
+ // Sort by complexity, this groups all similar expression types together.
+ GroupByComplexity(Ops, &LI, DT);
+
+ Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
+
+ // Limit recursion calls depth.
+ if (Depth > MaxArithDepth || hasHugeExpression(Ops))
+ return getOrCreateMulExpr(Ops, Flags);
+
+ // If there are any constants, fold them together.
+ unsigned Idx = 0;
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+
+ if (Ops.size() == 2)
+ // C1*(C2+V) -> C1*C2 + C1*V
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
+ // If any of Add's ops are Adds or Muls with a constant, apply this
+ // transformation as well.
+ //
+ // TODO: There are some cases where this transformation is not
+ // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
+ // this transformation should be narrowed down.
+ if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
+ return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
+ SCEV::FlagAnyWrap, Depth + 1),
+ getMulExpr(LHSC, Add->getOperand(1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
+
+ ++Idx;
+ while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+ // We found two constants, fold them together!
+ ConstantInt *Fold =
+ ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
+ Ops[0] = getConstant(Fold);
+ Ops.erase(Ops.begin()+1); // Erase the folded element
+ if (Ops.size() == 1) return Ops[0];
+ LHSC = cast<SCEVConstant>(Ops[0]);
+ }
+
+ // If we are left with a constant one being multiplied, strip it off.
+ if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) {
+ Ops.erase(Ops.begin());
+ --Idx;
+ } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
+ // If we have a multiply of zero, it will always be zero.
+ return Ops[0];
+ } else if (Ops[0]->isAllOnesValue()) {
+ // If we have a mul by -1 of an add, try distributing the -1 among the
+ // add operands.
+ if (Ops.size() == 2) {
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
+ SmallVector<const SCEV *, 4> NewOps;
+ bool AnyFolded = false;
+ for (const SCEV *AddOp : Add->operands()) {
+ const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
+ Depth + 1);
+ if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
+ NewOps.push_back(Mul);
+ }
+ if (AnyFolded)
+ return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
+ } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
+ // Negation preserves a recurrence's no self-wrap property.
+ SmallVector<const SCEV *, 4> Operands;
+ for (const SCEV *AddRecOp : AddRec->operands())
+ Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
+ Depth + 1));
+
+ return getAddRecExpr(Operands, AddRec->getLoop(),
+ AddRec->getNoWrapFlags(SCEV::FlagNW));
+ }
+ }
+ }
+
+ if (Ops.size() == 1)
+ return Ops[0];
+ }
+
+ // Skip over the add expression until we get to a multiply.
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
+ ++Idx;
+
+ // If there are mul operands inline them all into this expression.
+ if (Idx < Ops.size()) {
+ bool DeletedMul = false;
+ while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
+ if (Ops.size() > MulOpsInlineThreshold)
+ break;
+ // If we have an mul, expand the mul operands onto the end of the
+ // operands list.
+ Ops.erase(Ops.begin()+Idx);
+ Ops.append(Mul->op_begin(), Mul->op_end());
+ DeletedMul = true;
+ }
+
+ // If we deleted at least one mul, we added operands to the end of the
+ // list, and they are not necessarily sorted. Recurse to resort and
+ // resimplify any operands we just acquired.
+ if (DeletedMul)
+ return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+
+ // If there are any add recurrences in the operands list, see if any other
+ // added values are loop invariant. If so, we can fold them into the
+ // recurrence.
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
+ ++Idx;
+
+ // Scan over all recurrences, trying to fold loop invariants into them.
+ for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
+ // Scan all of the other operands to this mul and add them to the vector
+ // if they are loop invariant w.r.t. the recurrence.
+ SmallVector<const SCEV *, 8> LIOps;
+ const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
+ const Loop *AddRecLoop = AddRec->getLoop();
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
+ LIOps.push_back(Ops[i]);
+ Ops.erase(Ops.begin()+i);
+ --i; --e;
+ }
+
+ // If we found some loop invariants, fold them into the recurrence.
+ if (!LIOps.empty()) {
+ // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
+ SmallVector<const SCEV *, 4> NewOps;
+ NewOps.reserve(AddRec->getNumOperands());
+ const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
+ for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
+ NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
+ SCEV::FlagAnyWrap, Depth + 1));
+
+ // Build the new addrec. Propagate the NUW and NSW flags if both the
+ // outer mul and the inner addrec are guaranteed to have no overflow.
+ //
+ // No self-wrap cannot be guaranteed after changing the step size, but
+ // will be inferred if either NUW or NSW is true.
+ Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
+ const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
+
+ // If all of the other operands were loop invariant, we are done.
+ if (Ops.size() == 1) return NewRec;
+
+ // Otherwise, multiply the folded AddRec by the non-invariant parts.
+ for (unsigned i = 0;; ++i)
+ if (Ops[i] == AddRec) {
+ Ops[i] = NewRec;
+ break;
+ }
+ return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+ }
+
+ // Okay, if there weren't any loop invariants to be folded, check to see
+ // if there are multiple AddRec's with the same loop induction variable
+ // being multiplied together. If so, we can fold them.
+
+ // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
+ // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
+ // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
+ // ]]],+,...up to x=2n}.
+ // Note that the arguments to choose() are always integers with values
+ // known at compile time, never SCEV objects.
+ //
+ // The implementation avoids pointless extra computations when the two
+ // addrec's are of different length (mathematically, it's equivalent to
+ // an infinite stream of zeros on the right).
+ bool OpsModified = false;
+ for (unsigned OtherIdx = Idx+1;
+ OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
+ ++OtherIdx) {
+ const SCEVAddRecExpr *OtherAddRec =
+ dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
+ if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
+ continue;
+
+ // Limit max number of arguments to avoid creation of unreasonably big
+ // SCEVAddRecs with very complex operands.
+ if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
+ MaxAddRecSize || isHugeExpression(AddRec) ||
+ isHugeExpression(OtherAddRec))
+ continue;
+
+ bool Overflow = false;
+ Type *Ty = AddRec->getType();
+ bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
+ SmallVector<const SCEV*, 7> AddRecOps;
+ for (int x = 0, xe = AddRec->getNumOperands() +
+ OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
+ SmallVector <const SCEV *, 7> SumOps;
+ for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
+ uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
+ for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
+ ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
+ z < ze && !Overflow; ++z) {
+ uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
+ uint64_t Coeff;
+ if (LargerThan64Bits)
+ Coeff = umul_ov(Coeff1, Coeff2, Overflow);
+ else
+ Coeff = Coeff1*Coeff2;
+ const SCEV *CoeffTerm = getConstant(Ty, Coeff);
+ const SCEV *Term1 = AddRec->getOperand(y-z);
+ const SCEV *Term2 = OtherAddRec->getOperand(z);
+ SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
+ SCEV::FlagAnyWrap, Depth + 1));
+ }
+ }
+ if (SumOps.empty())
+ SumOps.push_back(getZero(Ty));
+ AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
+ }
+ if (!Overflow) {
+ const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
+ SCEV::FlagAnyWrap);
+ if (Ops.size() == 2) return NewAddRec;
+ Ops[Idx] = NewAddRec;
+ Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
+ OpsModified = true;
+ AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
+ if (!AddRec)
+ break;
+ }
+ }
+ if (OpsModified)
+ return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+
+ // Otherwise couldn't fold anything into this recurrence. Move onto the
+ // next one.
+ }
+
+ // Okay, it looks like we really DO need an mul expr. Check to see if we
+ // already have one, otherwise create a new one.
+ return getOrCreateMulExpr(Ops, Flags);
+}
+
+/// Represents an unsigned remainder expression based on unsigned division.
+const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ assert(getEffectiveSCEVType(LHS->getType()) ==
+ getEffectiveSCEVType(RHS->getType()) &&
+ "SCEVURemExpr operand types don't match!");
+
+ // Short-circuit easy cases
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+ // If constant is one, the result is trivial
+ if (RHSC->getValue()->isOne())
+ return getZero(LHS->getType()); // X urem 1 --> 0
+
+ // If constant is a power of two, fold into a zext(trunc(LHS)).
+ if (RHSC->getAPInt().isPowerOf2()) {
+ Type *FullTy = LHS->getType();
+ Type *TruncTy =
+ IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
+ return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
+ }
+ }
+
+ // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
+ const SCEV *UDiv = getUDivExpr(LHS, RHS);
+ const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
+ return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
+}
+
+/// Get a canonical unsigned division expression, or something simpler if
+/// possible.
+const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ assert(getEffectiveSCEVType(LHS->getType()) ==
+ getEffectiveSCEVType(RHS->getType()) &&
+ "SCEVUDivExpr operand types don't match!");
+
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+ if (RHSC->getValue()->isOne())
+ return LHS; // X udiv 1 --> x
+ // If the denominator is zero, the result of the udiv is undefined. Don't
+ // try to analyze it, because the resolution chosen here may differ from
+ // the resolution chosen in other parts of the compiler.
+ if (!RHSC->getValue()->isZero()) {
+ // Determine if the division can be folded into the operands of
+ // its operands.
+ // TODO: Generalize this to non-constants by using known-bits information.
+ Type *Ty = LHS->getType();
+ unsigned LZ = RHSC->getAPInt().countLeadingZeros();
+ unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
+ // For non-power-of-two values, effectively round the value up to the
+ // nearest power of two.
+ if (!RHSC->getAPInt().isPowerOf2())
+ ++MaxShiftAmt;
+ IntegerType *ExtTy =
+ IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
+ if (const SCEVConstant *Step =
+ dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
+ // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
+ const APInt &StepInt = Step->getAPInt();
+ const APInt &DivInt = RHSC->getAPInt();
+ if (!StepInt.urem(DivInt) &&
+ getZeroExtendExpr(AR, ExtTy) ==
+ getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
+ getZeroExtendExpr(Step, ExtTy),
+ AR->getLoop(), SCEV::FlagAnyWrap)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (const SCEV *Op : AR->operands())
+ Operands.push_back(getUDivExpr(Op, RHS));
+ return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
+ }
+ /// Get a canonical UDivExpr for a recurrence.
+ /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
+ // We can currently only fold X%N if X is constant.
+ const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
+ if (StartC && !DivInt.urem(StepInt) &&
+ getZeroExtendExpr(AR, ExtTy) ==
+ getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
+ getZeroExtendExpr(Step, ExtTy),
+ AR->getLoop(), SCEV::FlagAnyWrap)) {
+ const APInt &StartInt = StartC->getAPInt();
+ const APInt &StartRem = StartInt.urem(StepInt);
+ if (StartRem != 0)
+ LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
+ AR->getLoop(), SCEV::FlagNW);
+ }
+ }
+ // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (const SCEV *Op : M->operands())
+ Operands.push_back(getZeroExtendExpr(Op, ExtTy));
+ if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
+ // Find an operand that's safely divisible.
+ for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
+ const SCEV *Op = M->getOperand(i);
+ const SCEV *Div = getUDivExpr(Op, RHSC);
+ if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
+ Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
+ M->op_end());
+ Operands[i] = Div;
+ return getMulExpr(Operands);
+ }
+ }
+ }
+
+ // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
+ if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
+ if (auto *DivisorConstant =
+ dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
+ bool Overflow = false;
+ APInt NewRHS =
+ DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
+ if (Overflow) {
+ return getConstant(RHSC->getType(), 0, false);
+ }
+ return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
+ }
+ }
+
+ // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
+ if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (const SCEV *Op : A->operands())
+ Operands.push_back(getZeroExtendExpr(Op, ExtTy));
+ if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
+ Operands.clear();
+ for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
+ const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
+ if (isa<SCEVUDivExpr>(Op) ||
+ getMulExpr(Op, RHS) != A->getOperand(i))
+ break;
+ Operands.push_back(Op);
+ }
+ if (Operands.size() == A->getNumOperands())
+ return getAddExpr(Operands);
+ }
+ }
+
+ // Fold if both operands are constant.
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
+ Constant *LHSCV = LHSC->getValue();
+ Constant *RHSCV = RHSC->getValue();
+ return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
+ RHSCV)));
+ }
+ }
+ }
+
+ FoldingSetNodeID ID;
+ ID.AddInteger(scUDivExpr);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
+ SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
+ LHS, RHS);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+}
+
+static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
+ APInt A = C1->getAPInt().abs();
+ APInt B = C2->getAPInt().abs();
+ uint32_t ABW = A.getBitWidth();
+ uint32_t BBW = B.getBitWidth();
+
+ if (ABW > BBW)
+ B = B.zext(ABW);
+ else if (ABW < BBW)
+ A = A.zext(BBW);
+
+ return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
+}
+
+/// Get a canonical unsigned division expression, or something simpler if
+/// possible. There is no representation for an exact udiv in SCEV IR, but we
+/// can attempt to remove factors from the LHS and RHS. We can't do this when
+/// it's not exact because the udiv may be clearing bits.
+const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ // TODO: we could try to find factors in all sorts of things, but for now we
+ // just deal with u/exact (multiply, constant). See SCEVDivision towards the
+ // end of this file for inspiration.
+
+ const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
+ if (!Mul || !Mul->hasNoUnsignedWrap())
+ return getUDivExpr(LHS, RHS);
+
+ if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
+ // If the mulexpr multiplies by a constant, then that constant must be the
+ // first element of the mulexpr.
+ if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
+ if (LHSCst == RHSCst) {
+ SmallVector<const SCEV *, 2> Operands;
+ Operands.append(Mul->op_begin() + 1, Mul->op_end());
+ return getMulExpr(Operands);
+ }
+
+ // We can't just assume that LHSCst divides RHSCst cleanly, it could be
+ // that there's a factor provided by one of the other terms. We need to
+ // check.
+ APInt Factor = gcd(LHSCst, RHSCst);
+ if (!Factor.isIntN(1)) {
+ LHSCst =
+ cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
+ RHSCst =
+ cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
+ SmallVector<const SCEV *, 2> Operands;
+ Operands.push_back(LHSCst);
+ Operands.append(Mul->op_begin() + 1, Mul->op_end());
+ LHS = getMulExpr(Operands);
+ RHS = RHSCst;
+ Mul = dyn_cast<SCEVMulExpr>(LHS);
+ if (!Mul)
+ return getUDivExactExpr(LHS, RHS);
+ }
+ }
+ }
+
+ for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
+ if (Mul->getOperand(i) == RHS) {
+ SmallVector<const SCEV *, 2> Operands;
+ Operands.append(Mul->op_begin(), Mul->op_begin() + i);
+ Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
+ return getMulExpr(Operands);
+ }
+ }
+
+ return getUDivExpr(LHS, RHS);
+}
+
+/// Get an add recurrence expression for the specified loop. Simplify the
+/// expression as much as possible.
+const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
+ const Loop *L,
+ SCEV::NoWrapFlags Flags) {
+ SmallVector<const SCEV *, 4> Operands;
+ Operands.push_back(Start);
+ if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
+ if (StepChrec->getLoop() == L) {
+ Operands.append(StepChrec->op_begin(), StepChrec->op_end());
+ return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
+ }
+
+ Operands.push_back(Step);
+ return getAddRecExpr(Operands, L, Flags);
+}
+
+/// Get an add recurrence expression for the specified loop. Simplify the
+/// expression as much as possible.
+const SCEV *
+ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
+ const Loop *L, SCEV::NoWrapFlags Flags) {
+ if (Operands.size() == 1) return Operands[0];
+#ifndef NDEBUG
+ Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
+ for (unsigned i = 1, e = Operands.size(); i != e; ++i)
+ assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
+ "SCEVAddRecExpr operand types don't match!");
+ for (unsigned i = 0, e = Operands.size(); i != e; ++i)
+ assert(isLoopInvariant(Operands[i], L) &&
+ "SCEVAddRecExpr operand is not loop-invariant!");
+#endif
+
+ if (Operands.back()->isZero()) {
+ Operands.pop_back();
+ return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
+ }
+
+ // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
+ // use that information to infer NUW and NSW flags. However, computing a
+ // BE count requires calling getAddRecExpr, so we may not yet have a
+ // meaningful BE count at this point (and if we don't, we'd be stuck
+ // with a SCEVCouldNotCompute as the cached BE count).
+
+ Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+
+ // Canonicalize nested AddRecs in by nesting them in order of loop depth.
+ if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
+ const Loop *NestedLoop = NestedAR->getLoop();
+ if (L->contains(NestedLoop)
+ ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
+ : (!NestedLoop->contains(L) &&
+ DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
+ SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
+ NestedAR->op_end());
+ Operands[0] = NestedAR->getStart();
+ // AddRecs require their operands be loop-invariant with respect to their
+ // loops. Don't perform this transformation if it would break this
+ // requirement.
+ bool AllInvariant = all_of(
+ Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+
+ if (AllInvariant) {
+ // Create a recurrence for the outer loop with the same step size.
+ //
+ // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
+ // inner recurrence has the same property.
+ SCEV::NoWrapFlags OuterFlags =
+ maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
+
+ NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
+ AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+ return isLoopInvariant(Op, NestedLoop);
+ });
+
+ if (AllInvariant) {
+ // Ok, both add recurrences are valid after the transformation.
+ //
+ // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
+ // the outer recurrence has the same property.
+ SCEV::NoWrapFlags InnerFlags =
+ maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
+ return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
+ }
+ }
+ // Reset Operands to its original state.
+ Operands[0] = NestedAR;
+ }
+ }
+
+ // Okay, it looks like we really DO need an addrec expr. Check to see if we
+ // already have one, otherwise create a new one.
+ return getOrCreateAddRecExpr(Operands, L, Flags);
+}
+
+const SCEV *
+ScalarEvolution::getGEPExpr(GEPOperator *GEP,
+ const SmallVectorImpl<const SCEV *> &IndexExprs) {
+ const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
+ // getSCEV(Base)->getType() has the same address space as Base->getType()
+ // because SCEV::getType() preserves the address space.
+ Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType());
+ // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
+ // instruction to its SCEV, because the Instruction may be guarded by control
+ // flow and the no-overflow bits may not be valid for the expression in any
+ // context. This can be fixed similarly to how these flags are handled for
+ // adds.
+ SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW
+ : SCEV::FlagAnyWrap;
+
+ const SCEV *TotalOffset = getZero(IntPtrTy);
+ // The array size is unimportant. The first thing we do on CurTy is getting
+ // its element type.
+ Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0);
+ for (const SCEV *IndexExpr : IndexExprs) {
+ // Compute the (potentially symbolic) offset in bytes for this index.
+ if (StructType *STy = dyn_cast<StructType>(CurTy)) {
+ // For a struct, add the member offset.
+ ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
+ unsigned FieldNo = Index->getZExtValue();
+ const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo);
+
+ // Add the field offset to the running total offset.
+ TotalOffset = getAddExpr(TotalOffset, FieldOffset);
+
+ // Update CurTy to the type of the field at Index.
+ CurTy = STy->getTypeAtIndex(Index);
+ } else {
+ // Update CurTy to its element type.
+ CurTy = cast<SequentialType>(CurTy)->getElementType();
+ // For an array, add the element offset, explicitly scaled.
+ const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy);
+ // Getelementptr indices are signed.
+ IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy);
+
+ // Multiply the index by the element size to compute the element offset.
+ const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap);
+
+ // Add the element offset to the running total offset.
+ TotalOffset = getAddExpr(TotalOffset, LocalOffset);
+ }
+ }
+
+ // Add the total offset from all the GEP indices to the base.
+ return getAddExpr(BaseExpr, TotalOffset, Wrap);
+}
+
+std::tuple<const SCEV *, FoldingSetNodeID, void *>
+ScalarEvolution::findExistingSCEVInCache(int SCEVType,
+ ArrayRef<const SCEV *> Ops) {
+ FoldingSetNodeID ID;
+ void *IP = nullptr;
+ ID.AddInteger(SCEVType);
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ ID.AddPointer(Ops[i]);
+ return std::tuple<const SCEV *, FoldingSetNodeID, void *>(
+ UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP);
+}
+
+const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind,
+ SmallVectorImpl<const SCEV *> &Ops) {
+ assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
+ if (Ops.size() == 1) return Ops[0];
+#ifndef NDEBUG
+ Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
+ for (unsigned i = 1, e = Ops.size(); i != e; ++i)
+ assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
+ "Operand types don't match!");
+#endif
+
+ bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
+ bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
+
+ // Sort by complexity, this groups all similar expression types together.
+ GroupByComplexity(Ops, &LI, DT);
+
+ // Check if we have created the same expression before.
+ if (const SCEV *S = std::get<0>(findExistingSCEVInCache(Kind, Ops))) {
+ return S;
+ }
+
+ // If there are any constants, fold them together.
+ unsigned Idx = 0;
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+ ++Idx;
+ assert(Idx < Ops.size());
+ auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
+ if (Kind == scSMaxExpr)
+ return APIntOps::smax(LHS, RHS);
+ else if (Kind == scSMinExpr)
+ return APIntOps::smin(LHS, RHS);
+ else if (Kind == scUMaxExpr)
+ return APIntOps::umax(LHS, RHS);
+ else if (Kind == scUMinExpr)
+ return APIntOps::umin(LHS, RHS);
+ llvm_unreachable("Unknown SCEV min/max opcode");
+ };
+
+ while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+ // We found two constants, fold them together!
+ ConstantInt *Fold = ConstantInt::get(
+ getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
+ Ops[0] = getConstant(Fold);
+ Ops.erase(Ops.begin()+1); // Erase the folded element
+ if (Ops.size() == 1) return Ops[0];
+ LHSC = cast<SCEVConstant>(Ops[0]);
+ }
+
+ bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
+ bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
+
+ if (IsMax ? IsMinV : IsMaxV) {
+ // If we are left with a constant minimum(/maximum)-int, strip it off.
+ Ops.erase(Ops.begin());
+ --Idx;
+ } else if (IsMax ? IsMaxV : IsMinV) {
+ // If we have a max(/min) with a constant maximum(/minimum)-int,
+ // it will always be the extremum.
+ return LHSC;
+ }
+
+ if (Ops.size() == 1) return Ops[0];
+ }
+
+ // Find the first operation of the same kind
+ while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
+ ++Idx;
+
+ // Check to see if one of the operands is of the same kind. If so, expand its
+ // operands onto our operand list, and recurse to simplify.
+ if (Idx < Ops.size()) {
+ bool DeletedAny = false;
+ while (Ops[Idx]->getSCEVType() == Kind) {
+ const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
+ Ops.erase(Ops.begin()+Idx);
+ Ops.append(SMME->op_begin(), SMME->op_end());
+ DeletedAny = true;
+ }
+
+ if (DeletedAny)
+ return getMinMaxExpr(Kind, Ops);
+ }
+
+ // Okay, check to see if the same value occurs in the operand list twice. If
+ // so, delete one. Since we sorted the list, these values are required to
+ // be adjacent.
+ llvm::CmpInst::Predicate GEPred =
+ IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
+ llvm::CmpInst::Predicate LEPred =
+ IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
+ llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
+ llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
+ for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
+ if (Ops[i] == Ops[i + 1] ||
+ isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
+ // X op Y op Y --> X op Y
+ // X op Y --> X, if we know X, Y are ordered appropriately
+ Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
+ --i;
+ --e;
+ } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
+ Ops[i + 1])) {
+ // X op Y --> Y, if we know X, Y are ordered appropriately
+ Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
+ --i;
+ --e;
+ }
+ }
+
+ if (Ops.size() == 1) return Ops[0];
+
+ assert(!Ops.empty() && "Reduced smax down to nothing!");
+
+ // Okay, it looks like we really DO need an expr. Check to see if we
+ // already have one, otherwise create a new one.
+ const SCEV *ExistingSCEV;
+ FoldingSetNodeID ID;
+ void *IP;
+ std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(Kind, Ops);
+ if (ExistingSCEV)
+ return ExistingSCEV;
+ const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
+ std::uninitialized_copy(Ops.begin(), Ops.end(), O);
+ SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr(
+ ID.Intern(SCEVAllocator), static_cast<SCEVTypes>(Kind), O, Ops.size());
+
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return S;
+}
+
+const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
+ return getSMaxExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ return getMinMaxExpr(scSMaxExpr, Ops);
+}
+
+const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
+ return getUMaxExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ return getMinMaxExpr(scUMaxExpr, Ops);
+}
+
+const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getSMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ return getMinMaxExpr(scSMinExpr, Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getUMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ return getMinMaxExpr(scUMinExpr, Ops);
+}
+
+const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
+ // We can bypass creating a target-independent
+ // constant expression and then folding it back into a ConstantInt.
+ // This is just a compile-time optimization.
+ return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
+}
+
+const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
+ StructType *STy,
+ unsigned FieldNo) {
+ // We can bypass creating a target-independent
+ // constant expression and then folding it back into a ConstantInt.
+ // This is just a compile-time optimization.
+ return getConstant(
+ IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
+}
+
+const SCEV *ScalarEvolution::getUnknown(Value *V) {
+ // Don't attempt to do anything other than create a SCEVUnknown object
+ // here. createSCEV only calls getUnknown after checking for all other
+ // interesting possibilities, and any other code that calls getUnknown
+ // is doing so in order to hide a value from SCEV canonicalization.
+
+ FoldingSetNodeID ID;
+ ID.AddInteger(scUnknown);
+ ID.AddPointer(V);
+ void *IP = nullptr;
+ if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
+ assert(cast<SCEVUnknown>(S)->getValue() == V &&
+ "Stale SCEVUnknown in uniquing map!");
+ return S;
+ }
+ SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
+ FirstUnknown);
+ FirstUnknown = cast<SCEVUnknown>(S);
+ UniqueSCEVs.InsertNode(S, IP);
+ return S;
+}
+
+//===----------------------------------------------------------------------===//
+// Basic SCEV Analysis and PHI Idiom Recognition Code
+//
+
+/// Test if values of the given type are analyzable within the SCEV
+/// framework. This primarily includes integer types, and it can optionally
+/// include pointer types if the ScalarEvolution class has access to
+/// target-specific information.
+bool ScalarEvolution::isSCEVable(Type *Ty) const {
+ // Integers and pointers are always SCEVable.
+ return Ty->isIntOrPtrTy();
+}
+
+/// Return the size in bits of the specified type, for which isSCEVable must
+/// return true.
+uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
+ assert(isSCEVable(Ty) && "Type is not SCEVable!");
+ if (Ty->isPointerTy())
+ return getDataLayout().getIndexTypeSizeInBits(Ty);
+ return getDataLayout().getTypeSizeInBits(Ty);
+}
+
+/// Return a type with the same bitwidth as the given type and which represents
+/// how SCEV will treat the given type, for which isSCEVable must return
+/// true. For pointer types, this is the pointer-sized integer type.
+Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
+ assert(isSCEVable(Ty) && "Type is not SCEVable!");
+
+ if (Ty->isIntegerTy())
+ return Ty;
+
+ // The only other support type is pointer.
+ assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
+ return getDataLayout().getIntPtrType(Ty);
+}
+
+Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
+ return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
+}
+
+const SCEV *ScalarEvolution::getCouldNotCompute() {
+ return CouldNotCompute.get();
+}
+
+bool ScalarEvolution::checkValidity(const SCEV *S) const {
+ bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
+ auto *SU = dyn_cast<SCEVUnknown>(S);
+ return SU && SU->getValue() == nullptr;
+ });
+
+ return !ContainsNulls;
+}
+
+bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
+ HasRecMapType::iterator I = HasRecMap.find(S);
+ if (I != HasRecMap.end())
+ return I->second;
+
+ bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>);
+ HasRecMap.insert({S, FoundAddRec});
+ return FoundAddRec;
+}
+
+/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
+/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
+/// offset I, then return {S', I}, else return {\p S, nullptr}.
+static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
+ const auto *Add = dyn_cast<SCEVAddExpr>(S);
+ if (!Add)
+ return {S, nullptr};
+
+ if (Add->getNumOperands() != 2)
+ return {S, nullptr};
+
+ auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
+ if (!ConstOp)
+ return {S, nullptr};
+
+ return {Add->getOperand(1), ConstOp->getValue()};
+}
+
+/// Return the ValueOffsetPair set for \p S. \p S can be represented
+/// by the value and offset from any ValueOffsetPair in the set.
+SetVector<ScalarEvolution::ValueOffsetPair> *
+ScalarEvolution::getSCEVValues(const SCEV *S) {
+ ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
+ if (SI == ExprValueMap.end())
+ return nullptr;
+#ifndef NDEBUG
+ if (VerifySCEVMap) {
+ // Check there is no dangling Value in the set returned.
+ for (const auto &VE : SI->second)
+ assert(ValueExprMap.count(VE.first));
+ }
+#endif
+ return &SI->second;
+}
+
+/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
+/// cannot be used separately. eraseValueFromMap should be used to remove
+/// V from ValueExprMap and ExprValueMap at the same time.
+void ScalarEvolution::eraseValueFromMap(Value *V) {
+ ValueExprMapType::iterator I = ValueExprMap.find_as(V);
+ if (I != ValueExprMap.end()) {
+ const SCEV *S = I->second;
+ // Remove {V, 0} from the set of ExprValueMap[S]
+ if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S))
+ SV->remove({V, nullptr});
+
+ // Remove {V, Offset} from the set of ExprValueMap[Stripped]
+ const SCEV *Stripped;
+ ConstantInt *Offset;
+ std::tie(Stripped, Offset) = splitAddExpr(S);
+ if (Offset != nullptr) {
+ if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped))
+ SV->remove({V, Offset});
+ }
+ ValueExprMap.erase(V);
+ }
+}
+
+/// Check whether value has nuw/nsw/exact set but SCEV does not.
+/// TODO: In reality it is better to check the poison recursively
+/// but this is better than nothing.
+static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) {
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ if (isa<OverflowingBinaryOperator>(I)) {
+ if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
+ if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
+ return true;
+ if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
+ return true;
+ }
+ } else if (isa<PossiblyExactOperator>(I) && I->isExact())
+ return true;
+ }
+ return false;
+}
+
+/// Return an existing SCEV if it exists, otherwise analyze the expression and
+/// create a new one.
+const SCEV *ScalarEvolution::getSCEV(Value *V) {
+ assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
+
+ const SCEV *S = getExistingSCEV(V);
+ if (S == nullptr) {
+ S = createSCEV(V);
+ // During PHI resolution, it is possible to create two SCEVs for the same
+ // V, so it is needed to double check whether V->S is inserted into
+ // ValueExprMap before insert S->{V, 0} into ExprValueMap.
+ std::pair<ValueExprMapType::iterator, bool> Pair =
+ ValueExprMap.insert({SCEVCallbackVH(V, this), S});
+ if (Pair.second && !SCEVLostPoisonFlags(S, V)) {
+ ExprValueMap[S].insert({V, nullptr});
+
+ // If S == Stripped + Offset, add Stripped -> {V, Offset} into
+ // ExprValueMap.
+ const SCEV *Stripped = S;
+ ConstantInt *Offset = nullptr;
+ std::tie(Stripped, Offset) = splitAddExpr(S);
+ // If stripped is SCEVUnknown, don't bother to save
+ // Stripped -> {V, offset}. It doesn't simplify and sometimes even
+ // increase the complexity of the expansion code.
+ // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
+ // because it may generate add/sub instead of GEP in SCEV expansion.
+ if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
+ !isa<GetElementPtrInst>(V))
+ ExprValueMap[Stripped].insert({V, Offset});
+ }
+ }
+ return S;
+}
+
+const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
+ assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
+
+ ValueExprMapType::iterator I = ValueExprMap.find_as(V);
+ if (I != ValueExprMap.end()) {
+ const SCEV *S = I->second;
+ if (checkValidity(S))
+ return S;
+ eraseValueFromMap(V);
+ forgetMemoizedResults(S);
+ }
+ return nullptr;
+}
+
+/// Return a SCEV corresponding to -V = -1*V
+const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
+ SCEV::NoWrapFlags Flags) {
+ if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
+ return getConstant(
+ cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
+
+ Type *Ty = V->getType();
+ Ty = getEffectiveSCEVType(Ty);
+ return getMulExpr(
+ V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
+}
+
+/// If Expr computes ~A, return A else return nullptr
+static const SCEV *MatchNotExpr(const SCEV *Expr) {
+ const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
+ if (!Add || Add->getNumOperands() != 2 ||
+ !Add->getOperand(0)->isAllOnesValue())
+ return nullptr;
+
+ const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
+ if (!AddRHS || AddRHS->getNumOperands() != 2 ||
+ !AddRHS->getOperand(0)->isAllOnesValue())
+ return nullptr;
+
+ return AddRHS->getOperand(1);
+}
+
+/// Return a SCEV corresponding to ~V = -1-V
+const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
+ if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
+ return getConstant(
+ cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
+
+ // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
+ if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
+ auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
+ SmallVector<const SCEV *, 2> MatchedOperands;
+ for (const SCEV *Operand : MME->operands()) {
+ const SCEV *Matched = MatchNotExpr(Operand);
+ if (!Matched)
+ return (const SCEV *)nullptr;
+ MatchedOperands.push_back(Matched);
+ }
+ return getMinMaxExpr(
+ SCEVMinMaxExpr::negate(static_cast<SCEVTypes>(MME->getSCEVType())),
+ MatchedOperands);
+ };
+ if (const SCEV *Replaced = MatchMinMaxNegation(MME))
+ return Replaced;
+ }
+
+ Type *Ty = V->getType();
+ Ty = getEffectiveSCEVType(Ty);
+ const SCEV *AllOnes =
+ getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
+ return getMinusSCEV(AllOnes, V);
+}
+
+const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
+ SCEV::NoWrapFlags Flags,
+ unsigned Depth) {
+ // Fast path: X - X --> 0.
+ if (LHS == RHS)
+ return getZero(LHS->getType());
+
+ // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
+ // makes it so that we cannot make much use of NUW.
+ auto AddFlags = SCEV::FlagAnyWrap;
+ const bool RHSIsNotMinSigned =
+ !getSignedRangeMin(RHS).isMinSignedValue();
+ if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
+ // Let M be the minimum representable signed value. Then (-1)*RHS
+ // signed-wraps if and only if RHS is M. That can happen even for
+ // a NSW subtraction because e.g. (-1)*M signed-wraps even though
+ // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
+ // (-1)*RHS, we need to prove that RHS != M.
+ //
+ // If LHS is non-negative and we know that LHS - RHS does not
+ // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
+ // either by proving that RHS > M or that LHS >= 0.
+ if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
+ AddFlags = SCEV::FlagNSW;
+ }
+ }
+
+ // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
+ // RHS is NSW and LHS >= 0.
+ //
+ // The difficulty here is that the NSW flag may have been proven
+ // relative to a loop that is to be found in a recurrence in LHS and
+ // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
+ // larger scope than intended.
+ auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
+
+ return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
+}
+
+const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
+ unsigned Depth) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot truncate or zero extend with non-integer arguments!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
+ return getTruncateExpr(V, Ty, Depth);
+ return getZeroExtendExpr(V, Ty, Depth);
+}
+
+const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
+ unsigned Depth) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot truncate or zero extend with non-integer arguments!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
+ return getTruncateExpr(V, Ty, Depth);
+ return getSignExtendExpr(V, Ty, Depth);
+}
+
+const SCEV *
+ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot noop or zero extend with non-integer arguments!");
+ assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
+ "getNoopOrZeroExtend cannot truncate!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ return getZeroExtendExpr(V, Ty);
+}
+
+const SCEV *
+ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot noop or sign extend with non-integer arguments!");
+ assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
+ "getNoopOrSignExtend cannot truncate!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ return getSignExtendExpr(V, Ty);
+}
+
+const SCEV *
+ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot noop or any extend with non-integer arguments!");
+ assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
+ "getNoopOrAnyExtend cannot truncate!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ return getAnyExtendExpr(V, Ty);
+}
+
+const SCEV *
+ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
+ Type *SrcTy = V->getType();
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ "Cannot truncate or noop with non-integer arguments!");
+ assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
+ "getTruncateOrNoop cannot extend!");
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+ return V; // No conversion
+ return getTruncateExpr(V, Ty);
+}
+
+const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
+ const SCEV *RHS) {
+ const SCEV *PromotedLHS = LHS;
+ const SCEV *PromotedRHS = RHS;
+
+ if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
+ PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
+ else
+ PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
+
+ return getUMaxExpr(PromotedLHS, PromotedRHS);
+}
+
+const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
+ const SCEV *RHS) {
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getUMinFromMismatchedTypes(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(
+ SmallVectorImpl<const SCEV *> &Ops) {
+ assert(!Ops.empty() && "At least one operand must be!");
+ // Trivial case.
+ if (Ops.size() == 1)
+ return Ops[0];
+
+ // Find the max type first.
+ Type *MaxType = nullptr;
+ for (auto *S : Ops)
+ if (MaxType)
+ MaxType = getWiderType(MaxType, S->getType());
+ else
+ MaxType = S->getType();
+
+ // Extend all ops to max type.
+ SmallVector<const SCEV *, 2> PromotedOps;
+ for (auto *S : Ops)
+ PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
+
+ // Generate umin.
+ return getUMinExpr(PromotedOps);
+}
+
+const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
+ // A pointer operand may evaluate to a nonpointer expression, such as null.
+ if (!V->getType()->isPointerTy())
+ return V;
+
+ if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
+ return getPointerBase(Cast->getOperand());
+ } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
+ const SCEV *PtrOp = nullptr;
+ for (const SCEV *NAryOp : NAry->operands()) {
+ if (NAryOp->getType()->isPointerTy()) {
+ // Cannot find the base of an expression with multiple pointer operands.
+ if (PtrOp)
+ return V;
+ PtrOp = NAryOp;
+ }
+ }
+ if (!PtrOp)
+ return V;
+ return getPointerBase(PtrOp);
+ }
+ return V;
+}
+
+/// Push users of the given Instruction onto the given Worklist.
+static void
+PushDefUseChildren(Instruction *I,
+ SmallVectorImpl<Instruction *> &Worklist) {
+ // Push the def-use children onto the Worklist stack.
+ for (User *U : I->users())
+ Worklist.push_back(cast<Instruction>(U));
+}
+
+void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
+ SmallVector<Instruction *, 16> Worklist;
+ PushDefUseChildren(PN, Worklist);
+
+ SmallPtrSet<Instruction *, 8> Visited;
+ Visited.insert(PN);
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+ if (!Visited.insert(I).second)
+ continue;
+
+ auto It = ValueExprMap.find_as(static_cast<Value *>(I));
+ if (It != ValueExprMap.end()) {
+ const SCEV *Old = It->second;
+
+ // Short-circuit the def-use traversal if the symbolic name
+ // ceases to appear in expressions.
+ if (Old != SymName && !hasOperand(Old, SymName))
+ continue;
+
+ // SCEVUnknown for a PHI either means that it has an unrecognized
+ // structure, it's a PHI that's in the progress of being computed
+ // by createNodeForPHI, or it's a single-value PHI. In the first case,
+ // additional loop trip count information isn't going to change anything.
+ // In the second case, createNodeForPHI will perform the necessary
+ // updates on its own when it gets to that point. In the third, we do
+ // want to forget the SCEVUnknown.
+ if (!isa<PHINode>(I) ||
+ !isa<SCEVUnknown>(Old) ||
+ (I != PN && Old == SymName)) {
+ eraseValueFromMap(It->first);
+ forgetMemoizedResults(Old);
+ }
+ }
+
+ PushDefUseChildren(I, Worklist);
+ }
+}
+
+namespace {
+
+/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
+/// expression in case its Loop is L. If it is not L then
+/// if IgnoreOtherLoops is true then use AddRec itself
+/// otherwise rewrite cannot be done.
+/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
+class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
+ bool IgnoreOtherLoops = true) {
+ SCEVInitRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(S);
+ if (Rewriter.hasSeenLoopVariantSCEVUnknown())
+ return SE.getCouldNotCompute();
+ return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
+ ? SE.getCouldNotCompute()
+ : Result;
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (!SE.isLoopInvariant(Expr, L))
+ SeenLoopVariantSCEVUnknown = true;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ // Only re-write AddRecExprs for this loop.
+ if (Expr->getLoop() == L)
+ return Expr->getStart();
+ SeenOtherLoops = true;
+ return Expr;
+ }
+
+ bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
+
+ bool hasSeenOtherLoops() { return SeenOtherLoops; }
+
+private:
+ explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
+ const Loop *L;
+ bool SeenLoopVariantSCEVUnknown = false;
+ bool SeenOtherLoops = false;
+};
+
+/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
+/// increment expression in case its Loop is L. If it is not L then
+/// use AddRec itself.
+/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
+class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
+ SCEVPostIncRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(S);
+ return Rewriter.hasSeenLoopVariantSCEVUnknown()
+ ? SE.getCouldNotCompute()
+ : Result;
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (!SE.isLoopInvariant(Expr, L))
+ SeenLoopVariantSCEVUnknown = true;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ // Only re-write AddRecExprs for this loop.
+ if (Expr->getLoop() == L)
+ return Expr->getPostIncExpr(SE);
+ SeenOtherLoops = true;
+ return Expr;
+ }
+
+ bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
+
+ bool hasSeenOtherLoops() { return SeenOtherLoops; }
+
+private:
+ explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
+ const Loop *L;
+ bool SeenLoopVariantSCEVUnknown = false;
+ bool SeenOtherLoops = false;
+};
+
+/// This class evaluates the compare condition by matching it against the
+/// condition of loop latch. If there is a match we assume a true value
+/// for the condition while building SCEV nodes.
+class SCEVBackedgeConditionFolder
+ : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L,
+ ScalarEvolution &SE) {
+ bool IsPosBECond = false;
+ Value *BECond = nullptr;
+ if (BasicBlock *Latch = L->getLoopLatch()) {
+ BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
+ if (BI && BI->isConditional()) {
+ assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
+ "Both outgoing branches should not target same header!");
+ BECond = BI->getCondition();
+ IsPosBECond = BI->getSuccessor(0) == L->getHeader();
+ } else {
+ return S;
+ }
+ }
+ SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
+ return Rewriter.visit(S);
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ const SCEV *Result = Expr;
+ bool InvariantF = SE.isLoopInvariant(Expr, L);
+
+ if (!InvariantF) {
+ Instruction *I = cast<Instruction>(Expr->getValue());
+ switch (I->getOpcode()) {
+ case Instruction::Select: {
+ SelectInst *SI = cast<SelectInst>(I);
+ Optional<const SCEV *> Res =
+ compareWithBackedgeCondition(SI->getCondition());
+ if (Res.hasValue()) {
+ bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
+ Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
+ }
+ break;
+ }
+ default: {
+ Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
+ if (Res.hasValue())
+ Result = Res.getValue();
+ break;
+ }
+ }
+ }
+ return Result;
+ }
+
+private:
+ explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
+ bool IsPosBECond, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
+ IsPositiveBECond(IsPosBECond) {}
+
+ Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
+
+ const Loop *L;
+ /// Loop back condition.
+ Value *BackedgeCond = nullptr;
+ /// Set to true if loop back is on positive branch condition.
+ bool IsPositiveBECond;
+};
+
+Optional<const SCEV *>
+SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
+
+ // If value matches the backedge condition for loop latch,
+ // then return a constant evolution node based on loopback
+ // branch taken.
+ if (BackedgeCond == IC)
+ return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
+ : SE.getZero(Type::getInt1Ty(SE.getContext()));
+ return None;
+}
+
+class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L,
+ ScalarEvolution &SE) {
+ SCEVShiftRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(S);
+ return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ // Only allow AddRecExprs for this loop.
+ if (!SE.isLoopInvariant(Expr, L))
+ Valid = false;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ if (Expr->getLoop() == L && Expr->isAffine())
+ return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
+ Valid = false;
+ return Expr;
+ }
+
+ bool isValid() { return Valid; }
+
+private:
+ explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
+ const Loop *L;
+ bool Valid = true;
+};
+
+} // end anonymous namespace
+
+SCEV::NoWrapFlags
+ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
+ if (!AR->isAffine())
+ return SCEV::FlagAnyWrap;
+
+ using OBO = OverflowingBinaryOperator;
+
+ SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
+
+ if (!AR->hasNoSignedWrap()) {
+ ConstantRange AddRecRange = getSignedRange(AR);
+ ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
+
+ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, IncRange, OBO::NoSignedWrap);
+ if (NSWRegion.contains(AddRecRange))
+ Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
+ }
+
+ if (!AR->hasNoUnsignedWrap()) {
+ ConstantRange AddRecRange = getUnsignedRange(AR);
+ ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
+
+ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, IncRange, OBO::NoUnsignedWrap);
+ if (NUWRegion.contains(AddRecRange))
+ Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
+ }
+
+ return Result;
+}
+
+namespace {
+
+/// Represents an abstract binary operation. This may exist as a
+/// normal instruction or constant expression, or may have been
+/// derived from an expression tree.
+struct BinaryOp {
+ unsigned Opcode;
+ Value *LHS;
+ Value *RHS;
+ bool IsNSW = false;
+ bool IsNUW = false;
+
+ /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
+ /// constant expression.
+ Operator *Op = nullptr;
+
+ explicit BinaryOp(Operator *Op)
+ : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
+ Op(Op) {
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
+ IsNSW = OBO->hasNoSignedWrap();
+ IsNUW = OBO->hasNoUnsignedWrap();
+ }
+ }
+
+ explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
+ bool IsNUW = false)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
+};
+
+} // end anonymous namespace
+
+/// Try to map \p V into a BinaryOp, and return \c None on failure.
+static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
+ auto *Op = dyn_cast<Operator>(V);
+ if (!Op)
+ return None;
+
+ // Implementation detail: all the cleverness here should happen without
+ // creating new SCEV expressions -- our caller knowns tricks to avoid creating
+ // SCEV expressions when possible, and we should not break that.
+
+ switch (Op->getOpcode()) {
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Mul:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::AShr:
+ case Instruction::Shl:
+ return BinaryOp(Op);
+
+ case Instruction::Xor:
+ if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
+ // If the RHS of the xor is a signmask, then this is just an add.
+ // Instcombine turns add of signmask into xor as a strength reduction step.
+ if (RHSC->getValue().isSignMask())
+ return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
+ return BinaryOp(Op);
+
+ case Instruction::LShr:
+ // Turn logical shift right of a constant into a unsigned divide.
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
+ uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().ult(BitWidth)) {
+ Constant *X =
+ ConstantInt::get(SA->getContext(),
+ APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+ return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
+ }
+ }
+ return BinaryOp(Op);
+
+ case Instruction::ExtractValue: {
+ auto *EVI = cast<ExtractValueInst>(Op);
+ if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
+ break;
+
+ auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
+ if (!WO)
+ break;
+
+ Instruction::BinaryOps BinOp = WO->getBinaryOp();
+ bool Signed = WO->isSigned();
+ // TODO: Should add nuw/nsw flags for mul as well.
+ if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
+ return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
+
+ // Now that we know that all uses of the arithmetic-result component of
+ // CI are guarded by the overflow check, we can go ahead and pretend
+ // that the arithmetic is non-overflowing.
+ return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
+ /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
+ }
+
+ default:
+ break;
+ }
+
+ return None;
+}
+
+/// Helper function to createAddRecFromPHIWithCasts. We have a phi
+/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
+/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
+/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
+/// follows one of the following patterns:
+/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// If the SCEV expression of \p Op conforms with one of the expected patterns
+/// we return the type of the truncation operation, and indicate whether the
+/// truncated type should be treated as signed/unsigned by setting
+/// \p Signed to true/false, respectively.
+static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
+ bool &Signed, ScalarEvolution &SE) {
+ // The case where Op == SymbolicPHI (that is, with no type conversions on
+ // the way) is handled by the regular add recurrence creating logic and
+ // would have already been triggered in createAddRecForPHI. Reaching it here
+ // means that createAddRecFromPHI had failed for this PHI before (e.g.,
+ // because one of the other operands of the SCEVAddExpr updating this PHI is
+ // not invariant).
+ //
+ // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
+ // this case predicates that allow us to prove that Op == SymbolicPHI will
+ // be added.
+ if (Op == SymbolicPHI)
+ return nullptr;
+
+ unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
+ unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
+ if (SourceBits != NewBits)
+ return nullptr;
+
+ const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
+ const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
+ if (!SExt && !ZExt)
+ return nullptr;
+ const SCEVTruncateExpr *Trunc =
+ SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
+ : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
+ if (!Trunc)
+ return nullptr;
+ const SCEV *X = Trunc->getOperand();
+ if (X != SymbolicPHI)
+ return nullptr;
+ Signed = SExt != nullptr;
+ return Trunc->getType();
+}
+
+static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
+ if (!PN->getType()->isIntegerTy())
+ return nullptr;
+ const Loop *L = LI.getLoopFor(PN->getParent());
+ if (!L || L->getHeader() != PN->getParent())
+ return nullptr;
+ return L;
+}
+
+// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
+// computation that updates the phi follows the following pattern:
+// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
+// which correspond to a phi->trunc->sext/zext->add->phi update chain.
+// If so, try to see if it can be rewritten as an AddRecExpr under some
+// Predicates. If successful, return them as a pair. Also cache the results
+// of the analysis.
+//
+// Example usage scenario:
+// Say the Rewriter is called for the following SCEV:
+// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+// where:
+// %X = phi i64 (%Start, %BEValue)
+// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
+// and call this function with %SymbolicPHI = %X.
+//
+// The analysis will find that the value coming around the backedge has
+// the following SCEV:
+// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+// Upon concluding that this matches the desired pattern, the function
+// will return the pair {NewAddRec, SmallPredsVec} where:
+// NewAddRec = {%Start,+,%Step}
+// SmallPredsVec = {P1, P2, P3} as follows:
+// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
+// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
+// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
+// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
+// under the predicates {P1,P2,P3}.
+// This predicated rewrite will be cached in PredicatedSCEVRewrites:
+// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
+//
+// TODO's:
+//
+// 1) Extend the Induction descriptor to also support inductions that involve
+// casts: When needed (namely, when we are called in the context of the
+// vectorizer induction analysis), a Set of cast instructions will be
+// populated by this method, and provided back to isInductionPHI. This is
+// needed to allow the vectorizer to properly record them to be ignored by
+// the cost model and to avoid vectorizing them (otherwise these casts,
+// which are redundant under the runtime overflow checks, will be
+// vectorized, which can be costly).
+//
+// 2) Support additional induction/PHISCEV patterns: We also want to support
+// inductions where the sext-trunc / zext-trunc operations (partly) occur
+// after the induction update operation (the induction increment):
+//
+// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
+// which correspond to a phi->add->trunc->sext/zext->phi update chain.
+//
+// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
+// which correspond to a phi->trunc->add->sext/zext->phi update chain.
+//
+// 3) Outline common code with createAddRecFromPHI to avoid duplication.
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
+ SmallVector<const SCEVPredicate *, 3> Predicates;
+
+ // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
+ // return an AddRec expression under some predicate.
+
+ auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+ const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+ assert(L && "Expecting an integer loop header phi");
+
+ // The loop may have multiple entrances or multiple exits; we can analyze
+ // this phi as an addrec if it has a unique entry value and a unique
+ // backedge value.
+ Value *BEValueV = nullptr, *StartValueV = nullptr;
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ Value *V = PN->getIncomingValue(i);
+ if (L->contains(PN->getIncomingBlock(i))) {
+ if (!BEValueV) {
+ BEValueV = V;
+ } else if (BEValueV != V) {
+ BEValueV = nullptr;
+ break;
+ }
+ } else if (!StartValueV) {
+ StartValueV = V;
+ } else if (StartValueV != V) {
+ StartValueV = nullptr;
+ break;
+ }
+ }
+ if (!BEValueV || !StartValueV)
+ return None;
+
+ const SCEV *BEValue = getSCEV(BEValueV);
+
+ // If the value coming around the backedge is an add with the symbolic
+ // value we just inserted, possibly with casts that we can ignore under
+ // an appropriate runtime guard, then we found a simple induction variable!
+ const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
+ if (!Add)
+ return None;
+
+ // If there is a single occurrence of the symbolic value, possibly
+ // casted, replace it with a recurrence.
+ unsigned FoundIndex = Add->getNumOperands();
+ Type *TruncTy = nullptr;
+ bool Signed;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if ((TruncTy =
+ isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
+ if (FoundIndex == e) {
+ FoundIndex = i;
+ break;
+ }
+
+ if (FoundIndex == Add->getNumOperands())
+ return None;
+
+ // Create an add with everything but the specified operand.
+ SmallVector<const SCEV *, 8> Ops;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (i != FoundIndex)
+ Ops.push_back(Add->getOperand(i));
+ const SCEV *Accum = getAddExpr(Ops);
+
+ // The runtime checks will not be valid if the step amount is
+ // varying inside the loop.
+ if (!isLoopInvariant(Accum, L))
+ return None;
+
+ // *** Part2: Create the predicates
+
+ // Analysis was successful: we have a phi-with-cast pattern for which we
+ // can return an AddRec expression under the following predicates:
+ //
+ // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
+ // fits within the truncated type (does not overflow) for i = 0 to n-1.
+ // P2: An Equal predicate that guarantees that
+ // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
+ // P3: An Equal predicate that guarantees that
+ // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
+ //
+ // As we next prove, the above predicates guarantee that:
+ // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
+ //
+ //
+ // More formally, we want to prove that:
+ // Expr(i+1) = Start + (i+1) * Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ //
+ // Given that:
+ // 1) Expr(0) = Start
+ // 2) Expr(1) = Start + Accum
+ // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
+ // 3) Induction hypothesis (step i):
+ // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
+ //
+ // Proof:
+ // Expr(i+1) =
+ // = Start + (i+1)*Accum
+ // = (Start + i*Accum) + Accum
+ // = Expr(i) + Accum
+ // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
+ // :: from step i
+ //
+ // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
+ //
+ // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
+ // + (Ext ix (Trunc iy (Accum) to ix) to iy)
+ // + Accum :: from P3
+ //
+ // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
+ // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
+ //
+ // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ //
+ // By induction, the same applies to all iterations 1<=i<n:
+ //
+
+ // Create a truncated addrec for which we will add a no overflow check (P1).
+ const SCEV *StartVal = getSCEV(StartValueV);
+ const SCEV *PHISCEV =
+ getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
+ getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
+
+ // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
+ // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
+ // will be constant.
+ //
+ // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
+ // add P1.
+ if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
+ Signed ? SCEVWrapPredicate::IncrementNSSW
+ : SCEVWrapPredicate::IncrementNUSW;
+ const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
+ Predicates.push_back(AddRecPred);
+ }
+
+ // Create the Equal Predicates P2,P3:
+
+ // It is possible that the predicates P2 and/or P3 are computable at
+ // compile time due to StartVal and/or Accum being constants.
+ // If either one is, then we can check that now and escape if either P2
+ // or P3 is false.
+
+ // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
+ // for each of StartVal and Accum
+ auto getExtendedExpr = [&](const SCEV *Expr,
+ bool CreateSignExtend) -> const SCEV * {
+ assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
+ const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
+ const SCEV *ExtendedExpr =
+ CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
+ : getZeroExtendExpr(TruncatedExpr, Expr->getType());
+ return ExtendedExpr;
+ };
+
+ // Given:
+ // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
+ // = getExtendedExpr(Expr)
+ // Determine whether the predicate P: Expr == ExtendedExpr
+ // is known to be false at compile time
+ auto PredIsKnownFalse = [&](const SCEV *Expr,
+ const SCEV *ExtendedExpr) -> bool {
+ return Expr != ExtendedExpr &&
+ isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
+ };
+
+ const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
+ if (PredIsKnownFalse(StartVal, StartExtended)) {
+ LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
+ return None;
+ }
+
+ // The Step is always Signed (because the overflow checks are either
+ // NSSW or NUSW)
+ const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
+ if (PredIsKnownFalse(Accum, AccumExtended)) {
+ LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
+ return None;
+ }
+
+ auto AppendPredicate = [&](const SCEV *Expr,
+ const SCEV *ExtendedExpr) -> void {
+ if (Expr != ExtendedExpr &&
+ !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
+ const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
+ LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
+ Predicates.push_back(Pred);
+ }
+ };
+
+ AppendPredicate(StartVal, StartExtended);
+ AppendPredicate(Accum, AccumExtended);
+
+ // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
+ // which the casts had been folded away. The caller can rewrite SymbolicPHI
+ // into NewAR if it will also add the runtime overflow checks specified in
+ // Predicates.
+ auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
+
+ std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
+ std::make_pair(NewAR, Predicates);
+ // Remember the result of the analysis for this SCEV at this locayyytion.
+ PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
+ return PredRewrite;
+}
+
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
+ auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+ const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+ if (!L)
+ return None;
+
+ // Check to see if we already analyzed this PHI.
+ auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
+ if (I != PredicatedSCEVRewrites.end()) {
+ std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
+ I->second;
+ // Analysis was done before and failed to create an AddRec:
+ if (Rewrite.first == SymbolicPHI)
+ return None;
+ // Analysis was done before and succeeded to create an AddRec under
+ // a predicate:
+ assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
+ assert(!(Rewrite.second).empty() && "Expected to find Predicates");
+ return Rewrite;
+ }
+
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
+
+ // Record in the cache that the analysis failed
+ if (!Rewrite) {
+ SmallVector<const SCEVPredicate *, 3> Predicates;
+ PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
+ return None;
+ }
+
+ return Rewrite;
+}
+
+// FIXME: This utility is currently required because the Rewriter currently
+// does not rewrite this expression:
+// {0, +, (sext ix (trunc iy to ix) to iy)}
+// into {0, +, %step},
+// even when the following Equal predicate exists:
+// "%step == (sext ix (trunc iy to ix) to iy)".
+bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
+ const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
+ if (AR1 == AR2)
+ return true;
+
+ auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
+ if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
+ !Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
+ return false;
+ return true;
+ };
+
+ if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
+ !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
+ return false;
+ return true;
+}
+
+/// A helper function for createAddRecFromPHI to handle simple cases.
+///
+/// This function tries to find an AddRec expression for the simplest (yet most
+/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
+/// If it fails, createAddRecFromPHI will use a more general, but slow,
+/// technique for finding the AddRec expression.
+const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
+ Value *BEValueV,
+ Value *StartValueV) {
+ const Loop *L = LI.getLoopFor(PN->getParent());
+ assert(L && L->getHeader() == PN->getParent());
+ assert(BEValueV && StartValueV);
+
+ auto BO = MatchBinaryOp(BEValueV, DT);
+ if (!BO)
+ return nullptr;
+
+ if (BO->Opcode != Instruction::Add)
+ return nullptr;
+
+ const SCEV *Accum = nullptr;
+ if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
+ Accum = getSCEV(BO->RHS);
+ else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
+ Accum = getSCEV(BO->LHS);
+
+ if (!Accum)
+ return nullptr;
+
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+ if (BO->IsNUW)
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ if (BO->IsNSW)
+ Flags = setFlags(Flags, SCEV::FlagNSW);
+
+ const SCEV *StartVal = getSCEV(StartValueV);
+ const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
+
+ ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
+
+ // We can add Flags to the post-inc expression only if we
+ // know that it is *undefined behavior* for BEValueV to
+ // overflow.
+ if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
+ if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
+ (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
+
+ return PHISCEV;
+}
+
+const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
+ const Loop *L = LI.getLoopFor(PN->getParent());
+ if (!L || L->getHeader() != PN->getParent())
+ return nullptr;
+
+ // The loop may have multiple entrances or multiple exits; we can analyze
+ // this phi as an addrec if it has a unique entry value and a unique
+ // backedge value.
+ Value *BEValueV = nullptr, *StartValueV = nullptr;
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ Value *V = PN->getIncomingValue(i);
+ if (L->contains(PN->getIncomingBlock(i))) {
+ if (!BEValueV) {
+ BEValueV = V;
+ } else if (BEValueV != V) {
+ BEValueV = nullptr;
+ break;
+ }
+ } else if (!StartValueV) {
+ StartValueV = V;
+ } else if (StartValueV != V) {
+ StartValueV = nullptr;
+ break;
+ }
+ }
+ if (!BEValueV || !StartValueV)
+ return nullptr;
+
+ assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
+ "PHI node already processed?");
+
+ // First, try to find AddRec expression without creating a fictituos symbolic
+ // value for PN.
+ if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
+ return S;
+
+ // Handle PHI node value symbolically.
+ const SCEV *SymbolicName = getUnknown(PN);
+ ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName});
+
+ // Using this symbolic name for the PHI, analyze the value coming around
+ // the back-edge.
+ const SCEV *BEValue = getSCEV(BEValueV);
+
+ // NOTE: If BEValue is loop invariant, we know that the PHI node just
+ // has a special value for the first iteration of the loop.
+
+ // If the value coming around the backedge is an add with the symbolic
+ // value we just inserted, then we found a simple induction variable!
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
+ // If there is a single occurrence of the symbolic value, replace it
+ // with a recurrence.
+ unsigned FoundIndex = Add->getNumOperands();
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (Add->getOperand(i) == SymbolicName)
+ if (FoundIndex == e) {
+ FoundIndex = i;
+ break;
+ }
+
+ if (FoundIndex != Add->getNumOperands()) {
+ // Create an add with everything but the specified operand.
+ SmallVector<const SCEV *, 8> Ops;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (i != FoundIndex)
+ Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
+ L, *this));
+ const SCEV *Accum = getAddExpr(Ops);
+
+ // This is not a valid addrec if the step amount is varying each
+ // loop iteration, but is not itself an addrec in this loop.
+ if (isLoopInvariant(Accum, L) ||
+ (isa<SCEVAddRecExpr>(Accum) &&
+ cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+
+ if (auto BO = MatchBinaryOp(BEValueV, DT)) {
+ if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
+ if (BO->IsNUW)
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ if (BO->IsNSW)
+ Flags = setFlags(Flags, SCEV::FlagNSW);
+ }
+ } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
+ // If the increment is an inbounds GEP, then we know the address
+ // space cannot be wrapped around. We cannot make any guarantee
+ // about signed or unsigned overflow because pointers are
+ // unsigned but we may have a negative index from the base
+ // pointer. We can guarantee that no unsigned wrap occurs if the
+ // indices form a positive value.
+ if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
+ Flags = setFlags(Flags, SCEV::FlagNW);
+
+ const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
+ if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ }
+
+ // We cannot transfer nuw and nsw flags from subtraction
+ // operations -- sub nuw X, Y is not the same as add nuw X, -Y
+ // for instance.
+ }
+
+ const SCEV *StartVal = getSCEV(StartValueV);
+ const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
+
+ // Okay, for the entire analysis of this edge we assumed the PHI
+ // to be symbolic. We now need to go back and purge all of the
+ // entries for the scalars that use the symbolic expression.
+ forgetSymbolicName(PN, SymbolicName);
+ ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
+
+ // We can add Flags to the post-inc expression only if we
+ // know that it is *undefined behavior* for BEValueV to
+ // overflow.
+ if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
+ if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
+ (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
+
+ return PHISCEV;
+ }
+ }
+ } else {
+ // Otherwise, this could be a loop like this:
+ // i = 0; for (j = 1; ..; ++j) { .... i = j; }
+ // In this case, j = {1,+,1} and BEValue is j.
+ // Because the other in-value of i (0) fits the evolution of BEValue
+ // i really is an addrec evolution.
+ //
+ // We can generalize this saying that i is the shifted value of BEValue
+ // by one iteration:
+ // PHI(f(0), f({1,+,1})) --> f({0,+,1})
+ const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
+ const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
+ if (Shifted != getCouldNotCompute() &&
+ Start != getCouldNotCompute()) {
+ const SCEV *StartVal = getSCEV(StartValueV);
+ if (Start == StartVal) {
+ // Okay, for the entire analysis of this edge we assumed the PHI
+ // to be symbolic. We now need to go back and purge all of the
+ // entries for the scalars that use the symbolic expression.
+ forgetSymbolicName(PN, SymbolicName);
+ ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
+ return Shifted;
+ }
+ }
+ }
+
+ // Remove the temporary PHI node SCEV that has been inserted while intending
+ // to create an AddRecExpr for this PHI node. We can not keep this temporary
+ // as it will prevent later (possibly simpler) SCEV expressions to be added
+ // to the ValueExprMap.
+ eraseValueFromMap(PN);
+
+ return nullptr;
+}
+
+// Checks if the SCEV S is available at BB. S is considered available at BB
+// if S can be materialized at BB without introducing a fault.
+static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
+ BasicBlock *BB) {
+ struct CheckAvailable {
+ bool TraversalDone = false;
+ bool Available = true;
+
+ const Loop *L = nullptr; // The loop BB is in (can be nullptr)
+ BasicBlock *BB = nullptr;
+ DominatorTree &DT;
+
+ CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
+ : L(L), BB(BB), DT(DT) {}
+
+ bool setUnavailable() {
+ TraversalDone = true;
+ Available = false;
+ return false;
+ }
+
+ bool follow(const SCEV *S) {
+ switch (S->getSCEVType()) {
+ case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
+ case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr:
+ // These expressions are available if their operand(s) is/are.
+ return true;
+
+ case scAddRecExpr: {
+ // We allow add recurrences that are on the loop BB is in, or some
+ // outer loop. This guarantees availability because the value of the
+ // add recurrence at BB is simply the "current" value of the induction
+ // variable. We can relax this in the future; for instance an add
+ // recurrence on a sibling dominating loop is also available at BB.
+ const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
+ if (L && (ARLoop == L || ARLoop->contains(L)))
+ return true;
+
+ return setUnavailable();
+ }
+
+ case scUnknown: {
+ // For SCEVUnknown, we check for simple dominance.
+ const auto *SU = cast<SCEVUnknown>(S);
+ Value *V = SU->getValue();
+
+ if (isa<Argument>(V))
+ return false;
+
+ if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
+ return false;
+
+ return setUnavailable();
+ }
+
+ case scUDivExpr:
+ case scCouldNotCompute:
+ // We do not try to smart about these at all.
+ return setUnavailable();
+ }
+ llvm_unreachable("switch should be fully covered!");
+ }
+
+ bool isDone() { return TraversalDone; }
+ };
+
+ CheckAvailable CA(L, BB, DT);
+ SCEVTraversal<CheckAvailable> ST(CA);
+
+ ST.visitAll(S);
+ return CA.Available;
+}
+
+// Try to match a control flow sequence that branches out at BI and merges back
+// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
+// match.
+static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
+ Value *&C, Value *&LHS, Value *&RHS) {
+ C = BI->getCondition();
+
+ BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
+ BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
+
+ if (!LeftEdge.isSingleEdge())
+ return false;
+
+ assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
+
+ Use &LeftUse = Merge->getOperandUse(0);
+ Use &RightUse = Merge->getOperandUse(1);
+
+ if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
+ LHS = LeftUse;
+ RHS = RightUse;
+ return true;
+ }
+
+ if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
+ LHS = RightUse;
+ RHS = LeftUse;
+ return true;
+ }
+
+ return false;
+}
+
+const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
+ auto IsReachable =
+ [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
+ if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
+ const Loop *L = LI.getLoopFor(PN->getParent());
+
+ // We don't want to break LCSSA, even in a SCEV expression tree.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+ return nullptr;
+
+ // Try to match
+ //
+ // br %cond, label %left, label %right
+ // left:
+ // br label %merge
+ // right:
+ // br label %merge
+ // merge:
+ // V = phi [ %x, %left ], [ %y, %right ]
+ //
+ // as "select %cond, %x, %y"
+
+ BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
+ assert(IDom && "At least the entry block should dominate PN");
+
+ auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
+ Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
+
+ if (BI && BI->isConditional() &&
+ BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
+ IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
+ IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
+ return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
+ }
+
+ return nullptr;
+}
+
+const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
+ if (const SCEV *S = createAddRecFromPHI(PN))
+ return S;
+
+ if (const SCEV *S = createNodeFromSelectLikePHI(PN))
+ return S;
+
+ // If the PHI has a single incoming value, follow that value, unless the
+ // PHI's incoming blocks are in a different loop, in which case doing so
+ // risks breaking LCSSA form. Instcombine would normally zap these, but
+ // it doesn't have DominatorTree information, so it may miss cases.
+ if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
+ if (LI.replacementPreservesLCSSAForm(PN, V))
+ return getSCEV(V);
+
+ // If it's not a loop phi, we can't handle it yet.
+ return getUnknown(PN);
+}
+
+const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
+ Value *Cond,
+ Value *TrueVal,
+ Value *FalseVal) {
+ // Handle "constant" branch or select. This can occur for instance when a
+ // loop pass transforms an inner loop and moves on to process the outer loop.
+ if (auto *CI = dyn_cast<ConstantInt>(Cond))
+ return getSCEV(CI->isOne() ? TrueVal : FalseVal);
+
+ // Try to match some simple smax or umax patterns.
+ auto *ICI = dyn_cast<ICmpInst>(Cond);
+ if (!ICI)
+ return getUnknown(I);
+
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+
+ switch (ICI->getPredicate()) {
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ // a >s b ? a+x : b+x -> smax(a, b)+x
+ // a >s b ? b+x : a+x -> smin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ // a >u b ? a+x : b+x -> umax(a, b)+x
+ // a >u b ? b+x : a+x -> umin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_NE:
+ // n != 0 ? n+x : 1+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, One);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_EQ:
+ // n == 0 ? 1+x : n+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, One);
+ const SCEV *RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return getUnknown(I);
+}
+
+/// Expand GEP instructions into add and multiply operations. This allows them
+/// to be analyzed by regular SCEV code.
+const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
+ // Don't attempt to analyze GEPs over unsized objects.
+ if (!GEP->getSourceElementType()->isSized())
+ return getUnknown(GEP);
+
+ SmallVector<const SCEV *, 4> IndexExprs;
+ for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
+ IndexExprs.push_back(getSCEV(*Index));
+ return getGEPExpr(GEP, IndexExprs);
+}
+
+uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+ return C->getAPInt().countTrailingZeros();
+
+ if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
+ return std::min(GetMinTrailingZeros(T->getOperand()),
+ (uint32_t)getTypeSizeInBits(T->getType()));
+
+ if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+ ? getTypeSizeInBits(E->getType())
+ : OpRes;
+ }
+
+ if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+ ? getTypeSizeInBits(E->getType())
+ : OpRes;
+ }
+
+ if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
+ }
+
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
+ // The result is the sum of all operands results.
+ uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+ uint32_t BitWidth = getTypeSizeInBits(M->getType());
+ for (unsigned i = 1, e = M->getNumOperands();
+ SumOpRes != BitWidth && i != e; ++i)
+ SumOpRes =
+ std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
+ return SumOpRes;
+ }
+
+ if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
+ }
+
+ if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
+ for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
+ return MinOpRes;
+ }
+
+ if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
+ for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
+ return MinOpRes;
+ }
+
+ if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ // For a SCEVUnknown, ask ValueTracking.
+ KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
+ return Known.countMinTrailingZeros();
+ }
+
+ // SCEVUDivExpr
+ return 0;
+}
+
+uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+ auto I = MinTrailingZerosCache.find(S);
+ if (I != MinTrailingZerosCache.end())
+ return I->second;
+
+ uint32_t Result = GetMinTrailingZerosImpl(S);
+ auto InsertPair = MinTrailingZerosCache.insert({S, Result});
+ assert(InsertPair.second && "Should insert a new key");
+ return InsertPair.first->second;
+}
+
+/// Helper method to assign a range to V from metadata present in the IR.
+static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
+ return getConstantRangeFromMetadata(*MD);
+
+ return None;
+}
+
+/// Determine the range for a particular SCEV. If SignHint is
+/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
+/// with a "cleaner" unsigned (resp. signed) representation.
+const ConstantRange &
+ScalarEvolution::getRangeRef(const SCEV *S,
+ ScalarEvolution::RangeSignHint SignHint) {
+ DenseMap<const SCEV *, ConstantRange> &Cache =
+ SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
+ : SignedRanges;
+ ConstantRange::PreferredRangeType RangeType =
+ SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
+ ? ConstantRange::Unsigned : ConstantRange::Signed;
+
+ // See if we've computed this range already.
+ DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
+ if (I != Cache.end())
+ return I->second;
+
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+ return setRange(C, SignHint, ConstantRange(C->getAPInt()));
+
+ unsigned BitWidth = getTypeSizeInBits(S->getType());
+ ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
+
+ // If the value has known zeros, the maximum value will have those known zeros
+ // as well.
+ uint32_t TZ = GetMinTrailingZeros(S);
+ if (TZ != 0) {
+ if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
+ ConservativeResult =
+ ConstantRange(APInt::getMinValue(BitWidth),
+ APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
+ else
+ ConservativeResult = ConstantRange(
+ APInt::getSignedMinValue(BitWidth),
+ APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
+ }
+
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
+ ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
+ for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
+ X = X.add(getRangeRef(Add->getOperand(i), SignHint));
+ return setRange(Add, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
+ ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
+ for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
+ X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
+ return setRange(Mul, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
+ ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint);
+ for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
+ X = X.smax(getRangeRef(SMax->getOperand(i), SignHint));
+ return setRange(SMax, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
+ ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint);
+ for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
+ X = X.umax(getRangeRef(UMax->getOperand(i), SignHint));
+ return setRange(UMax, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVSMinExpr *SMin = dyn_cast<SCEVSMinExpr>(S)) {
+ ConstantRange X = getRangeRef(SMin->getOperand(0), SignHint);
+ for (unsigned i = 1, e = SMin->getNumOperands(); i != e; ++i)
+ X = X.smin(getRangeRef(SMin->getOperand(i), SignHint));
+ return setRange(SMin, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVUMinExpr *UMin = dyn_cast<SCEVUMinExpr>(S)) {
+ ConstantRange X = getRangeRef(UMin->getOperand(0), SignHint);
+ for (unsigned i = 1, e = UMin->getNumOperands(); i != e; ++i)
+ X = X.umin(getRangeRef(UMin->getOperand(i), SignHint));
+ return setRange(UMin, SignHint,
+ ConservativeResult.intersectWith(X, RangeType));
+ }
+
+ if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
+ ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
+ ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
+ return setRange(UDiv, SignHint,
+ ConservativeResult.intersectWith(X.udiv(Y), RangeType));
+ }
+
+ if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
+ ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
+ return setRange(ZExt, SignHint,
+ ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
+ RangeType));
+ }
+
+ if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
+ ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
+ return setRange(SExt, SignHint,
+ ConservativeResult.intersectWith(X.signExtend(BitWidth),
+ RangeType));
+ }
+
+ if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
+ ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
+ return setRange(Trunc, SignHint,
+ ConservativeResult.intersectWith(X.truncate(BitWidth),
+ RangeType));
+ }
+
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
+ // If there's no unsigned wrap, the value will never be less than its
+ // initial value.
+ if (AddRec->hasNoUnsignedWrap())
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
+ if (!C->getValue()->isZero())
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(C->getAPInt(), APInt(BitWidth, 0)), RangeType);
+
+ // If there's no signed wrap, and all the operands have the same sign or
+ // zero, the value won't ever change sign.
+ if (AddRec->hasNoSignedWrap()) {
+ bool AllNonNeg = true;
+ bool AllNonPos = true;
+ for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
+ if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
+ if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
+ }
+ if (AllNonNeg)
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(APInt(BitWidth, 0),
+ APInt::getSignedMinValue(BitWidth)), RangeType);
+ else if (AllNonPos)
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(APInt::getSignedMinValue(BitWidth),
+ APInt(BitWidth, 1)), RangeType);
+ }
+
+ // TODO: non-affine addrec
+ if (AddRec->isAffine()) {
+ const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
+ if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
+ getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
+ auto RangeFromAffine = getRangeForAffineAR(
+ AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
+ BitWidth);
+ if (!RangeFromAffine.isFullSet())
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromAffine, RangeType);
+
+ auto RangeFromFactoring = getRangeViaFactoring(
+ AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
+ BitWidth);
+ if (!RangeFromFactoring.isFullSet())
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
+ }
+ }
+
+ return setRange(AddRec, SignHint, std::move(ConservativeResult));
+ }
+
+ if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ // Check if the IR explicitly contains !range metadata.
+ Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
+ if (MDRange.hasValue())
+ ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(),
+ RangeType);
+
+ // Split here to avoid paying the compile-time cost of calling both
+ // computeKnownBits and ComputeNumSignBits. This restriction can be lifted
+ // if needed.
+ const DataLayout &DL = getDataLayout();
+ if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
+ // For a SCEVUnknown, ask ValueTracking.
+ KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
+ if (Known.One != ~Known.Zero + 1)
+ ConservativeResult =
+ ConservativeResult.intersectWith(
+ ConstantRange(Known.One, ~Known.Zero + 1), RangeType);
+ } else {
+ assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
+ "generalize as needed!");
+ unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
+ if (NS > 1)
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
+ APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
+ RangeType);
+ }
+
+ // A range of Phi is a subset of union of all ranges of its input.
+ if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
+ // Make sure that we do not run over cycled Phis.
+ if (PendingPhiRanges.insert(Phi).second) {
+ ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
+ for (auto &Op : Phi->operands()) {
+ auto OpRange = getRangeRef(getSCEV(Op), SignHint);
+ RangeFromOps = RangeFromOps.unionWith(OpRange);
+ // No point to continue if we already have a full set.
+ if (RangeFromOps.isFullSet())
+ break;
+ }
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromOps, RangeType);
+ bool Erased = PendingPhiRanges.erase(Phi);
+ assert(Erased && "Failed to erase Phi properly?");
+ (void) Erased;
+ }
+ }
+
+ return setRange(U, SignHint, std::move(ConservativeResult));
+ }
+
+ return setRange(S, SignHint, std::move(ConservativeResult));
+}
+
+// Given a StartRange, Step and MaxBECount for an expression compute a range of
+// values that the expression can take. Initially, the expression has a value
+// from StartRange and then is changed by Step up to MaxBECount times. Signed
+// argument defines if we treat Step as signed or unsigned.
+static ConstantRange getRangeForAffineARHelper(APInt Step,
+ const ConstantRange &StartRange,
+ const APInt &MaxBECount,
+ unsigned BitWidth, bool Signed) {
+ // If either Step or MaxBECount is 0, then the expression won't change, and we
+ // just need to return the initial range.
+ if (Step == 0 || MaxBECount == 0)
+ return StartRange;
+
+ // If we don't know anything about the initial value (i.e. StartRange is
+ // FullRange), then we don't know anything about the final range either.
+ // Return FullRange.
+ if (StartRange.isFullSet())
+ return ConstantRange::getFull(BitWidth);
+
+ // If Step is signed and negative, then we use its absolute value, but we also
+ // note that we're moving in the opposite direction.
+ bool Descending = Signed && Step.isNegative();
+
+ if (Signed)
+ // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
+ // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
+ // This equations hold true due to the well-defined wrap-around behavior of
+ // APInt.
+ Step = Step.abs();
+
+ // Check if Offset is more than full span of BitWidth. If it is, the
+ // expression is guaranteed to overflow.
+ if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
+ return ConstantRange::getFull(BitWidth);
+
+ // Offset is by how much the expression can change. Checks above guarantee no
+ // overflow here.
+ APInt Offset = Step * MaxBECount;
+
+ // Minimum value of the final range will match the minimal value of StartRange
+ // if the expression is increasing and will be decreased by Offset otherwise.
+ // Maximum value of the final range will match the maximal value of StartRange
+ // if the expression is decreasing and will be increased by Offset otherwise.
+ APInt StartLower = StartRange.getLower();
+ APInt StartUpper = StartRange.getUpper() - 1;
+ APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
+ : (StartUpper + std::move(Offset));
+
+ // It's possible that the new minimum/maximum value will fall into the initial
+ // range (due to wrap around). This means that the expression can take any
+ // value in this bitwidth, and we have to return full range.
+ if (StartRange.contains(MovedBoundary))
+ return ConstantRange::getFull(BitWidth);
+
+ APInt NewLower =
+ Descending ? std::move(MovedBoundary) : std::move(StartLower);
+ APInt NewUpper =
+ Descending ? std::move(StartUpper) : std::move(MovedBoundary);
+ NewUpper += 1;
+
+ // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
+ return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
+}
+
+ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
+ const SCEV *Step,
+ const SCEV *MaxBECount,
+ unsigned BitWidth) {
+ assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
+ getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
+ "Precondition!");
+
+ MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
+ APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
+
+ // First, consider step signed.
+ ConstantRange StartSRange = getSignedRange(Start);
+ ConstantRange StepSRange = getSignedRange(Step);
+
+ // If Step can be both positive and negative, we need to find ranges for the
+ // maximum absolute step values in both directions and union them.
+ ConstantRange SR =
+ getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
+ MaxBECountValue, BitWidth, /* Signed = */ true);
+ SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
+ StartSRange, MaxBECountValue,
+ BitWidth, /* Signed = */ true));
+
+ // Next, consider step unsigned.
+ ConstantRange UR = getRangeForAffineARHelper(
+ getUnsignedRangeMax(Step), getUnsignedRange(Start),
+ MaxBECountValue, BitWidth, /* Signed = */ false);
+
+ // Finally, intersect signed and unsigned ranges.
+ return SR.intersectWith(UR, ConstantRange::Smallest);
+}
+
+ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
+ const SCEV *Step,
+ const SCEV *MaxBECount,
+ unsigned BitWidth) {
+ // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
+ // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
+
+ struct SelectPattern {
+ Value *Condition = nullptr;
+ APInt TrueValue;
+ APInt FalseValue;
+
+ explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
+ const SCEV *S) {
+ Optional<unsigned> CastOp;
+ APInt Offset(BitWidth, 0);
+
+ assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
+ "Should be!");
+
+ // Peel off a constant offset:
+ if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
+ // In the future we could consider being smarter here and handle
+ // {Start+Step,+,Step} too.
+ if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
+ return;
+
+ Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
+ S = SA->getOperand(1);
+ }
+
+ // Peel off a cast operation
+ if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) {
+ CastOp = SCast->getSCEVType();
+ S = SCast->getOperand();
+ }
+
+ using namespace llvm::PatternMatch;
+
+ auto *SU = dyn_cast<SCEVUnknown>(S);
+ const APInt *TrueVal, *FalseVal;
+ if (!SU ||
+ !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
+ m_APInt(FalseVal)))) {
+ Condition = nullptr;
+ return;
+ }
+
+ TrueValue = *TrueVal;
+ FalseValue = *FalseVal;
+
+ // Re-apply the cast we peeled off earlier
+ if (CastOp.hasValue())
+ switch (*CastOp) {
+ default:
+ llvm_unreachable("Unknown SCEV cast type!");
+
+ case scTruncate:
+ TrueValue = TrueValue.trunc(BitWidth);
+ FalseValue = FalseValue.trunc(BitWidth);
+ break;
+ case scZeroExtend:
+ TrueValue = TrueValue.zext(BitWidth);
+ FalseValue = FalseValue.zext(BitWidth);
+ break;
+ case scSignExtend:
+ TrueValue = TrueValue.sext(BitWidth);
+ FalseValue = FalseValue.sext(BitWidth);
+ break;
+ }
+
+ // Re-apply the constant offset we peeled off earlier
+ TrueValue += Offset;
+ FalseValue += Offset;
+ }
+
+ bool isRecognized() { return Condition != nullptr; }
+ };
+
+ SelectPattern StartPattern(*this, BitWidth, Start);
+ if (!StartPattern.isRecognized())
+ return ConstantRange::getFull(BitWidth);
+
+ SelectPattern StepPattern(*this, BitWidth, Step);
+ if (!StepPattern.isRecognized())
+ return ConstantRange::getFull(BitWidth);
+
+ if (StartPattern.Condition != StepPattern.Condition) {
+ // We don't handle this case today; but we could, by considering four
+ // possibilities below instead of two. I'm not sure if there are cases where
+ // that will help over what getRange already does, though.
+ return ConstantRange::getFull(BitWidth);
+ }
+
+ // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
+ // construct arbitrary general SCEV expressions here. This function is called
+ // from deep in the call stack, and calling getSCEV (on a sext instruction,
+ // say) can end up caching a suboptimal value.
+
+ // FIXME: without the explicit `this` receiver below, MSVC errors out with
+ // C2352 and C2512 (otherwise it isn't needed).
+
+ const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
+ const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
+ const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
+ const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
+
+ ConstantRange TrueRange =
+ this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
+ ConstantRange FalseRange =
+ this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
+
+ return TrueRange.unionWith(FalseRange);
+}
+
+SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
+ if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
+ const BinaryOperator *BinOp = cast<BinaryOperator>(V);
+
+ // Return early if there are no flags to propagate to the SCEV.
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+ if (BinOp->hasNoUnsignedWrap())
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+ if (BinOp->hasNoSignedWrap())
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+ if (Flags == SCEV::FlagAnyWrap)
+ return SCEV::FlagAnyWrap;
+
+ return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
+}
+
+bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
+ // Here we check that I is in the header of the innermost loop containing I,
+ // since we only deal with instructions in the loop header. The actual loop we
+ // need to check later will come from an add recurrence, but getting that
+ // requires computing the SCEV of the operands, which can be expensive. This
+ // check we can do cheaply to rule out some cases early.
+ Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent());
+ if (InnermostContainingLoop == nullptr ||
+ InnermostContainingLoop->getHeader() != I->getParent())
+ return false;
+
+ // Only proceed if we can prove that I does not yield poison.
+ if (!programUndefinedIfFullPoison(I))
+ return false;
+
+ // At this point we know that if I is executed, then it does not wrap
+ // according to at least one of NSW or NUW. If I is not executed, then we do
+ // not know if the calculation that I represents would wrap. Multiple
+ // instructions can map to the same SCEV. If we apply NSW or NUW from I to
+ // the SCEV, we must guarantee no wrapping for that SCEV also when it is
+ // derived from other instructions that map to the same SCEV. We cannot make
+ // that guarantee for cases where I is not executed. So we need to find the
+ // loop that I is considered in relation to and prove that I is executed for
+ // every iteration of that loop. That implies that the value that I
+ // calculates does not wrap anywhere in the loop, so then we can apply the
+ // flags to the SCEV.
+ //
+ // We check isLoopInvariant to disambiguate in case we are adding recurrences
+ // from different loops, so that we know which loop to prove that I is
+ // executed in.
+ for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) {
+ // I could be an extractvalue from a call to an overflow intrinsic.
+ // TODO: We can do better here in some cases.
+ if (!isSCEVable(I->getOperand(OpIndex)->getType()))
+ return false;
+ const SCEV *Op = getSCEV(I->getOperand(OpIndex));
+ if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
+ bool AllOtherOpsLoopInvariant = true;
+ for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands();
+ ++OtherOpIndex) {
+ if (OtherOpIndex != OpIndex) {
+ const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex));
+ if (!isLoopInvariant(OtherOp, AddRec->getLoop())) {
+ AllOtherOpsLoopInvariant = false;
+ break;
+ }
+ }
+ }
+ if (AllOtherOpsLoopInvariant &&
+ isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop()))
+ return true;
+ }
+ }
+ return false;
+}
+
+bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
+ // If we know that \c I can never be poison period, then that's enough.
+ if (isSCEVExprNeverPoison(I))
+ return true;
+
+ // For an add recurrence specifically, we assume that infinite loops without
+ // side effects are undefined behavior, and then reason as follows:
+ //
+ // If the add recurrence is poison in any iteration, it is poison on all
+ // future iterations (since incrementing poison yields poison). If the result
+ // of the add recurrence is fed into the loop latch condition and the loop
+ // does not contain any throws or exiting blocks other than the latch, we now
+ // have the ability to "choose" whether the backedge is taken or not (by
+ // choosing a sufficiently evil value for the poison feeding into the branch)
+ // for every iteration including and after the one in which \p I first became
+ // poison. There are two possibilities (let's call the iteration in which \p
+ // I first became poison as K):
+ //
+ // 1. In the set of iterations including and after K, the loop body executes
+ // no side effects. In this case executing the backege an infinte number
+ // of times will yield undefined behavior.
+ //
+ // 2. In the set of iterations including and after K, the loop body executes
+ // at least one side effect. In this case, that specific instance of side
+ // effect is control dependent on poison, which also yields undefined
+ // behavior.
+
+ auto *ExitingBB = L->getExitingBlock();
+ auto *LatchBB = L->getLoopLatch();
+ if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
+ return false;
+
+ SmallPtrSet<const Instruction *, 16> Pushed;
+ SmallVector<const Instruction *, 8> PoisonStack;
+
+ // We start by assuming \c I, the post-inc add recurrence, is poison. Only
+ // things that are known to be fully poison under that assumption go on the
+ // PoisonStack.
+ Pushed.insert(I);
+ PoisonStack.push_back(I);
+
+ bool LatchControlDependentOnPoison = false;
+ while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
+ const Instruction *Poison = PoisonStack.pop_back_val();
+
+ for (auto *PoisonUser : Poison->users()) {
+ if (propagatesFullPoison(cast<Instruction>(PoisonUser))) {
+ if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
+ PoisonStack.push_back(cast<Instruction>(PoisonUser));
+ } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
+ assert(BI->isConditional() && "Only possibility!");
+ if (BI->getParent() == LatchBB) {
+ LatchControlDependentOnPoison = true;
+ break;
+ }
+ }
+ }
+ }
+
+ return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
+}
+
+ScalarEvolution::LoopProperties
+ScalarEvolution::getLoopProperties(const Loop *L) {
+ using LoopProperties = ScalarEvolution::LoopProperties;
+
+ auto Itr = LoopPropertiesCache.find(L);
+ if (Itr == LoopPropertiesCache.end()) {
+ auto HasSideEffects = [](Instruction *I) {
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ return !SI->isSimple();
+
+ return I->mayHaveSideEffects();
+ };
+
+ LoopProperties LP = {/* HasNoAbnormalExits */ true,
+ /*HasNoSideEffects*/ true};
+
+ for (auto *BB : L->getBlocks())
+ for (auto &I : *BB) {
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I))
+ LP.HasNoAbnormalExits = false;
+ if (HasSideEffects(&I))
+ LP.HasNoSideEffects = false;
+ if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
+ break; // We're already as pessimistic as we can get.
+ }
+
+ auto InsertPair = LoopPropertiesCache.insert({L, LP});
+ assert(InsertPair.second && "We just checked!");
+ Itr = InsertPair.first;
+ }
+
+ return Itr->second;
+}
+
+const SCEV *ScalarEvolution::createSCEV(Value *V) {
+ if (!isSCEVable(V->getType()))
+ return getUnknown(V);
+
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ // Don't attempt to analyze instructions in blocks that aren't
+ // reachable. Such instructions don't matter, and they aren't required
+ // to obey basic rules for definitions dominating uses which this
+ // analysis depends on.
+ if (!DT.isReachableFromEntry(I->getParent()))
+ return getUnknown(UndefValue::get(V->getType()));
+ } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
+ return getConstant(CI);
+ else if (isa<ConstantPointerNull>(V))
+ return getZero(V->getType());
+ else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
+ return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
+ else if (!isa<ConstantExpr>(V))
+ return getUnknown(V);
+
+ Operator *U = cast<Operator>(V);
+ if (auto BO = MatchBinaryOp(U, DT)) {
+ switch (BO->Opcode) {
+ case Instruction::Add: {
+ // The simple thing to do would be to just call getSCEV on both operands
+ // and call getAddExpr with the result. However if we're looking at a
+ // bunch of things all added together, this can be quite inefficient,
+ // because it leads to N-1 getAddExpr calls for N ultimate operands.
+ // Instead, gather up all the operands and make a single getAddExpr call.
+ // LLVM IR canonical form means we need only traverse the left operands.
+ SmallVector<const SCEV *, 4> AddOps;
+ do {
+ if (BO->Op) {
+ if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+ AddOps.push_back(OpSCEV);
+ break;
+ }
+
+ // If a NUW or NSW flag can be applied to the SCEV for this
+ // addition, then compute the SCEV for this addition by itself
+ // with a separate call to getAddExpr. We need to do that
+ // instead of pushing the operands of the addition onto AddOps,
+ // since the flags are only known to apply to this particular
+ // addition - they may not apply to other additions that can be
+ // formed with operands from AddOps.
+ const SCEV *RHS = getSCEV(BO->RHS);
+ SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+ if (Flags != SCEV::FlagAnyWrap) {
+ const SCEV *LHS = getSCEV(BO->LHS);
+ if (BO->Opcode == Instruction::Sub)
+ AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
+ else
+ AddOps.push_back(getAddExpr(LHS, RHS, Flags));
+ break;
+ }
+ }
+
+ if (BO->Opcode == Instruction::Sub)
+ AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
+ else
+ AddOps.push_back(getSCEV(BO->RHS));
+
+ auto NewBO = MatchBinaryOp(BO->LHS, DT);
+ if (!NewBO || (NewBO->Opcode != Instruction::Add &&
+ NewBO->Opcode != Instruction::Sub)) {
+ AddOps.push_back(getSCEV(BO->LHS));
+ break;
+ }
+ BO = NewBO;
+ } while (true);
+
+ return getAddExpr(AddOps);
+ }
+
+ case Instruction::Mul: {
+ SmallVector<const SCEV *, 4> MulOps;
+ do {
+ if (BO->Op) {
+ if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+ MulOps.push_back(OpSCEV);
+ break;
+ }
+
+ SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+ if (Flags != SCEV::FlagAnyWrap) {
+ MulOps.push_back(
+ getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
+ break;
+ }
+ }
+
+ MulOps.push_back(getSCEV(BO->RHS));
+ auto NewBO = MatchBinaryOp(BO->LHS, DT);
+ if (!NewBO || NewBO->Opcode != Instruction::Mul) {
+ MulOps.push_back(getSCEV(BO->LHS));
+ break;
+ }
+ BO = NewBO;
+ } while (true);
+
+ return getMulExpr(MulOps);
+ }
+ case Instruction::UDiv:
+ return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
+ case Instruction::URem:
+ return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
+ case Instruction::Sub: {
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+ if (BO->Op)
+ Flags = getNoWrapFlagsFromUB(BO->Op);
+ return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
+ }
+ case Instruction::And:
+ // For an expression like x&255 that merely masks off the high bits,
+ // use zext(trunc(x)) as the SCEV expression.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ if (CI->isZero())
+ return getSCEV(BO->RHS);
+ if (CI->isMinusOne())
+ return getSCEV(BO->LHS);
+ const APInt &A = CI->getValue();
+
+ // Instcombine's ShrinkDemandedConstant may strip bits out of
+ // constants, obscuring what would otherwise be a low-bits mask.
+ // Use computeKnownBits to compute what ShrinkDemandedConstant
+ // knew about to reconstruct a low-bits mask value.
+ unsigned LZ = A.countLeadingZeros();
+ unsigned TZ = A.countTrailingZeros();
+ unsigned BitWidth = A.getBitWidth();
+ KnownBits Known(BitWidth);
+ computeKnownBits(BO->LHS, Known, getDataLayout(),
+ 0, &AC, nullptr, &DT);
+
+ APInt EffectiveMask =
+ APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
+ if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
+ const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
+ const SCEV *LHS = getSCEV(BO->LHS);
+ const SCEV *ShiftedLHS = nullptr;
+ if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
+ if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
+ // For an expression like (x * 8) & 8, simplify the multiply.
+ unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
+ unsigned GCD = std::min(MulZeros, TZ);
+ APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
+ SmallVector<const SCEV*, 4> MulOps;
+ MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
+ MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
+ auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
+ ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
+ }
+ }
+ if (!ShiftedLHS)
+ ShiftedLHS = getUDivExpr(LHS, MulCount);
+ return getMulExpr(
+ getZeroExtendExpr(
+ getTruncateExpr(ShiftedLHS,
+ IntegerType::get(getContext(), BitWidth - LZ - TZ)),
+ BO->LHS->getType()),
+ MulCount);
+ }
+ }
+ break;
+
+ case Instruction::Or:
+ // If the RHS of the Or is a constant, we may have something like:
+ // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
+ // optimizations will transparently handle this case.
+ //
+ // In order for this transformation to be safe, the LHS must be of the
+ // form X*(2^n) and the Or constant must be less than 2^n.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ const SCEV *LHS = getSCEV(BO->LHS);
+ const APInt &CIVal = CI->getValue();
+ if (GetMinTrailingZeros(LHS) >=
+ (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
+ // Build a plain add SCEV.
+ const SCEV *S = getAddExpr(LHS, getSCEV(CI));
+ // If the LHS of the add was an addrec and it has no-wrap flags,
+ // transfer the no-wrap flags, since an or won't introduce a wrap.
+ if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
+ const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
+ const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
+ OldAR->getNoWrapFlags());
+ }
+ return S;
+ }
+ }
+ break;
+
+ case Instruction::Xor:
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ // If the RHS of xor is -1, then this is a not operation.
+ if (CI->isMinusOne())
+ return getNotSCEV(getSCEV(BO->LHS));
+
+ // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
+ // This is a variant of the check for xor with -1, and it handles
+ // the case where instcombine has trimmed non-demanded bits out
+ // of an xor with -1.
+ if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
+ if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
+ if (LBO->getOpcode() == Instruction::And &&
+ LCI->getValue() == CI->getValue())
+ if (const SCEVZeroExtendExpr *Z =
+ dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
+ Type *UTy = BO->LHS->getType();
+ const SCEV *Z0 = Z->getOperand();
+ Type *Z0Ty = Z0->getType();
+ unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
+
+ // If C is a low-bits mask, the zero extend is serving to
+ // mask off the high bits. Complement the operand and
+ // re-apply the zext.
+ if (CI->getValue().isMask(Z0TySize))
+ return getZeroExtendExpr(getNotSCEV(Z0), UTy);
+
+ // If C is a single bit, it may be in the sign-bit position
+ // before the zero-extend. In this case, represent the xor
+ // using an add, which is equivalent, and re-apply the zext.
+ APInt Trunc = CI->getValue().trunc(Z0TySize);
+ if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
+ Trunc.isSignMask())
+ return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
+ UTy);
+ }
+ }
+ break;
+
+ case Instruction::Shl:
+ // Turn shift left of a constant amount into a multiply.
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
+ uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().uge(BitWidth))
+ break;
+
+ // It is currently not resolved how to interpret NSW for left
+ // shift by BitWidth - 1, so we avoid applying flags in that
+ // case. Remove this check (or this comment) once the situation
+ // is resolved. See
+ // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
+ // and http://reviews.llvm.org/D8890 .
+ auto Flags = SCEV::FlagAnyWrap;
+ if (BO->Op && SA->getValue().ult(BitWidth - 1))
+ Flags = getNoWrapFlagsFromUB(BO->Op);
+
+ Constant *X = ConstantInt::get(
+ getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+ return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
+ }
+ break;
+
+ case Instruction::AShr: {
+ // AShr X, C, where C is a constant.
+ ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
+ if (!CI)
+ break;
+
+ Type *OuterTy = BO->LHS->getType();
+ uint64_t BitWidth = getTypeSizeInBits(OuterTy);
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (CI->getValue().uge(BitWidth))
+ break;
+
+ if (CI->isZero())
+ return getSCEV(BO->LHS); // shift by zero --> noop
+
+ uint64_t AShrAmt = CI->getZExtValue();
+ Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
+
+ Operator *L = dyn_cast<Operator>(BO->LHS);
+ if (L && L->getOpcode() == Instruction::Shl) {
+ // X = Shl A, n
+ // Y = AShr X, m
+ // Both n and m are constant.
+
+ const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
+ if (L->getOperand(1) == BO->RHS)
+ // For a two-shift sext-inreg, i.e. n = m,
+ // use sext(trunc(x)) as the SCEV expression.
+ return getSignExtendExpr(
+ getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
+
+ ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
+ if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
+ uint64_t ShlAmt = ShlAmtCI->getZExtValue();
+ if (ShlAmt > AShrAmt) {
+ // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
+ // expression. We already checked that ShlAmt < BitWidth, so
+ // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
+ // ShlAmt - AShrAmt < Amt.
+ APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
+ ShlAmt - AShrAmt);
+ return getSignExtendExpr(
+ getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
+ getConstant(Mul)), OuterTy);
+ }
+ }
+ }
+ break;
+ }
+ }
+ }
+
+ switch (U->getOpcode()) {
+ case Instruction::Trunc:
+ return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::ZExt:
+ return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::SExt:
+ if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
+ // The NSW flag of a subtract does not always survive the conversion to
+ // A + (-1)*B. By pushing sign extension onto its operands we are much
+ // more likely to preserve NSW and allow later AddRec optimisations.
+ //
+ // NOTE: This is effectively duplicating this logic from getSignExtend:
+ // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+ // but by that point the NSW information has potentially been lost.
+ if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
+ Type *Ty = U->getType();
+ auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
+ auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
+ return getMinusSCEV(V1, V2, SCEV::FlagNSW);
+ }
+ }
+ return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+ case Instruction::BitCast:
+ // BitCasts are no-op casts so we just eliminate the cast.
+ if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
+ return getSCEV(U->getOperand(0));
+ break;
+
+ // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
+ // lead to pointer expressions which cannot safely be expanded to GEPs,
+ // because ScalarEvolution doesn't respect the GEP aliasing rules when
+ // simplifying integer expressions.
+
+ case Instruction::GetElementPtr:
+ return createNodeForGEP(cast<GEPOperator>(U));
+
+ case Instruction::PHI:
+ return createNodeForPHI(cast<PHINode>(U));
+
+ case Instruction::Select:
+ // U can also be a select constant expr, which let fall through. Since
+ // createNodeForSelect only works for a condition that is an `ICmpInst`, and
+ // constant expressions cannot have instructions as operands, we'd have
+ // returned getUnknown for a select constant expressions anyway.
+ if (isa<Instruction>(U))
+ return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
+ U->getOperand(1), U->getOperand(2));
+ break;
+
+ case Instruction::Call:
+ case Instruction::Invoke:
+ if (Value *RV = CallSite(U).getReturnedArgOperand())
+ return getSCEV(RV);
+ break;
+ }
+
+ return getUnknown(V);
+}
+
+//===----------------------------------------------------------------------===//
+// Iteration Count Computation Code
+//
+
+static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
+ if (!ExitCount)
+ return 0;
+
+ ConstantInt *ExitConst = ExitCount->getValue();
+
+ // Guard against huge trip counts.
+ if (ExitConst->getValue().getActiveBits() > 32)
+ return 0;
+
+ // In case of integer overflow, this returns 0, which is correct.
+ return ((unsigned)ExitConst->getZExtValue()) + 1;
+}
+
+unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
+ if (BasicBlock *ExitingBB = L->getExitingBlock())
+ return getSmallConstantTripCount(L, ExitingBB);
+
+ // No trip count information for multiple exits.
+ return 0;
+}
+
+unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L,
+ BasicBlock *ExitingBlock) {
+ assert(ExitingBlock && "Must pass a non-null exiting block!");
+ assert(L->isLoopExiting(ExitingBlock) &&
+ "Exiting block must actually branch out of the loop!");
+ const SCEVConstant *ExitCount =
+ dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
+ return getConstantTripCount(ExitCount);
+}
+
+unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
+ const auto *MaxExitCount =
+ dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
+ return getConstantTripCount(MaxExitCount);
+}
+
+unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
+ if (BasicBlock *ExitingBB = L->getExitingBlock())
+ return getSmallConstantTripMultiple(L, ExitingBB);
+
+ // No trip multiple information for multiple exits.
+ return 0;
+}
+
+/// Returns the largest constant divisor of the trip count of this loop as a
+/// normal unsigned value, if possible. This means that the actual trip count is
+/// always a multiple of the returned value (don't forget the trip count could
+/// very well be zero as well!).
+///
+/// Returns 1 if the trip count is unknown or not guaranteed to be the
+/// multiple of a constant (which is also the case if the trip count is simply
+/// constant, use getSmallConstantTripCount for that case), Will also return 1
+/// if the trip count is very large (>= 2^32).
+///
+/// As explained in the comments for getSmallConstantTripCount, this assumes
+/// that control exits the loop via ExitingBlock.
+unsigned
+ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
+ BasicBlock *ExitingBlock) {
+ assert(ExitingBlock && "Must pass a non-null exiting block!");
+ assert(L->isLoopExiting(ExitingBlock) &&
+ "Exiting block must actually branch out of the loop!");
+ const SCEV *ExitCount = getExitCount(L, ExitingBlock);
+ if (ExitCount == getCouldNotCompute())
+ return 1;
+
+ // Get the trip count from the BE count by adding 1.
+ const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType()));
+
+ const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
+ if (!TC)
+ // Attempt to factor more general cases. Returns the greatest power of
+ // two divisor. If overflow happens, the trip count expression is still
+ // divisible by the greatest power of 2 divisor returned.
+ return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr));
+
+ ConstantInt *Result = TC->getValue();
+
+ // Guard against huge trip counts (this requires checking
+ // for zero to handle the case where the trip count == -1 and the
+ // addition wraps).
+ if (!Result || Result->getValue().getActiveBits() > 32 ||
+ Result->getValue().getActiveBits() == 0)
+ return 1;
+
+ return (unsigned)Result->getZExtValue();
+}
+
+/// Get the expression for the number of loop iterations for which this loop is
+/// guaranteed not to exit via ExitingBlock. Otherwise return
+/// SCEVCouldNotCompute.
+const SCEV *ScalarEvolution::getExitCount(const Loop *L,
+ BasicBlock *ExitingBlock) {
+ return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
+}
+
+const SCEV *
+ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
+ SCEVUnionPredicate &Preds) {
+ return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
+}
+
+const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
+ return getBackedgeTakenInfo(L).getExact(L, this);
+}
+
+/// Similar to getBackedgeTakenCount, except return the least SCEV value that is
+/// known never to be less than the actual backedge taken count.
+const SCEV *ScalarEvolution::getConstantMaxBackedgeTakenCount(const Loop *L) {
+ return getBackedgeTakenInfo(L).getMax(this);
+}
+
+bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
+ return getBackedgeTakenInfo(L).isMaxOrZero(this);
+}
+
+/// Push PHI nodes in the header of the given loop onto the given Worklist.
+static void
+PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
+ BasicBlock *Header = L->getHeader();
+
+ // Push all Loop-header PHIs onto the Worklist stack.
+ for (PHINode &PN : Header->phis())
+ Worklist.push_back(&PN);
+}
+
+const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
+ auto &BTI = getBackedgeTakenInfo(L);
+ if (BTI.hasFullInfo())
+ return BTI;
+
+ auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
+
+ if (!Pair.second)
+ return Pair.first->second;
+
+ BackedgeTakenInfo Result =
+ computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
+
+ return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
+}
+
+const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
+ // Initially insert an invalid entry for this loop. If the insertion
+ // succeeds, proceed to actually compute a backedge-taken count and
+ // update the value. The temporary CouldNotCompute value tells SCEV
+ // code elsewhere that it shouldn't attempt to request a new
+ // backedge-taken count, which could result in infinite recursion.
+ std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
+ BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
+ if (!Pair.second)
+ return Pair.first->second;
+
+ // computeBackedgeTakenCount may allocate memory for its result. Inserting it
+ // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
+ // must be cleared in this scope.
+ BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
+
+ // In product build, there are no usage of statistic.
+ (void)NumTripCountsComputed;
+ (void)NumTripCountsNotComputed;
+#if LLVM_ENABLE_STATS || !defined(NDEBUG)
+ const SCEV *BEExact = Result.getExact(L, this);
+ if (BEExact != getCouldNotCompute()) {
+ assert(isLoopInvariant(BEExact, L) &&
+ isLoopInvariant(Result.getMax(this), L) &&
+ "Computed backedge-taken count isn't loop invariant for loop!");
+ ++NumTripCountsComputed;
+ }
+ else if (Result.getMax(this) == getCouldNotCompute() &&
+ isa<PHINode>(L->getHeader()->begin())) {
+ // Only count loops that have phi nodes as not being computable.
+ ++NumTripCountsNotComputed;
+ }
+#endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
+
+ // Now that we know more about the trip count for this loop, forget any
+ // existing SCEV values for PHI nodes in this loop since they are only
+ // conservative estimates made without the benefit of trip count
+ // information. This is similar to the code in forgetLoop, except that
+ // it handles SCEVUnknown PHI nodes specially.
+ if (Result.hasAnyInfo()) {
+ SmallVector<Instruction *, 16> Worklist;
+ PushLoopPHIs(L, Worklist);
+
+ SmallPtrSet<Instruction *, 8> Discovered;
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+
+ ValueExprMapType::iterator It =
+ ValueExprMap.find_as(static_cast<Value *>(I));
+ if (It != ValueExprMap.end()) {
+ const SCEV *Old = It->second;
+
+ // SCEVUnknown for a PHI either means that it has an unrecognized
+ // structure, or it's a PHI that's in the progress of being computed
+ // by createNodeForPHI. In the former case, additional loop trip
+ // count information isn't going to change anything. In the later
+ // case, createNodeForPHI will perform the necessary updates on its
+ // own when it gets to that point.
+ if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
+ eraseValueFromMap(It->first);
+ forgetMemoizedResults(Old);
+ }
+ if (PHINode *PN = dyn_cast<PHINode>(I))
+ ConstantEvolutionLoopExitValue.erase(PN);
+ }
+
+ // Since we don't need to invalidate anything for correctness and we're
+ // only invalidating to make SCEV's results more precise, we get to stop
+ // early to avoid invalidating too much. This is especially important in
+ // cases like:
+ //
+ // %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node
+ // loop0:
+ // %pn0 = phi
+ // ...
+ // loop1:
+ // %pn1 = phi
+ // ...
+ //
+ // where both loop0 and loop1's backedge taken count uses the SCEV
+ // expression for %v. If we don't have the early stop below then in cases
+ // like the above, getBackedgeTakenInfo(loop1) will clear out the trip
+ // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip
+ // count for loop1, effectively nullifying SCEV's trip count cache.
+ for (auto *U : I->users())
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ auto *LoopForUser = LI.getLoopFor(I->getParent());
+ if (LoopForUser && L->contains(LoopForUser) &&
+ Discovered.insert(I).second)
+ Worklist.push_back(I);
+ }
+ }
+ }
+
+ // Re-lookup the insert position, since the call to
+ // computeBackedgeTakenCount above could result in a
+ // recusive call to getBackedgeTakenInfo (on a different
+ // loop), which would invalidate the iterator computed
+ // earlier.
+ return BackedgeTakenCounts.find(L)->second = std::move(Result);
+}
+
+void ScalarEvolution::forgetAllLoops() {
+ // This method is intended to forget all info about loops. It should
+ // invalidate caches as if the following happened:
+ // - The trip counts of all loops have changed arbitrarily
+ // - Every llvm::Value has been updated in place to produce a different
+ // result.
+ BackedgeTakenCounts.clear();
+ PredicatedBackedgeTakenCounts.clear();
+ LoopPropertiesCache.clear();
+ ConstantEvolutionLoopExitValue.clear();
+ ValueExprMap.clear();
+ ValuesAtScopes.clear();
+ LoopDispositions.clear();
+ BlockDispositions.clear();
+ UnsignedRanges.clear();
+ SignedRanges.clear();
+ ExprValueMap.clear();
+ HasRecMap.clear();
+ MinTrailingZerosCache.clear();
+ PredicatedSCEVRewrites.clear();
+}
+
+void ScalarEvolution::forgetLoop(const Loop *L) {
+ // Drop any stored trip count value.
+ auto RemoveLoopFromBackedgeMap =
+ [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) {
+ auto BTCPos = Map.find(L);
+ if (BTCPos != Map.end()) {
+ BTCPos->second.clear();
+ Map.erase(BTCPos);
+ }
+ };
+
+ SmallVector<const Loop *, 16> LoopWorklist(1, L);
+ SmallVector<Instruction *, 32> Worklist;
+ SmallPtrSet<Instruction *, 16> Visited;
+
+ // Iterate over all the loops and sub-loops to drop SCEV information.
+ while (!LoopWorklist.empty()) {
+ auto *CurrL = LoopWorklist.pop_back_val();
+
+ RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL);
+ RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL);
+
+ // Drop information about predicated SCEV rewrites for this loop.
+ for (auto I = PredicatedSCEVRewrites.begin();
+ I != PredicatedSCEVRewrites.end();) {
+ std::pair<const SCEV *, const Loop *> Entry = I->first;
+ if (Entry.second == CurrL)
+ PredicatedSCEVRewrites.erase(I++);
+ else
+ ++I;
+ }
+
+ auto LoopUsersItr = LoopUsers.find(CurrL);
+ if (LoopUsersItr != LoopUsers.end()) {
+ for (auto *S : LoopUsersItr->second)
+ forgetMemoizedResults(S);
+ LoopUsers.erase(LoopUsersItr);
+ }
+
+ // Drop information about expressions based on loop-header PHIs.
+ PushLoopPHIs(CurrL, Worklist);
+
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+ if (!Visited.insert(I).second)
+ continue;
+
+ ValueExprMapType::iterator It =
+ ValueExprMap.find_as(static_cast<Value *>(I));
+ if (It != ValueExprMap.end()) {
+ eraseValueFromMap(It->first);
+ forgetMemoizedResults(It->second);
+ if (PHINode *PN = dyn_cast<PHINode>(I))
+ ConstantEvolutionLoopExitValue.erase(PN);
+ }
+
+ PushDefUseChildren(I, Worklist);
+ }
+
+ LoopPropertiesCache.erase(CurrL);
+ // Forget all contained loops too, to avoid dangling entries in the
+ // ValuesAtScopes map.
+ LoopWorklist.append(CurrL->begin(), CurrL->end());
+ }
+}
+
+void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
+ while (Loop *Parent = L->getParentLoop())
+ L = Parent;
+ forgetLoop(L);
+}
+
+void ScalarEvolution::forgetValue(Value *V) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I) return;
+
+ // Drop information about expressions based on loop-header PHIs.
+ SmallVector<Instruction *, 16> Worklist;
+ Worklist.push_back(I);
+
+ SmallPtrSet<Instruction *, 8> Visited;
+ while (!Worklist.empty()) {
+ I = Worklist.pop_back_val();
+ if (!Visited.insert(I).second)
+ continue;
+
+ ValueExprMapType::iterator It =
+ ValueExprMap.find_as(static_cast<Value *>(I));
+ if (It != ValueExprMap.end()) {
+ eraseValueFromMap(It->first);
+ forgetMemoizedResults(It->second);
+ if (PHINode *PN = dyn_cast<PHINode>(I))
+ ConstantEvolutionLoopExitValue.erase(PN);
+ }
+
+ PushDefUseChildren(I, Worklist);
+ }
+}
+
+/// Get the exact loop backedge taken count considering all loop exits. A
+/// computable result can only be returned for loops with all exiting blocks
+/// dominating the latch. howFarToZero assumes that the limit of each loop test
+/// is never skipped. This is a valid assumption as long as the loop exits via
+/// that test. For precise results, it is the caller's responsibility to specify
+/// the relevant loop exiting block using getExact(ExitingBlock, SE).
+const SCEV *
+ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
+ SCEVUnionPredicate *Preds) const {
+ // If any exits were not computable, the loop is not computable.
+ if (!isComplete() || ExitNotTaken.empty())
+ return SE->getCouldNotCompute();
+
+ const BasicBlock *Latch = L->getLoopLatch();
+ // All exiting blocks we have collected must dominate the only backedge.
+ if (!Latch)
+ return SE->getCouldNotCompute();
+
+ // All exiting blocks we have gathered dominate loop's latch, so exact trip
+ // count is simply a minimum out of all these calculated exit counts.
+ SmallVector<const SCEV *, 2> Ops;
+ for (auto &ENT : ExitNotTaken) {
+ const SCEV *BECount = ENT.ExactNotTaken;
+ assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
+ assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
+ "We should only have known counts for exiting blocks that dominate "
+ "latch!");
+
+ Ops.push_back(BECount);
+
+ if (Preds && !ENT.hasAlwaysTruePredicate())
+ Preds->add(ENT.Predicate.get());
+
+ assert((Preds || ENT.hasAlwaysTruePredicate()) &&
+ "Predicate should be always true!");
+ }
+
+ return SE->getUMinFromMismatchedTypes(Ops);
+}
+
+/// Get the exact not taken count for this loop exit.
+const SCEV *
+ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
+ ScalarEvolution *SE) const {
+ for (auto &ENT : ExitNotTaken)
+ if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
+ return ENT.ExactNotTaken;
+
+ return SE->getCouldNotCompute();
+}
+
+/// getMax - Get the max backedge taken count for the loop.
+const SCEV *
+ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
+ auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+ return !ENT.hasAlwaysTruePredicate();
+ };
+
+ if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax())
+ return SE->getCouldNotCompute();
+
+ assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) &&
+ "No point in having a non-constant max backedge taken count!");
+ return getMax();
+}
+
+bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
+ auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+ return !ENT.hasAlwaysTruePredicate();
+ };
+ return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
+}
+
+bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
+ ScalarEvolution *SE) const {
+ if (getMax() && getMax() != SE->getCouldNotCompute() &&
+ SE->hasOperand(getMax(), S))
+ return true;
+
+ for (auto &ENT : ExitNotTaken)
+ if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&
+ SE->hasOperand(ENT.ExactNotTaken, S))
+ return true;
+
+ return false;
+}
+
+ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
+ : ExactNotTaken(E), MaxNotTaken(E) {
+ assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
+ isa<SCEVConstant>(MaxNotTaken)) &&
+ "No point in having a non-constant max backedge taken count!");
+}
+
+ScalarEvolution::ExitLimit::ExitLimit(
+ const SCEV *E, const SCEV *M, bool MaxOrZero,
+ ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
+ : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
+ assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
+ !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
+ "Exact is not allowed to be less precise than Max");
+ assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
+ isa<SCEVConstant>(MaxNotTaken)) &&
+ "No point in having a non-constant max backedge taken count!");
+ for (auto *PredSet : PredSetList)
+ for (auto *P : *PredSet)
+ addPredicate(P);
+}
+
+ScalarEvolution::ExitLimit::ExitLimit(
+ const SCEV *E, const SCEV *M, bool MaxOrZero,
+ const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
+ : ExitLimit(E, M, MaxOrZero, {&PredSet}) {
+ assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
+ isa<SCEVConstant>(MaxNotTaken)) &&
+ "No point in having a non-constant max backedge taken count!");
+}
+
+ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
+ bool MaxOrZero)
+ : ExitLimit(E, M, MaxOrZero, None) {
+ assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
+ isa<SCEVConstant>(MaxNotTaken)) &&
+ "No point in having a non-constant max backedge taken count!");
+}
+
+/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
+/// computable exit into a persistent ExitNotTakenInfo array.
+ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
+ ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
+ ExitCounts,
+ bool Complete, const SCEV *MaxCount, bool MaxOrZero)
+ : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
+ using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
+
+ ExitNotTaken.reserve(ExitCounts.size());
+ std::transform(
+ ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
+ [&](const EdgeExitInfo &EEI) {
+ BasicBlock *ExitBB = EEI.first;
+ const ExitLimit &EL = EEI.second;
+ if (EL.Predicates.empty())
+ return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
+
+ std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
+ for (auto *Pred : EL.Predicates)
+ Predicate->add(Pred);
+
+ return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
+ });
+ assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) &&
+ "No point in having a non-constant max backedge taken count!");
+}
+
+/// Invalidate this result and free the ExitNotTakenInfo array.
+void ScalarEvolution::BackedgeTakenInfo::clear() {
+ ExitNotTaken.clear();
+}
+
+/// Compute the number of times the backedge of the specified loop will execute.
+ScalarEvolution::BackedgeTakenInfo
+ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
+ bool AllowPredicates) {
+ SmallVector<BasicBlock *, 8> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
+
+ SmallVector<EdgeExitInfo, 4> ExitCounts;
+ bool CouldComputeBECount = true;
+ BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
+ const SCEV *MustExitMaxBECount = nullptr;
+ const SCEV *MayExitMaxBECount = nullptr;
+ bool MustExitMaxOrZero = false;
+
+ // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
+ // and compute maxBECount.
+ // Do a union of all the predicates here.
+ for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
+ BasicBlock *ExitBB = ExitingBlocks[i];
+ ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
+
+ assert((AllowPredicates || EL.Predicates.empty()) &&
+ "Predicated exit limit when predicates are not allowed!");
+
+ // 1. For each exit that can be computed, add an entry to ExitCounts.
+ // CouldComputeBECount is true only if all exits can be computed.
+ if (EL.ExactNotTaken == getCouldNotCompute())
+ // We couldn't compute an exact value for this exit, so
+ // we won't be able to compute an exact value for the loop.
+ CouldComputeBECount = false;
+ else
+ ExitCounts.emplace_back(ExitBB, EL);
+
+ // 2. Derive the loop's MaxBECount from each exit's max number of
+ // non-exiting iterations. Partition the loop exits into two kinds:
+ // LoopMustExits and LoopMayExits.
+ //
+ // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
+ // is a LoopMayExit. If any computable LoopMustExit is found, then
+ // MaxBECount is the minimum EL.MaxNotTaken of computable
+ // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
+ // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
+ // computable EL.MaxNotTaken.
+ if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
+ DT.dominates(ExitBB, Latch)) {
+ if (!MustExitMaxBECount) {
+ MustExitMaxBECount = EL.MaxNotTaken;
+ MustExitMaxOrZero = EL.MaxOrZero;
+ } else {
+ MustExitMaxBECount =
+ getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
+ }
+ } else if (MayExitMaxBECount != getCouldNotCompute()) {
+ if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
+ MayExitMaxBECount = EL.MaxNotTaken;
+ else {
+ MayExitMaxBECount =
+ getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
+ }
+ }
+ }
+ const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
+ (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
+ // The loop backedge will be taken the maximum or zero times if there's
+ // a single exit that must be taken the maximum or zero times.
+ bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
+ return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
+ MaxBECount, MaxOrZero);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+ bool AllowPredicates) {
+ assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
+ // If our exiting block does not dominate the latch, then its connection with
+ // loop's exit limit may be far from trivial.
+ const BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch || !DT.dominates(ExitingBlock, Latch))
+ return getCouldNotCompute();
+
+ bool IsOnlyExit = (L->getExitingBlock() != nullptr);
+ Instruction *Term = ExitingBlock->getTerminator();
+ if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
+ assert(BI->isConditional() && "If unconditional, it can't be in loop!");
+ bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
+ assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
+ "It should have one successor in loop and one exit block!");
+ // Proceed to the next level to examine the exit condition expression.
+ return computeExitLimitFromCond(
+ L, BI->getCondition(), ExitIfTrue,
+ /*ControlsExit=*/IsOnlyExit, AllowPredicates);
+ }
+
+ if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
+ // For switch, make sure that there is a single exit from the loop.
+ BasicBlock *Exit = nullptr;
+ for (auto *SBB : successors(ExitingBlock))
+ if (!L->contains(SBB)) {
+ if (Exit) // Multiple exit successors.
+ return getCouldNotCompute();
+ Exit = SBB;
+ }
+ assert(Exit && "Exiting block must have at least one exit");
+ return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
+ /*ControlsExit=*/IsOnlyExit);
+ }
+
+ return getCouldNotCompute();
+}
+
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
+ const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
+ ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
+ return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
+ ControlsExit, AllowPredicates);
+}
+
+Optional<ScalarEvolution::ExitLimit>
+ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
+ bool ExitIfTrue, bool ControlsExit,
+ bool AllowPredicates) {
+ (void)this->L;
+ (void)this->ExitIfTrue;
+ (void)this->AllowPredicates;
+
+ assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
+ this->AllowPredicates == AllowPredicates &&
+ "Variance in assumed invariant key components!");
+ auto Itr = TripCountMap.find({ExitCond, ControlsExit});
+ if (Itr == TripCountMap.end())
+ return None;
+ return Itr->second;
+}
+
+void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
+ bool ExitIfTrue,
+ bool ControlsExit,
+ bool AllowPredicates,
+ const ExitLimit &EL) {
+ assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
+ this->AllowPredicates == AllowPredicates &&
+ "Variance in assumed invariant key components!");
+
+ auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
+ assert(InsertResult.second && "Expected successful insertion!");
+ (void)InsertResult;
+ (void)ExitIfTrue;
+}
+
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
+
+ if (auto MaybeEL =
+ Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
+ return *MaybeEL;
+
+ ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
+ ControlsExit, AllowPredicates);
+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
+ return EL;
+}
+
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
+ // Check if the controlling expression for this loop is an And or Or.
+ if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
+ if (BO->getOpcode() == Instruction::And) {
+ // Recurse on the operands of the and.
+ bool EitherMayExit = !ExitIfTrue;
+ ExitLimit EL0 = computeExitLimitFromCondCached(
+ Cache, L, BO->getOperand(0), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
+ ExitLimit EL1 = computeExitLimitFromCondCached(
+ Cache, L, BO->getOperand(1), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
+ const SCEV *BECount = getCouldNotCompute();
+ const SCEV *MaxBECount = getCouldNotCompute();
+ if (EitherMayExit) {
+ // Both conditions must be true for the loop to continue executing.
+ // Choose the less conservative count.
+ if (EL0.ExactNotTaken == getCouldNotCompute() ||
+ EL1.ExactNotTaken == getCouldNotCompute())
+ BECount = getCouldNotCompute();
+ else
+ BECount =
+ getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
+ if (EL0.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL1.MaxNotTaken;
+ else if (EL1.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL0.MaxNotTaken;
+ else
+ MaxBECount =
+ getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
+ } else {
+ // Both conditions must be true at the same time for the loop to exit.
+ // For now, be conservative.
+ if (EL0.MaxNotTaken == EL1.MaxNotTaken)
+ MaxBECount = EL0.MaxNotTaken;
+ if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+ BECount = EL0.ExactNotTaken;
+ }
+
+ // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
+ // to be more aggressive when computing BECount than when computing
+ // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
+ // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
+ // to not.
+ if (isa<SCEVCouldNotCompute>(MaxBECount) &&
+ !isa<SCEVCouldNotCompute>(BECount))
+ MaxBECount = getConstant(getUnsignedRangeMax(BECount));
+
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
+ }
+ if (BO->getOpcode() == Instruction::Or) {
+ // Recurse on the operands of the or.
+ bool EitherMayExit = ExitIfTrue;
+ ExitLimit EL0 = computeExitLimitFromCondCached(
+ Cache, L, BO->getOperand(0), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
+ ExitLimit EL1 = computeExitLimitFromCondCached(
+ Cache, L, BO->getOperand(1), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
+ const SCEV *BECount = getCouldNotCompute();
+ const SCEV *MaxBECount = getCouldNotCompute();
+ if (EitherMayExit) {
+ // Both conditions must be false for the loop to continue executing.
+ // Choose the less conservative count.
+ if (EL0.ExactNotTaken == getCouldNotCompute() ||
+ EL1.ExactNotTaken == getCouldNotCompute())
+ BECount = getCouldNotCompute();
+ else
+ BECount =
+ getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
+ if (EL0.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL1.MaxNotTaken;
+ else if (EL1.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL0.MaxNotTaken;
+ else
+ MaxBECount =
+ getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
+ } else {
+ // Both conditions must be false at the same time for the loop to exit.
+ // For now, be conservative.
+ if (EL0.MaxNotTaken == EL1.MaxNotTaken)
+ MaxBECount = EL0.MaxNotTaken;
+ if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+ BECount = EL0.ExactNotTaken;
+ }
+ // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
+ // to be more aggressive when computing BECount than when computing
+ // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
+ // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
+ // to not.
+ if (isa<SCEVCouldNotCompute>(MaxBECount) &&
+ !isa<SCEVCouldNotCompute>(BECount))
+ MaxBECount = getConstant(getUnsignedRangeMax(BECount));
+
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
+ }
+ }
+
+ // With an icmp, it may be feasible to compute an exact backedge-taken count.
+ // Proceed to the next level to examine the icmp.
+ if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
+ ExitLimit EL =
+ computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
+ if (EL.hasFullInfo() || !AllowPredicates)
+ return EL;
+
+ // Try again, but use SCEV predicates this time.
+ return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
+ /*AllowPredicates=*/true);
+ }
+
+ // Check for a constant condition. These are normally stripped out by
+ // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
+ // preserve the CFG and is temporarily leaving constant conditions
+ // in place.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
+ if (ExitIfTrue == !CI->getZExtValue())
+ // The backedge is always taken.
+ return getCouldNotCompute();
+ else
+ // The backedge is never taken.
+ return getZero(CI->getType());
+ }
+
+ // If it's not an integer or pointer comparison then compute it the hard way.
+ return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
+ ICmpInst *ExitCond,
+ bool ExitIfTrue,
+ bool ControlsExit,
+ bool AllowPredicates) {
+ // If the condition was exit on true, convert the condition to exit on false
+ ICmpInst::Predicate Pred;
+ if (!ExitIfTrue)
+ Pred = ExitCond->getPredicate();
+ else
+ Pred = ExitCond->getInversePredicate();
+ const ICmpInst::Predicate OriginalPred = Pred;
+
+ // Handle common loops like: for (X = "string"; *X; ++X)
+ if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
+ if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
+ ExitLimit ItCnt =
+ computeLoadConstantCompareExitLimit(LI, RHS, L, Pred);
+ if (ItCnt.hasAnyInfo())
+ return ItCnt;
+ }
+
+ const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
+ const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
+
+ // Try to evaluate any dependencies out of the loop.
+ LHS = getSCEVAtScope(LHS, L);
+ RHS = getSCEVAtScope(RHS, L);
+
+ // At this point, we would like to compute how many iterations of the
+ // loop the predicate will return true for these inputs.
+ if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
+ // If there is a loop-invariant, force it into the RHS.
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ // Simplify the operands before analyzing them.
+ (void)SimplifyICmpOperands(Pred, LHS, RHS);
+
+ // If we have a comparison of a chrec against a constant, try to use value
+ // ranges to answer this query.
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
+ if (AddRec->getLoop() == L) {
+ // Form the constant range.
+ ConstantRange CompRange =
+ ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
+
+ const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
+ if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
+ }
+
+ switch (Pred) {
+ case ICmpInst::ICMP_NE: { // while (X != Y)
+ // Convert to: while (X-Y != 0)
+ ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
+ AllowPredicates);
+ if (EL.hasAnyInfo()) return EL;
+ break;
+ }
+ case ICmpInst::ICMP_EQ: { // while (X == Y)
+ // Convert to: while (X-Y == 0)
+ ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
+ if (EL.hasAnyInfo()) return EL;
+ break;
+ }
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_ULT: { // while (X < Y)
+ bool IsSigned = Pred == ICmpInst::ICMP_SLT;
+ ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
+ AllowPredicates);
+ if (EL.hasAnyInfo()) return EL;
+ break;
+ }
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_UGT: { // while (X > Y)
+ bool IsSigned = Pred == ICmpInst::ICMP_SGT;
+ ExitLimit EL =
+ howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
+ AllowPredicates);
+ if (EL.hasAnyInfo()) return EL;
+ break;
+ }
+ default:
+ break;
+ }
+
+ auto *ExhaustiveCount =
+ computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
+
+ if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
+ return ExhaustiveCount;
+
+ return computeShiftCompareExitLimit(ExitCond->getOperand(0),
+ ExitCond->getOperand(1), L, OriginalPred);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
+ SwitchInst *Switch,
+ BasicBlock *ExitingBlock,
+ bool ControlsExit) {
+ assert(!L->contains(ExitingBlock) && "Not an exiting block!");
+
+ // Give up if the exit is the default dest of a switch.
+ if (Switch->getDefaultDest() == ExitingBlock)
+ return getCouldNotCompute();
+
+ assert(L->contains(Switch->getDefaultDest()) &&
+ "Default case must not exit the loop!");
+ const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
+ const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
+
+ // while (X != Y) --> while (X-Y != 0)
+ ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
+ if (EL.hasAnyInfo())
+ return EL;
+
+ return getCouldNotCompute();
+}
+
+static ConstantInt *
+EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
+ ScalarEvolution &SE) {
+ const SCEV *InVal = SE.getConstant(C);
+ const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
+ assert(isa<SCEVConstant>(Val) &&
+ "Evaluation of SCEV at constant didn't fold correctly?");
+ return cast<SCEVConstant>(Val)->getValue();
+}
+
+/// Given an exit condition of 'icmp op load X, cst', try to see if we can
+/// compute the backedge execution count.
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeLoadConstantCompareExitLimit(
+ LoadInst *LI,
+ Constant *RHS,
+ const Loop *L,
+ ICmpInst::Predicate predicate) {
+ if (LI->isVolatile()) return getCouldNotCompute();
+
+ // Check to see if the loaded pointer is a getelementptr of a global.
+ // TODO: Use SCEV instead of manually grubbing with GEPs.
+ GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
+ if (!GEP) return getCouldNotCompute();
+
+ // Make sure that it is really a constant global we are gepping, with an
+ // initializer, and make sure the first IDX is really 0.
+ GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
+ if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
+ GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
+ !cast<Constant>(GEP->getOperand(1))->isNullValue())
+ return getCouldNotCompute();
+
+ // Okay, we allow one non-constant index into the GEP instruction.
+ Value *VarIdx = nullptr;
+ std::vector<Constant*> Indexes;
+ unsigned VarIdxNum = 0;
+ for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
+ Indexes.push_back(CI);
+ } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
+ if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's.
+ VarIdx = GEP->getOperand(i);
+ VarIdxNum = i-2;
+ Indexes.push_back(nullptr);
+ }
+
+ // Loop-invariant loads may be a byproduct of loop optimization. Skip them.
+ if (!VarIdx)
+ return getCouldNotCompute();
+
+ // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
+ // Check to see if X is a loop variant variable value now.
+ const SCEV *Idx = getSCEV(VarIdx);
+ Idx = getSCEVAtScope(Idx, L);
+
+ // We can only recognize very limited forms of loop index expressions, in
+ // particular, only affine AddRec's like {C1,+,C2}.
+ const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
+ if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
+ !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
+ !isa<SCEVConstant>(IdxExpr->getOperand(1)))
+ return getCouldNotCompute();
+
+ unsigned MaxSteps = MaxBruteForceIterations;
+ for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
+ ConstantInt *ItCst = ConstantInt::get(
+ cast<IntegerType>(IdxExpr->getType()), IterationNum);
+ ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
+
+ // Form the GEP offset.
+ Indexes[VarIdxNum] = Val;
+
+ Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
+ Indexes);
+ if (!Result) break; // Cannot compute!
+
+ // Evaluate the condition for this iteration.
+ Result = ConstantExpr::getICmp(predicate, Result, RHS);
+ if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure
+ if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
+ ++NumArrayLenItCounts;
+ return getConstant(ItCst); // Found terminating iteration!
+ }
+ }
+ return getCouldNotCompute();
+}
+
+ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
+ Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
+ ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
+ if (!RHS)
+ return getCouldNotCompute();
+
+ const BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return getCouldNotCompute();
+
+ const BasicBlock *Predecessor = L->getLoopPredecessor();
+ if (!Predecessor)
+ return getCouldNotCompute();
+
+ // Return true if V is of the form "LHS `shift_op` <positive constant>".
+ // Return LHS in OutLHS and shift_opt in OutOpCode.
+ auto MatchPositiveShift =
+ [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
+
+ using namespace PatternMatch;
+
+ ConstantInt *ShiftAmt;
+ if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::LShr;
+ else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::AShr;
+ else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::Shl;
+ else
+ return false;
+
+ return ShiftAmt->getValue().isStrictlyPositive();
+ };
+
+ // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
+ //
+ // loop:
+ // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
+ // %iv.shifted = lshr i32 %iv, <positive constant>
+ //
+ // Return true on a successful match. Return the corresponding PHI node (%iv
+ // above) in PNOut and the opcode of the shift operation in OpCodeOut.
+ auto MatchShiftRecurrence =
+ [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
+ Optional<Instruction::BinaryOps> PostShiftOpCode;
+
+ {
+ Instruction::BinaryOps OpC;
+ Value *V;
+
+ // If we encounter a shift instruction, "peel off" the shift operation,
+ // and remember that we did so. Later when we inspect %iv's backedge
+ // value, we will make sure that the backedge value uses the same
+ // operation.
+ //
+ // Note: the peeled shift operation does not have to be the same
+ // instruction as the one feeding into the PHI's backedge value. We only
+ // really care about it being the same *kind* of shift instruction --
+ // that's all that is required for our later inferences to hold.
+ if (MatchPositiveShift(LHS, V, OpC)) {
+ PostShiftOpCode = OpC;
+ LHS = V;
+ }
+ }
+
+ PNOut = dyn_cast<PHINode>(LHS);
+ if (!PNOut || PNOut->getParent() != L->getHeader())
+ return false;
+
+ Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
+ Value *OpLHS;
+
+ return
+ // The backedge value for the PHI node must be a shift by a positive
+ // amount
+ MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
+
+ // of the PHI node itself
+ OpLHS == PNOut &&
+
+ // and the kind of shift should be match the kind of shift we peeled
+ // off, if any.
+ (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
+ };
+
+ PHINode *PN;
+ Instruction::BinaryOps OpCode;
+ if (!MatchShiftRecurrence(LHS, PN, OpCode))
+ return getCouldNotCompute();
+
+ const DataLayout &DL = getDataLayout();
+
+ // The key rationale for this optimization is that for some kinds of shift
+ // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
+ // within a finite number of iterations. If the condition guarding the
+ // backedge (in the sense that the backedge is taken if the condition is true)
+ // is false for the value the shift recurrence stabilizes to, then we know
+ // that the backedge is taken only a finite number of times.
+
+ ConstantInt *StableValue = nullptr;
+ switch (OpCode) {
+ default:
+ llvm_unreachable("Impossible case!");
+
+ case Instruction::AShr: {
+ // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
+ // bitwidth(K) iterations.
+ Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
+ KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr,
+ Predecessor->getTerminator(), &DT);
+ auto *Ty = cast<IntegerType>(RHS->getType());
+ if (Known.isNonNegative())
+ StableValue = ConstantInt::get(Ty, 0);
+ else if (Known.isNegative())
+ StableValue = ConstantInt::get(Ty, -1, true);
+ else
+ return getCouldNotCompute();
+
+ break;
+ }
+ case Instruction::LShr:
+ case Instruction::Shl:
+ // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
+ // stabilize to 0 in at most bitwidth(K) iterations.
+ StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
+ break;
+ }
+
+ auto *Result =
+ ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
+ assert(Result->getType()->isIntegerTy(1) &&
+ "Otherwise cannot be an operand to a branch instruction");
+
+ if (Result->isZeroValue()) {
+ unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+ const SCEV *UpperBound =
+ getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
+ return ExitLimit(getCouldNotCompute(), UpperBound, false);
+ }
+
+ return getCouldNotCompute();
+}
+
+/// Return true if we can constant fold an instruction of the specified type,
+/// assuming that all operands were constants.
+static bool CanConstantFold(const Instruction *I) {
+ if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
+ isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
+ isa<LoadInst>(I) || isa<ExtractValueInst>(I))
+ return true;
+
+ if (const CallInst *CI = dyn_cast<CallInst>(I))
+ if (const Function *F = CI->getCalledFunction())
+ return canConstantFoldCallTo(CI, F);
+ return false;
+}
+
+/// Determine whether this instruction can constant evolve within this loop
+/// assuming its operands can all constant evolve.
+static bool canConstantEvolve(Instruction *I, const Loop *L) {
+ // An instruction outside of the loop can't be derived from a loop PHI.
+ if (!L->contains(I)) return false;
+
+ if (isa<PHINode>(I)) {
+ // We don't currently keep track of the control flow needed to evaluate
+ // PHIs, so we cannot handle PHIs inside of loops.
+ return L->getHeader() == I->getParent();
+ }
+
+ // If we won't be able to constant fold this expression even if the operands
+ // are constants, bail early.
+ return CanConstantFold(I);
+}
+
+/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
+/// recursing through each instruction operand until reaching a loop header phi.
+static PHINode *
+getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
+ DenseMap<Instruction *, PHINode *> &PHIMap,
+ unsigned Depth) {
+ if (Depth > MaxConstantEvolvingDepth)
+ return nullptr;
+
+ // Otherwise, we can evaluate this instruction if all of its operands are
+ // constant or derived from a PHI node themselves.
+ PHINode *PHI = nullptr;
+ for (Value *Op : UseInst->operands()) {
+ if (isa<Constant>(Op)) continue;
+
+ Instruction *OpInst = dyn_cast<Instruction>(Op);
+ if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
+
+ PHINode *P = dyn_cast<PHINode>(OpInst);
+ if (!P)
+ // If this operand is already visited, reuse the prior result.
+ // We may have P != PHI if this is the deepest point at which the
+ // inconsistent paths meet.
+ P = PHIMap.lookup(OpInst);
+ if (!P) {
+ // Recurse and memoize the results, whether a phi is found or not.
+ // This recursive call invalidates pointers into PHIMap.
+ P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
+ PHIMap[OpInst] = P;
+ }
+ if (!P)
+ return nullptr; // Not evolving from PHI
+ if (PHI && PHI != P)
+ return nullptr; // Evolving from multiple different PHIs.
+ PHI = P;
+ }
+ // This is a expression evolving from a constant PHI!
+ return PHI;
+}
+
+/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
+/// in the loop that V is derived from. We allow arbitrary operations along the
+/// way, but the operands of an operation must either be constants or a value
+/// derived from a constant PHI. If this expression does not fit with these
+/// constraints, return null.
+static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I || !canConstantEvolve(I, L)) return nullptr;
+
+ if (PHINode *PN = dyn_cast<PHINode>(I))
+ return PN;
+
+ // Record non-constant instructions contained by the loop.
+ DenseMap<Instruction *, PHINode *> PHIMap;
+ return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
+}
+
+/// EvaluateExpression - Given an expression that passes the
+/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
+/// in the loop has the value PHIVal. If we can't fold this expression for some
+/// reason, return null.
+static Constant *EvaluateExpression(Value *V, const Loop *L,
+ DenseMap<Instruction *, Constant *> &Vals,
+ const DataLayout &DL,
+ const TargetLibraryInfo *TLI) {
+ // Convenient constant check, but redundant for recursive calls.
+ if (Constant *C = dyn_cast<Constant>(V)) return C;
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I) return nullptr;
+
+ if (Constant *C = Vals.lookup(I)) return C;
+
+ // An instruction inside the loop depends on a value outside the loop that we
+ // weren't given a mapping for, or a value such as a call inside the loop.
+ if (!canConstantEvolve(I, L)) return nullptr;
+
+ // An unmapped PHI can be due to a branch or another loop inside this loop,
+ // or due to this not being the initial iteration through a loop where we
+ // couldn't compute the evolution of this particular PHI last time.
+ if (isa<PHINode>(I)) return nullptr;
+
+ std::vector<Constant*> Operands(I->getNumOperands());
+
+ for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+ Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
+ if (!Operand) {
+ Operands[i] = dyn_cast<Constant>(I->getOperand(i));
+ if (!Operands[i]) return nullptr;
+ continue;
+ }
+ Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
+ Vals[Operand] = C;
+ if (!C) return nullptr;
+ Operands[i] = C;
+ }
+
+ if (CmpInst *CI = dyn_cast<CmpInst>(I))
+ return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
+ Operands[1], DL, TLI);
+ if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+ if (!LI->isVolatile())
+ return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
+ }
+ return ConstantFoldInstOperands(I, Operands, DL, TLI);
+}
+
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+ Constant *IncomingVal = nullptr;
+
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ if (PN->getIncomingBlock(i) == BB)
+ continue;
+
+ auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+ if (!CurrentVal)
+ return nullptr;
+
+ if (IncomingVal != CurrentVal) {
+ if (IncomingVal)
+ return nullptr;
+ IncomingVal = CurrentVal;
+ }
+ }
+
+ return IncomingVal;
+}
+
+/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
+/// in the header of its containing loop, we know the loop executes a
+/// constant number of times, and the PHI node is just a recurrence
+/// involving constants, fold it.
+Constant *
+ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
+ const APInt &BEs,
+ const Loop *L) {
+ auto I = ConstantEvolutionLoopExitValue.find(PN);
+ if (I != ConstantEvolutionLoopExitValue.end())
+ return I->second;
+
+ if (BEs.ugt(MaxBruteForceIterations))
+ return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
+
+ Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
+
+ DenseMap<Instruction *, Constant *> CurrentIterVals;
+ BasicBlock *Header = L->getHeader();
+ assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
+
+ BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return nullptr;
+
+ for (PHINode &PHI : Header->phis()) {
+ if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
+ CurrentIterVals[&PHI] = StartCST;
+ }
+ if (!CurrentIterVals.count(PN))
+ return RetVal = nullptr;
+
+ Value *BEValue = PN->getIncomingValueForBlock(Latch);
+
+ // Execute the loop symbolically to determine the exit value.
+ assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
+ "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
+
+ unsigned NumIterations = BEs.getZExtValue(); // must be in range
+ unsigned IterationNum = 0;
+ const DataLayout &DL = getDataLayout();
+ for (; ; ++IterationNum) {
+ if (IterationNum == NumIterations)
+ return RetVal = CurrentIterVals[PN]; // Got exit value!
+
+ // Compute the value of the PHIs for the next iteration.
+ // EvaluateExpression adds non-phi values to the CurrentIterVals map.
+ DenseMap<Instruction *, Constant *> NextIterVals;
+ Constant *NextPHI =
+ EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
+ if (!NextPHI)
+ return nullptr; // Couldn't evaluate!
+ NextIterVals[PN] = NextPHI;
+
+ bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
+
+ // Also evaluate the other PHI nodes. However, we don't get to stop if we
+ // cease to be able to evaluate one of them or if they stop evolving,
+ // because that doesn't necessarily prevent us from computing PN.
+ SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
+ for (const auto &I : CurrentIterVals) {
+ PHINode *PHI = dyn_cast<PHINode>(I.first);
+ if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
+ PHIsToCompute.emplace_back(PHI, I.second);
+ }
+ // We use two distinct loops because EvaluateExpression may invalidate any
+ // iterators into CurrentIterVals.
+ for (const auto &I : PHIsToCompute) {
+ PHINode *PHI = I.first;
+ Constant *&NextPHI = NextIterVals[PHI];
+ if (!NextPHI) { // Not already computed.
+ Value *BEValue = PHI->getIncomingValueForBlock(Latch);
+ NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
+ }
+ if (NextPHI != I.second)
+ StoppedEvolving = false;
+ }
+
+ // If all entries in CurrentIterVals == NextIterVals then we can stop
+ // iterating, the loop can't continue to change.
+ if (StoppedEvolving)
+ return RetVal = CurrentIterVals[PN];
+
+ CurrentIterVals.swap(NextIterVals);
+ }
+}
+
+const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
+ Value *Cond,
+ bool ExitWhen) {
+ PHINode *PN = getConstantEvolvingPHI(Cond, L);
+ if (!PN) return getCouldNotCompute();
+
+ // If the loop is canonicalized, the PHI will have exactly two entries.
+ // That's the only form we support here.
+ if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
+
+ DenseMap<Instruction *, Constant *> CurrentIterVals;
+ BasicBlock *Header = L->getHeader();
+ assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
+
+ BasicBlock *Latch = L->getLoopLatch();
+ assert(Latch && "Should follow from NumIncomingValues == 2!");
+
+ for (PHINode &PHI : Header->phis()) {
+ if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
+ CurrentIterVals[&PHI] = StartCST;
+ }
+ if (!CurrentIterVals.count(PN))
+ return getCouldNotCompute();
+
+ // Okay, we find a PHI node that defines the trip count of this loop. Execute
+ // the loop symbolically to determine when the condition gets a value of
+ // "ExitWhen".
+ unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
+ const DataLayout &DL = getDataLayout();
+ for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
+ auto *CondVal = dyn_cast_or_null<ConstantInt>(
+ EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
+
+ // Couldn't symbolically evaluate.
+ if (!CondVal) return getCouldNotCompute();
+
+ if (CondVal->getValue() == uint64_t(ExitWhen)) {
+ ++NumBruteForceTripCountsComputed;
+ return getConstant(Type::getInt32Ty(getContext()), IterationNum);
+ }
+
+ // Update all the PHI nodes for the next iteration.
+ DenseMap<Instruction *, Constant *> NextIterVals;
+
+ // Create a list of which PHIs we need to compute. We want to do this before
+ // calling EvaluateExpression on them because that may invalidate iterators
+ // into CurrentIterVals.
+ SmallVector<PHINode *, 8> PHIsToCompute;
+ for (const auto &I : CurrentIterVals) {
+ PHINode *PHI = dyn_cast<PHINode>(I.first);
+ if (!PHI || PHI->getParent() != Header) continue;
+ PHIsToCompute.push_back(PHI);
+ }
+ for (PHINode *PHI : PHIsToCompute) {
+ Constant *&NextPHI = NextIterVals[PHI];
+ if (NextPHI) continue; // Already computed!
+
+ Value *BEValue = PHI->getIncomingValueForBlock(Latch);
+ NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
+ }
+ CurrentIterVals.swap(NextIterVals);
+ }
+
+ // Too many iterations were needed to evaluate.
+ return getCouldNotCompute();
+}
+
+const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+ SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+ ValuesAtScopes[V];
+ // Check to see if we've folded this expression at this loop before.
+ for (auto &LS : Values)
+ if (LS.first == L)
+ return LS.second ? LS.second : V;
+
+ Values.emplace_back(L, nullptr);
+
+ // Otherwise compute it.
+ const SCEV *C = computeSCEVAtScope(V, L);
+ for (auto &LS : reverse(ValuesAtScopes[V]))
+ if (LS.first == L) {
+ LS.second = C;
+ break;
+ }
+ return C;
+}
+
+/// This builds up a Constant using the ConstantExpr interface. That way, we
+/// will return Constants for objects which aren't represented by a
+/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
+/// Returns NULL if the SCEV isn't representable as a Constant.
+static Constant *BuildConstantFromSCEV(const SCEV *V) {
+ switch (static_cast<SCEVTypes>(V->getSCEVType())) {
+ case scCouldNotCompute:
+ case scAddRecExpr:
+ break;
+ case scConstant:
+ return cast<SCEVConstant>(V)->getValue();
+ case scUnknown:
+ return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
+ case scSignExtend: {
+ const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
+ return ConstantExpr::getSExt(CastOp, SS->getType());
+ break;
+ }
+ case scZeroExtend: {
+ const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
+ return ConstantExpr::getZExt(CastOp, SZ->getType());
+ break;
+ }
+ case scTruncate: {
+ const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
+ return ConstantExpr::getTrunc(CastOp, ST->getType());
+ break;
+ }
+ case scAddExpr: {
+ const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
+ if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
+ if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
+ unsigned AS = PTy->getAddressSpace();
+ Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
+ C = ConstantExpr::getBitCast(C, DestPtrTy);
+ }
+ for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
+ Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
+ if (!C2) return nullptr;
+
+ // First pointer!
+ if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
+ unsigned AS = C2->getType()->getPointerAddressSpace();
+ std::swap(C, C2);
+ Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
+ // The offsets have been converted to bytes. We can add bytes to an
+ // i8* by GEP with the byte count in the first index.
+ C = ConstantExpr::getBitCast(C, DestPtrTy);
+ }
+
+ // Don't bother trying to sum two pointers. We probably can't
+ // statically compute a load that results from it anyway.
+ if (C2->getType()->isPointerTy())
+ return nullptr;
+
+ if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
+ if (PTy->getElementType()->isStructTy())
+ C2 = ConstantExpr::getIntegerCast(
+ C2, Type::getInt32Ty(C->getContext()), true);
+ C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
+ } else
+ C = ConstantExpr::getAdd(C, C2);
+ }
+ return C;
+ }
+ break;
+ }
+ case scMulExpr: {
+ const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
+ if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
+ // Don't bother with pointers at all.
+ if (C->getType()->isPointerTy()) return nullptr;
+ for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
+ Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
+ if (!C2 || C2->getType()->isPointerTy()) return nullptr;
+ C = ConstantExpr::getMul(C, C2);
+ }
+ return C;
+ }
+ break;
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
+ if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
+ if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
+ if (LHS->getType() == RHS->getType())
+ return ConstantExpr::getUDiv(LHS, RHS);
+ break;
+ }
+ case scSMaxExpr:
+ case scUMaxExpr:
+ case scSMinExpr:
+ case scUMinExpr:
+ break; // TODO: smax, umax, smin, umax.
+ }
+ return nullptr;
+}
+
+const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
+ if (isa<SCEVConstant>(V)) return V;
+
+ // If this instruction is evolved from a constant-evolving PHI, compute the
+ // exit value from the loop without using SCEVs.
+ if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
+ if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
+ if (PHINode *PN = dyn_cast<PHINode>(I)) {
+ const Loop *LI = this->LI[I->getParent()];
+ // Looking for loop exit value.
+ if (LI && LI->getParentLoop() == L &&
+ PN->getParent() == LI->getHeader()) {
+ // Okay, there is no closed form solution for the PHI node. Check
+ // to see if the loop that contains it has a known backedge-taken
+ // count. If so, we may be able to force computation of the exit
+ // value.
+ const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
+ // This trivial case can show up in some degenerate cases where
+ // the incoming IR has not yet been fully simplified.
+ if (BackedgeTakenCount->isZero()) {
+ Value *InitValue = nullptr;
+ bool MultipleInitValues = false;
+ for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
+ if (!LI->contains(PN->getIncomingBlock(i))) {
+ if (!InitValue)
+ InitValue = PN->getIncomingValue(i);
+ else if (InitValue != PN->getIncomingValue(i)) {
+ MultipleInitValues = true;
+ break;
+ }
+ }
+ }
+ if (!MultipleInitValues && InitValue)
+ return getSCEV(InitValue);
+ }
+ // Do we have a loop invariant value flowing around the backedge
+ // for a loop which must execute the backedge?
+ if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
+ isKnownPositive(BackedgeTakenCount) &&
+ PN->getNumIncomingValues() == 2) {
+ unsigned InLoopPred = LI->contains(PN->getIncomingBlock(0)) ? 0 : 1;
+ const SCEV *OnBackedge = getSCEV(PN->getIncomingValue(InLoopPred));
+ if (IsAvailableOnEntry(LI, DT, OnBackedge, PN->getParent()))
+ return OnBackedge;
+ }
+ if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
+ // Okay, we know how many times the containing loop executes. If
+ // this is a constant evolving PHI node, get the final value at
+ // the specified iteration number.
+ Constant *RV =
+ getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI);
+ if (RV) return getSCEV(RV);
+ }
+ }
+
+ // If there is a single-input Phi, evaluate it at our scope. If we can
+ // prove that this replacement does not break LCSSA form, use new value.
+ if (PN->getNumOperands() == 1) {
+ const SCEV *Input = getSCEV(PN->getOperand(0));
+ const SCEV *InputAtScope = getSCEVAtScope(Input, L);
+ // TODO: We can generalize it using LI.replacementPreservesLCSSAForm,
+ // for the simplest case just support constants.
+ if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
+ }
+ }
+
+ // Okay, this is an expression that we cannot symbolically evaluate
+ // into a SCEV. Check to see if it's possible to symbolically evaluate
+ // the arguments into constants, and if so, try to constant propagate the
+ // result. This is particularly useful for computing loop exit values.
+ if (CanConstantFold(I)) {
+ SmallVector<Constant *, 4> Operands;
+ bool MadeImprovement = false;
+ for (Value *Op : I->operands()) {
+ if (Constant *C = dyn_cast<Constant>(Op)) {
+ Operands.push_back(C);
+ continue;
+ }
+
+ // If any of the operands is non-constant and if they are
+ // non-integer and non-pointer, don't even try to analyze them
+ // with scev techniques.
+ if (!isSCEVable(Op->getType()))
+ return V;
+
+ const SCEV *OrigV = getSCEV(Op);
+ const SCEV *OpV = getSCEVAtScope(OrigV, L);
+ MadeImprovement |= OrigV != OpV;
+
+ Constant *C = BuildConstantFromSCEV(OpV);
+ if (!C) return V;
+ if (C->getType() != Op->getType())
+ C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
+ Op->getType(),
+ false),
+ C, Op->getType());
+ Operands.push_back(C);
+ }
+
+ // Check to see if getSCEVAtScope actually made an improvement.
+ if (MadeImprovement) {
+ Constant *C = nullptr;
+ const DataLayout &DL = getDataLayout();
+ if (const CmpInst *CI = dyn_cast<CmpInst>(I))
+ C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
+ Operands[1], DL, &TLI);
+ else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
+ if (!LI->isVolatile())
+ C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
+ } else
+ C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
+ if (!C) return V;
+ return getSCEV(C);
+ }
+ }
+ }
+
+ // This is some other type of SCEVUnknown, just return it.
+ return V;
+ }
+
+ if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
+ // Avoid performing the look-up in the common case where the specified
+ // expression has no loop-variant portions.
+ for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
+ const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+ if (OpAtScope != Comm->getOperand(i)) {
+ // Okay, at least one of these operands is loop variant but might be
+ // foldable. Build a new instance of the folded commutative expression.
+ SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
+ Comm->op_begin()+i);
+ NewOps.push_back(OpAtScope);
+
+ for (++i; i != e; ++i) {
+ OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+ NewOps.push_back(OpAtScope);
+ }
+ if (isa<SCEVAddExpr>(Comm))
+ return getAddExpr(NewOps, Comm->getNoWrapFlags());
+ if (isa<SCEVMulExpr>(Comm))
+ return getMulExpr(NewOps, Comm->getNoWrapFlags());
+ if (isa<SCEVMinMaxExpr>(Comm))
+ return getMinMaxExpr(Comm->getSCEVType(), NewOps);
+ llvm_unreachable("Unknown commutative SCEV type!");
+ }
+ }
+ // If we got here, all operands are loop invariant.
+ return Comm;
+ }
+
+ if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
+ const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
+ const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
+ if (LHS == Div->getLHS() && RHS == Div->getRHS())
+ return Div; // must be loop invariant
+ return getUDivExpr(LHS, RHS);
+ }
+
+ // If this is a loop recurrence for a loop that does not contain L, then we
+ // are dealing with the final value computed by the loop.
+ if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
+ // First, attempt to evaluate each operand.
+ // Avoid performing the look-up in the common case where the specified
+ // expression has no loop-variant portions.
+ for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
+ const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
+ if (OpAtScope == AddRec->getOperand(i))
+ continue;
+
+ // Okay, at least one of these operands is loop variant but might be
+ // foldable. Build a new instance of the folded commutative expression.
+ SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
+ AddRec->op_begin()+i);
+ NewOps.push_back(OpAtScope);
+ for (++i; i != e; ++i)
+ NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
+
+ const SCEV *FoldedRec =
+ getAddRecExpr(NewOps, AddRec->getLoop(),
+ AddRec->getNoWrapFlags(SCEV::FlagNW));
+ AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
+ // The addrec may be folded to a nonrecurrence, for example, if the
+ // induction variable is multiplied by zero after constant folding. Go
+ // ahead and return the folded value.
+ if (!AddRec)
+ return FoldedRec;
+ break;
+ }
+
+ // If the scope is outside the addrec's loop, evaluate it by using the
+ // loop exit value of the addrec.
+ if (!AddRec->getLoop()->contains(L)) {
+ // To evaluate this recurrence, we need to know how many times the AddRec
+ // loop iterates. Compute this now.
+ const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
+ if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
+
+ // Then, evaluate the AddRec.
+ return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
+ }
+
+ return AddRec;
+ }
+
+ if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
+ const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+ if (Op == Cast->getOperand())
+ return Cast; // must be loop invariant
+ return getZeroExtendExpr(Op, Cast->getType());
+ }
+
+ if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
+ const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+ if (Op == Cast->getOperand())
+ return Cast; // must be loop invariant
+ return getSignExtendExpr(Op, Cast->getType());
+ }
+
+ if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
+ const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+ if (Op == Cast->getOperand())
+ return Cast; // must be loop invariant
+ return getTruncateExpr(Op, Cast->getType());
+ }
+
+ llvm_unreachable("Unknown SCEV type!");
+}
+
+const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
+ return getSCEVAtScope(getSCEV(V), L);
+}
+
+const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
+ if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
+ return stripInjectiveFunctions(ZExt->getOperand());
+ if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
+ return stripInjectiveFunctions(SExt->getOperand());
+ return S;
+}
+
+/// Finds the minimum unsigned root of the following equation:
+///
+/// A * X = B (mod N)
+///
+/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
+/// A and B isn't important.
+///
+/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
+static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
+ ScalarEvolution &SE) {
+ uint32_t BW = A.getBitWidth();
+ assert(BW == SE.getTypeSizeInBits(B->getType()));
+ assert(A != 0 && "A must be non-zero.");
+
+ // 1. D = gcd(A, N)
+ //
+ // The gcd of A and N may have only one prime factor: 2. The number of
+ // trailing zeros in A is its multiplicity
+ uint32_t Mult2 = A.countTrailingZeros();
+ // D = 2^Mult2
+
+ // 2. Check if B is divisible by D.
+ //
+ // B is divisible by D if and only if the multiplicity of prime factor 2 for B
+ // is not less than multiplicity of this prime factor for D.
+ if (SE.GetMinTrailingZeros(B) < Mult2)
+ return SE.getCouldNotCompute();
+
+ // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
+ // modulo (N / D).
+ //
+ // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
+ // (N / D) in general. The inverse itself always fits into BW bits, though,
+ // so we immediately truncate it.
+ APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
+ APInt Mod(BW + 1, 0);
+ Mod.setBit(BW - Mult2); // Mod = N / D
+ APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
+
+ // 4. Compute the minimum unsigned root of the equation:
+ // I * (B / D) mod (N / D)
+ // To simplify the computation, we factor out the divide by D:
+ // (I * B mod N) / D
+ const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
+ return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
+}
+
+/// For a given quadratic addrec, generate coefficients of the corresponding
+/// quadratic equation, multiplied by a common value to ensure that they are
+/// integers.
+/// The returned value is a tuple { A, B, C, M, BitWidth }, where
+/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
+/// were multiplied by, and BitWidth is the bit width of the original addrec
+/// coefficients.
+/// This function returns None if the addrec coefficients are not compile-
+/// time constants.
+static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
+GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
+ assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
+ const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
+ const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
+ const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
+ LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
+ << *AddRec << '\n');
+
+ // We currently can only solve this if the coefficients are constants.
+ if (!LC || !MC || !NC) {
+ LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
+ return None;
+ }
+
+ APInt L = LC->getAPInt();
+ APInt M = MC->getAPInt();
+ APInt N = NC->getAPInt();
+ assert(!N.isNullValue() && "This is not a quadratic addrec");
+
+ unsigned BitWidth = LC->getAPInt().getBitWidth();
+ unsigned NewWidth = BitWidth + 1;
+ LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
+ << BitWidth << '\n');
+ // The sign-extension (as opposed to a zero-extension) here matches the
+ // extension used in SolveQuadraticEquationWrap (with the same motivation).
+ N = N.sext(NewWidth);
+ M = M.sext(NewWidth);
+ L = L.sext(NewWidth);
+
+ // The increments are M, M+N, M+2N, ..., so the accumulated values are
+ // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
+ // L+M, L+2M+N, L+3M+3N, ...
+ // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
+ //
+ // The equation Acc = 0 is then
+ // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
+ // In a quadratic form it becomes:
+ // N n^2 + (2M-N) n + 2L = 0.
+
+ APInt A = N;
+ APInt B = 2 * M - A;
+ APInt C = 2 * L;
+ APInt T = APInt(NewWidth, 2);
+ LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
+ << "x + " << C << ", coeff bw: " << NewWidth
+ << ", multiplied by " << T << '\n');
+ return std::make_tuple(A, B, C, T, BitWidth);
+}
+
+/// Helper function to compare optional APInts:
+/// (a) if X and Y both exist, return min(X, Y),
+/// (b) if neither X nor Y exist, return None,
+/// (c) if exactly one of X and Y exists, return that value.
+static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
+ if (X.hasValue() && Y.hasValue()) {
+ unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
+ APInt XW = X->sextOrSelf(W);
+ APInt YW = Y->sextOrSelf(W);
+ return XW.slt(YW) ? *X : *Y;
+ }
+ if (!X.hasValue() && !Y.hasValue())
+ return None;
+ return X.hasValue() ? *X : *Y;
+}
+
+/// Helper function to truncate an optional APInt to a given BitWidth.
+/// When solving addrec-related equations, it is preferable to return a value
+/// that has the same bit width as the original addrec's coefficients. If the
+/// solution fits in the original bit width, truncate it (except for i1).
+/// Returning a value of a different bit width may inhibit some optimizations.
+///
+/// In general, a solution to a quadratic equation generated from an addrec
+/// may require BW+1 bits, where BW is the bit width of the addrec's
+/// coefficients. The reason is that the coefficients of the quadratic
+/// equation are BW+1 bits wide (to avoid truncation when converting from
+/// the addrec to the equation).
+static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
+ if (!X.hasValue())
+ return None;
+ unsigned W = X->getBitWidth();
+ if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
+ return X->trunc(BitWidth);
+ return X;
+}
+
+/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
+/// iterations. The values L, M, N are assumed to be signed, and they
+/// should all have the same bit widths.
+/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
+/// where BW is the bit width of the addrec's coefficients.
+/// If the calculated value is a BW-bit integer (for BW > 1), it will be
+/// returned as such, otherwise the bit width of the returned value may
+/// be greater than BW.
+///
+/// This function returns None if
+/// (a) the addrec coefficients are not constant, or
+/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
+/// like x^2 = 5, no integer solutions exist, in other cases an integer
+/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
+static Optional<APInt>
+SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
+ APInt A, B, C, M;
+ unsigned BitWidth;
+ auto T = GetQuadraticEquation(AddRec);
+ if (!T.hasValue())
+ return None;
+
+ std::tie(A, B, C, M, BitWidth) = *T;
+ LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
+ Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
+ if (!X.hasValue())
+ return None;
+
+ ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
+ ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
+ if (!V->isZero())
+ return None;
+
+ return TruncIfPossible(X, BitWidth);
+}
+
+/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
+/// iterations. The values M, N are assumed to be signed, and they
+/// should all have the same bit widths.
+/// Find the least n such that c(n) does not belong to the given range,
+/// while c(n-1) does.
+///
+/// This function returns None if
+/// (a) the addrec coefficients are not constant, or
+/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
+/// bounds of the range.
+static Optional<APInt>
+SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
+ const ConstantRange &Range, ScalarEvolution &SE) {
+ assert(AddRec->getOperand(0)->isZero() &&
+ "Starting value of addrec should be 0");
+ LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
+ << Range << ", addrec " << *AddRec << '\n');
+ // This case is handled in getNumIterationsInRange. Here we can assume that
+ // we start in the range.
+ assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
+ "Addrec's initial value should be in range");
+
+ APInt A, B, C, M;
+ unsigned BitWidth;
+ auto T = GetQuadraticEquation(AddRec);
+ if (!T.hasValue())
+ return None;
+
+ // Be careful about the return value: there can be two reasons for not
+ // returning an actual number. First, if no solutions to the equations
+ // were found, and second, if the solutions don't leave the given range.
+ // The first case means that the actual solution is "unknown", the second
+ // means that it's known, but not valid. If the solution is unknown, we
+ // cannot make any conclusions.
+ // Return a pair: the optional solution and a flag indicating if the
+ // solution was found.
+ auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
+ // Solve for signed overflow and unsigned overflow, pick the lower
+ // solution.
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
+ << Bound << " (before multiplying by " << M << ")\n");
+ Bound *= M; // The quadratic equation multiplier.
+
+ Optional<APInt> SO = None;
+ if (BitWidth > 1) {
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
+ "signed overflow\n");
+ SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
+ }
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
+ "unsigned overflow\n");
+ Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
+ BitWidth+1);
+
+ auto LeavesRange = [&] (const APInt &X) {
+ ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
+ ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
+ if (Range.contains(V0->getValue()))
+ return false;
+ // X should be at least 1, so X-1 is non-negative.
+ ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
+ ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
+ if (Range.contains(V1->getValue()))
+ return true;
+ return false;
+ };
+
+ // If SolveQuadraticEquationWrap returns None, it means that there can
+ // be a solution, but the function failed to find it. We cannot treat it
+ // as "no solution".
+ if (!SO.hasValue() || !UO.hasValue())
+ return { None, false };
+
+ // Check the smaller value first to see if it leaves the range.
+ // At this point, both SO and UO must have values.
+ Optional<APInt> Min = MinOptional(SO, UO);
+ if (LeavesRange(*Min))
+ return { Min, true };
+ Optional<APInt> Max = Min == SO ? UO : SO;
+ if (LeavesRange(*Max))
+ return { Max, true };
+
+ // Solutions were found, but were eliminated, hence the "true".
+ return { None, true };
+ };
+
+ std::tie(A, B, C, M, BitWidth) = *T;
+ // Lower bound is inclusive, subtract 1 to represent the exiting value.
+ APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1;
+ APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth());
+ auto SL = SolveForBoundary(Lower);
+ auto SU = SolveForBoundary(Upper);
+ // If any of the solutions was unknown, no meaninigful conclusions can
+ // be made.
+ if (!SL.second || !SU.second)
+ return None;
+
+ // Claim: The correct solution is not some value between Min and Max.
+ //
+ // Justification: Assuming that Min and Max are different values, one of
+ // them is when the first signed overflow happens, the other is when the
+ // first unsigned overflow happens. Crossing the range boundary is only
+ // possible via an overflow (treating 0 as a special case of it, modeling
+ // an overflow as crossing k*2^W for some k).
+ //
+ // The interesting case here is when Min was eliminated as an invalid
+ // solution, but Max was not. The argument is that if there was another
+ // overflow between Min and Max, it would also have been eliminated if
+ // it was considered.
+ //
+ // For a given boundary, it is possible to have two overflows of the same
+ // type (signed/unsigned) without having the other type in between: this
+ // can happen when the vertex of the parabola is between the iterations
+ // corresponding to the overflows. This is only possible when the two
+ // overflows cross k*2^W for the same k. In such case, if the second one
+ // left the range (and was the first one to do so), the first overflow
+ // would have to enter the range, which would mean that either we had left
+ // the range before or that we started outside of it. Both of these cases
+ // are contradictions.
+ //
+ // Claim: In the case where SolveForBoundary returns None, the correct
+ // solution is not some value between the Max for this boundary and the
+ // Min of the other boundary.
+ //
+ // Justification: Assume that we had such Max_A and Min_B corresponding
+ // to range boundaries A and B and such that Max_A < Min_B. If there was
+ // a solution between Max_A and Min_B, it would have to be caused by an
+ // overflow corresponding to either A or B. It cannot correspond to B,
+ // since Min_B is the first occurrence of such an overflow. If it
+ // corresponded to A, it would have to be either a signed or an unsigned
+ // overflow that is larger than both eliminated overflows for A. But
+ // between the eliminated overflows and this overflow, the values would
+ // cover the entire value space, thus crossing the other boundary, which
+ // is a contradiction.
+
+ return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
+ bool AllowPredicates) {
+
+ // This is only used for loops with a "x != y" exit test. The exit condition
+ // is now expressed as a single expression, V = x-y. So the exit test is
+ // effectively V != 0. We know and take advantage of the fact that this
+ // expression only being used in a comparison by zero context.
+
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+ // If the value is a constant
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
+ // If the value is already zero, the branch will execute zero times.
+ if (C->getValue()->isZero()) return C;
+ return getCouldNotCompute(); // Otherwise it will loop infinitely.
+ }
+
+ const SCEVAddRecExpr *AddRec =
+ dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
+
+ if (!AddRec && AllowPredicates)
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
+
+ if (!AddRec || AddRec->getLoop() != L)
+ return getCouldNotCompute();
+
+ // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
+ // the quadratic equation to solve it.
+ if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
+ // We can only use this value if the chrec ends up with an exact zero
+ // value at this index. When solving for "X*X != 5", for example, we
+ // should not accept a root of 2.
+ if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
+ const auto *R = cast<SCEVConstant>(getConstant(S.getValue()));
+ return ExitLimit(R, R, false, Predicates);
+ }
+ return getCouldNotCompute();
+ }
+
+ // Otherwise we can only handle this if it is affine.
+ if (!AddRec->isAffine())
+ return getCouldNotCompute();
+
+ // If this is an affine expression, the execution count of this branch is
+ // the minimum unsigned root of the following equation:
+ //
+ // Start + Step*N = 0 (mod 2^BW)
+ //
+ // equivalent to:
+ //
+ // Step*N = -Start (mod 2^BW)
+ //
+ // where BW is the common bit width of Start and Step.
+
+ // Get the initial value for the loop.
+ const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
+ const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
+
+ // For now we handle only constant steps.
+ //
+ // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
+ // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
+ // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
+ // We have not yet seen any such cases.
+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
+ if (!StepC || StepC->getValue()->isZero())
+ return getCouldNotCompute();
+
+ // For positive steps (counting up until unsigned overflow):
+ // N = -Start/Step (as unsigned)
+ // For negative steps (counting down to zero):
+ // N = Start/-Step
+ // First compute the unsigned distance from zero in the direction of Step.
+ bool CountDown = StepC->getAPInt().isNegative();
+ const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
+
+ // Handle unitary steps, which cannot wraparound.
+ // 1*N = -Start; -1*N = Start (mod 2^BW), so:
+ // N = Distance (as unsigned)
+ if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
+ APInt MaxBECount = getUnsignedRangeMax(Distance);
+
+ // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
+ // we end up with a loop whose backedge-taken count is n - 1. Detect this
+ // case, and see if we can improve the bound.
+ //
+ // Explicitly handling this here is necessary because getUnsignedRange
+ // isn't context-sensitive; it doesn't know that we only care about the
+ // range inside the loop.
+ const SCEV *Zero = getZero(Distance->getType());
+ const SCEV *One = getOne(Distance->getType());
+ const SCEV *DistancePlusOne = getAddExpr(Distance, One);
+ if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
+ // If Distance + 1 doesn't overflow, we can compute the maximum distance
+ // as "unsigned_max(Distance + 1) - 1".
+ ConstantRange CR = getUnsignedRange(DistancePlusOne);
+ MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
+ }
+ return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
+ }
+
+ // If the condition controls loop exit (the loop exits only if the expression
+ // is true) and the addition is no-wrap we can use unsigned divide to
+ // compute the backedge count. In this case, the step may not divide the
+ // distance, but we don't care because if the condition is "missed" the loop
+ // will have undefined behavior due to wrapping.
+ if (ControlsExit && AddRec->hasNoSelfWrap() &&
+ loopHasNoAbnormalExits(AddRec->getLoop())) {
+ const SCEV *Exact =
+ getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
+ const SCEV *Max =
+ Exact == getCouldNotCompute()
+ ? Exact
+ : getConstant(getUnsignedRangeMax(Exact));
+ return ExitLimit(Exact, Max, false, Predicates);
+ }
+
+ // Solve the general equation.
+ const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
+ getNegativeSCEV(Start), *this);
+ const SCEV *M = E == getCouldNotCompute()
+ ? E
+ : getConstant(getUnsignedRangeMax(E));
+ return ExitLimit(E, M, false, Predicates);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
+ // Loops that look like: while (X == 0) are very strange indeed. We don't
+ // handle them yet except for the trivial case. This could be expanded in the
+ // future as needed.
+
+ // If the value is a constant, check to see if it is known to be non-zero
+ // already. If so, the backedge will execute zero times.
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
+ if (!C->getValue()->isZero())
+ return getZero(C->getType());
+ return getCouldNotCompute(); // Otherwise it will loop infinitely.
+ }
+
+ // We could implement others, but I really doubt anyone writes loops like
+ // this, and if they did, they would already be constant folded.
+ return getCouldNotCompute();
+}
+
+std::pair<BasicBlock *, BasicBlock *>
+ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
+ // If the block has a unique predecessor, then there is no path from the
+ // predecessor to the block that does not go through the direct edge
+ // from the predecessor to the block.
+ if (BasicBlock *Pred = BB->getSinglePredecessor())
+ return {Pred, BB};
+
+ // A loop's header is defined to be a block that dominates the loop.
+ // If the header has a unique predecessor outside the loop, it must be
+ // a block that has exactly one successor that can reach the loop.
+ if (Loop *L = LI.getLoopFor(BB))
+ return {L->getLoopPredecessor(), L->getHeader()};
+
+ return {nullptr, nullptr};
+}
+
+/// SCEV structural equivalence is usually sufficient for testing whether two
+/// expressions are equal, however for the purposes of looking for a condition
+/// guarding a loop, it can be useful to be a little more general, since a
+/// front-end may have replicated the controlling expression.
+static bool HasSameValue(const SCEV *A, const SCEV *B) {
+ // Quick check to see if they are the same SCEV.
+ if (A == B) return true;
+
+ auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
+ // Not all instructions that are "identical" compute the same value. For
+ // instance, two distinct alloca instructions allocating the same type are
+ // identical and do not read memory; but compute distinct values.
+ return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
+ };
+
+ // Otherwise, if they're both SCEVUnknown, it's possible that they hold
+ // two different instructions with the same value. Check for this case.
+ if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
+ if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
+ if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
+ if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
+ if (ComputesEqualValues(AI, BI))
+ return true;
+
+ // Otherwise assume they may have a different value.
+ return false;
+}
+
+bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
+ const SCEV *&LHS, const SCEV *&RHS,
+ unsigned Depth) {
+ bool Changed = false;
+ // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
+ // '0 != 0'.
+ auto TrivialCase = [&](bool TriviallyTrue) {
+ LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
+ Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
+ return true;
+ };
+ // If we hit the max recursion limit bail out.
+ if (Depth >= 3)
+ return false;
+
+ // Canonicalize a constant to the right side.
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
+ // Check for both operands constant.
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+ if (ConstantExpr::getICmp(Pred,
+ LHSC->getValue(),
+ RHSC->getValue())->isNullValue())
+ return TrivialCase(false);
+ else
+ return TrivialCase(true);
+ }
+ // Otherwise swap the operands to put the constant on the right.
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ Changed = true;
+ }
+
+ // If we're comparing an addrec with a value which is loop-invariant in the
+ // addrec's loop, put the addrec on the left. Also make a dominance check,
+ // as both operands could be addrecs loop-invariant in each other's loop.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
+ const Loop *L = AR->getLoop();
+ if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ Changed = true;
+ }
+ }
+
+ // If there's a constant operand, canonicalize comparisons with boundary
+ // cases, and canonicalize *-or-equal comparisons to regular comparisons.
+ if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
+ const APInt &RA = RC->getAPInt();
+
+ bool SimplifiedByConstantRange = false;
+
+ if (!ICmpInst::isEquality(Pred)) {
+ ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
+ if (ExactCR.isFullSet())
+ return TrivialCase(true);
+ else if (ExactCR.isEmptySet())
+ return TrivialCase(false);
+
+ APInt NewRHS;
+ CmpInst::Predicate NewPred;
+ if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
+ ICmpInst::isEquality(NewPred)) {
+ // We were able to convert an inequality to an equality.
+ Pred = NewPred;
+ RHS = getConstant(NewRHS);
+ Changed = SimplifiedByConstantRange = true;
+ }
+ }
+
+ if (!SimplifiedByConstantRange) {
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE:
+ // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
+ if (!RA)
+ if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
+ if (const SCEVMulExpr *ME =
+ dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
+ if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
+ ME->getOperand(0)->isAllOnesValue()) {
+ RHS = AE->getOperand(1);
+ LHS = ME->getOperand(1);
+ Changed = true;
+ }
+ break;
+
+
+ // The "Should have been caught earlier!" messages refer to the fact
+ // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
+ // should have fired on the corresponding cases, and canonicalized the
+ // check to trivial case.
+
+ case ICmpInst::ICMP_UGE:
+ assert(!RA.isMinValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_UGT;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_ULE:
+ assert(!RA.isMaxValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_ULT;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_SGE:
+ assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_SGT;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_SLE:
+ assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_SLT;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ }
+ }
+ }
+
+ // Check for obvious equality.
+ if (HasSameValue(LHS, RHS)) {
+ if (ICmpInst::isTrueWhenEqual(Pred))
+ return TrivialCase(true);
+ if (ICmpInst::isFalseWhenEqual(Pred))
+ return TrivialCase(false);
+ }
+
+ // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
+ // adding or subtracting 1 from one of the operands.
+ switch (Pred) {
+ case ICmpInst::ICMP_SLE:
+ if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
+ SCEV::FlagNSW);
+ Pred = ICmpInst::ICMP_SLT;
+ Changed = true;
+ } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
+ SCEV::FlagNSW);
+ Pred = ICmpInst::ICMP_SLT;
+ Changed = true;
+ }
+ break;
+ case ICmpInst::ICMP_SGE:
+ if (!getSignedRangeMin(RHS).isMinSignedValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
+ SCEV::FlagNSW);
+ Pred = ICmpInst::ICMP_SGT;
+ Changed = true;
+ } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
+ SCEV::FlagNSW);
+ Pred = ICmpInst::ICMP_SGT;
+ Changed = true;
+ }
+ break;
+ case ICmpInst::ICMP_ULE:
+ if (!getUnsignedRangeMax(RHS).isMaxValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
+ SCEV::FlagNUW);
+ Pred = ICmpInst::ICMP_ULT;
+ Changed = true;
+ } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
+ Pred = ICmpInst::ICMP_ULT;
+ Changed = true;
+ }
+ break;
+ case ICmpInst::ICMP_UGE:
+ if (!getUnsignedRangeMin(RHS).isMinValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
+ Pred = ICmpInst::ICMP_UGT;
+ Changed = true;
+ } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
+ SCEV::FlagNUW);
+ Pred = ICmpInst::ICMP_UGT;
+ Changed = true;
+ }
+ break;
+ default:
+ break;
+ }
+
+ // TODO: More simplifications are possible here.
+
+ // Recursively simplify until we either hit a recursion limit or nothing
+ // changes.
+ if (Changed)
+ return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1);
+
+ return Changed;
+}
+
+bool ScalarEvolution::isKnownNegative(const SCEV *S) {
+ return getSignedRangeMax(S).isNegative();
+}
+
+bool ScalarEvolution::isKnownPositive(const SCEV *S) {
+ return getSignedRangeMin(S).isStrictlyPositive();
+}
+
+bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
+ return !getSignedRangeMin(S).isNegative();
+}
+
+bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
+ return !getSignedRangeMax(S).isStrictlyPositive();
+}
+
+bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
+ return isKnownNegative(S) || isKnownPositive(S);
+}
+
+std::pair<const SCEV *, const SCEV *>
+ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
+ // Compute SCEV on entry of loop L.
+ const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
+ if (Start == getCouldNotCompute())
+ return { Start, Start };
+ // Compute post increment SCEV for loop L.
+ const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
+ assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
+ return { Start, PostInc };
+}
+
+bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // First collect all loops.
+ SmallPtrSet<const Loop *, 8> LoopsUsed;
+ getUsedLoops(LHS, LoopsUsed);
+ getUsedLoops(RHS, LoopsUsed);
+
+ if (LoopsUsed.empty())
+ return false;
+
+ // Domination relationship must be a linear order on collected loops.
+#ifndef NDEBUG
+ for (auto *L1 : LoopsUsed)
+ for (auto *L2 : LoopsUsed)
+ assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
+ DT.dominates(L2->getHeader(), L1->getHeader())) &&
+ "Domination relationship is not a linear order");
+#endif
+
+ const Loop *MDL =
+ *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
+ [&](const Loop *L1, const Loop *L2) {
+ return DT.properlyDominates(L1->getHeader(), L2->getHeader());
+ });
+
+ // Get init and post increment value for LHS.
+ auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
+ // if LHS contains unknown non-invariant SCEV then bail out.
+ if (SplitLHS.first == getCouldNotCompute())
+ return false;
+ assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
+ // Get init and post increment value for RHS.
+ auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
+ // if RHS contains unknown non-invariant SCEV then bail out.
+ if (SplitRHS.first == getCouldNotCompute())
+ return false;
+ assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
+ // It is possible that init SCEV contains an invariant load but it does
+ // not dominate MDL and is not available at MDL loop entry, so we should
+ // check it here.
+ if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
+ !isAvailableAtLoopEntry(SplitRHS.first, MDL))
+ return false;
+
+ return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) &&
+ isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
+ SplitRHS.second);
+}
+
+bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // Canonicalize the inputs first.
+ (void)SimplifyICmpOperands(Pred, LHS, RHS);
+
+ if (isKnownViaInduction(Pred, LHS, RHS))
+ return true;
+
+ if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
+ return true;
+
+ // Otherwise see what can be done with some simple reasoning.
+ return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
+}
+
+bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
+ const SCEVAddRecExpr *LHS,
+ const SCEV *RHS) {
+ const Loop *L = LHS->getLoop();
+ return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
+ isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
+}
+
+bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
+ ICmpInst::Predicate Pred,
+ bool &Increasing) {
+ bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing);
+
+#ifndef NDEBUG
+ // Verify an invariant: inverting the predicate should turn a monotonically
+ // increasing change to a monotonically decreasing one, and vice versa.
+ bool IncreasingSwapped;
+ bool ResultSwapped = isMonotonicPredicateImpl(
+ LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped);
+
+ assert(Result == ResultSwapped && "should be able to analyze both!");
+ if (ResultSwapped)
+ assert(Increasing == !IncreasingSwapped &&
+ "monotonicity should flip as we flip the predicate");
+#endif
+
+ return Result;
+}
+
+bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
+ ICmpInst::Predicate Pred,
+ bool &Increasing) {
+
+ // A zero step value for LHS means the induction variable is essentially a
+ // loop invariant value. We don't really depend on the predicate actually
+ // flipping from false to true (for increasing predicates, and the other way
+ // around for decreasing predicates), all we care about is that *if* the
+ // predicate changes then it only changes from false to true.
+ //
+ // A zero step value in itself is not very useful, but there may be places
+ // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
+ // as general as possible.
+
+ switch (Pred) {
+ default:
+ return false; // Conservative answer
+
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ if (!LHS->hasNoUnsignedWrap())
+ return false;
+
+ Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE;
+ return true;
+
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE: {
+ if (!LHS->hasNoSignedWrap())
+ return false;
+
+ const SCEV *Step = LHS->getStepRecurrence(*this);
+
+ if (isKnownNonNegative(Step)) {
+ Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE;
+ return true;
+ }
+
+ if (isKnownNonPositive(Step)) {
+ Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
+ return true;
+ }
+
+ return false;
+ }
+
+ }
+
+ llvm_unreachable("switch has default clause!");
+}
+
+bool ScalarEvolution::isLoopInvariantPredicate(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+ ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
+ const SCEV *&InvariantRHS) {
+
+ // If there is a loop-invariant, force it into the RHS, otherwise bail out.
+ if (!isLoopInvariant(RHS, L)) {
+ if (!isLoopInvariant(LHS, L))
+ return false;
+
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!ArLHS || ArLHS->getLoop() != L)
+ return false;
+
+ bool Increasing;
+ if (!isMonotonicPredicate(ArLHS, Pred, Increasing))
+ return false;
+
+ // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
+ // true as the loop iterates, and the backedge is control dependent on
+ // "ArLHS `Pred` RHS" == true then we can reason as follows:
+ //
+ // * if the predicate was false in the first iteration then the predicate
+ // is never evaluated again, since the loop exits without taking the
+ // backedge.
+ // * if the predicate was true in the first iteration then it will
+ // continue to be true for all future iterations since it is
+ // monotonically increasing.
+ //
+ // For both the above possibilities, we can replace the loop varying
+ // predicate with its value on the first iteration of the loop (which is
+ // loop invariant).
+ //
+ // A similar reasoning applies for a monotonically decreasing predicate, by
+ // replacing true with false and false with true in the above two bullets.
+
+ auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
+
+ if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
+ return false;
+
+ InvariantPred = Pred;
+ InvariantLHS = ArLHS->getStart();
+ InvariantRHS = RHS;
+ return true;
+}
+
+bool ScalarEvolution::isKnownPredicateViaConstantRanges(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
+ if (HasSameValue(LHS, RHS))
+ return ICmpInst::isTrueWhenEqual(Pred);
+
+ // This code is split out from isKnownPredicate because it is called from
+ // within isLoopEntryGuardedByCond.
+
+ auto CheckRanges =
+ [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) {
+ return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS)
+ .contains(RangeLHS);
+ };
+
+ // The check at the top of the function catches the case where the values are
+ // known to be equal.
+ if (Pred == CmpInst::ICMP_EQ)
+ return false;
+
+ if (Pred == CmpInst::ICMP_NE)
+ return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) ||
+ CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) ||
+ isKnownNonZero(getMinusSCEV(LHS, RHS));
+
+ if (CmpInst::isSigned(Pred))
+ return CheckRanges(getSignedRange(LHS), getSignedRange(RHS));
+
+ return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS));
+}
+
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+ // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+ // Return Y via OutY.
+ auto MatchBinaryAddToConst =
+ [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+ SCEV::NoWrapFlags ExpectedFlags) {
+ const SCEV *NonConstOp, *ConstOp;
+ SCEV::NoWrapFlags FlagsPresent;
+
+ if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+ !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+ return false;
+
+ OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
+ return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+ };
+
+ APInt C;
+
+ switch (Pred) {
+ default:
+ break;
+
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SLE:
+ // X s<= (X + C)<nsw> if C >= 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+ return true;
+
+ // (X + C)<nsw> s<= X if C <= 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+ !C.isStrictlyPositive())
+ return true;
+ break;
+
+ case ICmpInst::ICMP_SGT:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SLT:
+ // X s< (X + C)<nsw> if C > 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+ C.isStrictlyPositive())
+ return true;
+
+ // (X + C)<nsw> s< X if C < 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+ return true;
+ break;
+ }
+
+ return false;
+}
+
+bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+ if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
+ return false;
+
+ // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
+ // the stack can result in exponential time complexity.
+ SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
+
+ // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
+ //
+ // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
+ // isKnownPredicate. isKnownPredicate is more powerful, but also more
+ // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
+ // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
+ // use isKnownPredicate later if needed.
+ return isKnownNonNegative(RHS) &&
+ isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+ isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
+}
+
+bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // No need to even try if we know the module has no guards.
+ if (!HasGuards)
+ return false;
+
+ return any_of(*BB, [&](Instruction &I) {
+ using namespace llvm::PatternMatch;
+
+ Value *Condition;
+ return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
+ m_Value(Condition))) &&
+ isImpliedCond(Pred, LHS, RHS, Condition, false);
+ });
+}
+
+/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
+/// protected by a conditional between LHS and RHS. This is used to
+/// to eliminate casts.
+bool
+ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // Interpret a null as meaning no loop, where there is obviously no guard
+ // (interprocedural conditions notwithstanding).
+ if (!L) return true;
+
+ if (VerifyIR)
+ assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
+ "This cannot be done on broken IR!");
+
+
+ if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
+ return true;
+
+ BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return false;
+
+ BranchInst *LoopContinuePredicate =
+ dyn_cast<BranchInst>(Latch->getTerminator());
+ if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
+ isImpliedCond(Pred, LHS, RHS,
+ LoopContinuePredicate->getCondition(),
+ LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
+ return true;
+
+ // We don't want more than one activation of the following loops on the stack
+ // -- that can lead to O(n!) time complexity.
+ if (WalkingBEDominatingConds)
+ return false;
+
+ SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
+
+ // See if we can exploit a trip count to prove the predicate.
+ const auto &BETakenInfo = getBackedgeTakenInfo(L);
+ const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
+ if (LatchBECount != getCouldNotCompute()) {
+ // We know that Latch branches back to the loop header exactly
+ // LatchBECount times. This means the backdege condition at Latch is
+ // equivalent to "{0,+,1} u< LatchBECount".
+ Type *Ty = LatchBECount->getType();
+ auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
+ const SCEV *LoopCounter =
+ getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
+ if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
+ LatchBECount))
+ return true;
+ }
+
+ // Check conditions due to any @llvm.assume intrinsics.
+ for (auto &AssumeVH : AC.assumptions()) {
+ if (!AssumeVH)
+ continue;
+ auto *CI = cast<CallInst>(AssumeVH);
+ if (!DT.dominates(CI, Latch->getTerminator()))
+ continue;
+
+ if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
+ return true;
+ }
+
+ // If the loop is not reachable from the entry block, we risk running into an
+ // infinite loop as we walk up into the dom tree. These loops do not matter
+ // anyway, so we just return a conservative answer when we see them.
+ if (!DT.isReachableFromEntry(L->getHeader()))
+ return false;
+
+ if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
+ return true;
+
+ for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
+ DTN != HeaderDTN; DTN = DTN->getIDom()) {
+ assert(DTN && "should reach the loop header before reaching the root!");
+
+ BasicBlock *BB = DTN->getBlock();
+ if (isImpliedViaGuard(BB, Pred, LHS, RHS))
+ return true;
+
+ BasicBlock *PBB = BB->getSinglePredecessor();
+ if (!PBB)
+ continue;
+
+ BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
+ if (!ContinuePredicate || !ContinuePredicate->isConditional())
+ continue;
+
+ Value *Condition = ContinuePredicate->getCondition();
+
+ // If we have an edge `E` within the loop body that dominates the only
+ // latch, the condition guarding `E` also guards the backedge. This
+ // reasoning works only for loops with a single latch.
+
+ BasicBlockEdge DominatingEdge(PBB, BB);
+ if (DominatingEdge.isSingleEdge()) {
+ // We're constructively (and conservatively) enumerating edges within the
+ // loop body that dominate the latch. The dominator tree better agree
+ // with us on this:
+ assert(DT.dominates(DominatingEdge, Latch) && "should be!");
+
+ if (isImpliedCond(Pred, LHS, RHS, Condition,
+ BB != ContinuePredicate->getSuccessor(0)))
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool
+ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // Interpret a null as meaning no loop, where there is obviously no guard
+ // (interprocedural conditions notwithstanding).
+ if (!L) return false;
+
+ if (VerifyIR)
+ assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
+ "This cannot be done on broken IR!");
+
+ // Both LHS and RHS must be available at loop entry.
+ assert(isAvailableAtLoopEntry(LHS, L) &&
+ "LHS is not available at Loop Entry");
+ assert(isAvailableAtLoopEntry(RHS, L) &&
+ "RHS is not available at Loop Entry");
+
+ if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
+ return true;
+
+ // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
+ // the facts (a >= b && a != b) separately. A typical situation is when the
+ // non-strict comparison is known from ranges and non-equality is known from
+ // dominating predicates. If we are proving strict comparison, we always try
+ // to prove non-equality and non-strict comparison separately.
+ auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
+ const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
+ bool ProvedNonStrictComparison = false;
+ bool ProvedNonEquality = false;
+
+ if (ProvingStrictComparison) {
+ ProvedNonStrictComparison =
+ isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS);
+ ProvedNonEquality =
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+
+ // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
+ auto ProveViaGuard = [&](BasicBlock *Block) {
+ if (isImpliedViaGuard(Block, Pred, LHS, RHS))
+ return true;
+ if (ProvingStrictComparison) {
+ if (!ProvedNonStrictComparison)
+ ProvedNonStrictComparison =
+ isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS);
+ if (!ProvedNonEquality)
+ ProvedNonEquality =
+ isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+ return false;
+ };
+
+ // Try to prove (Pred, LHS, RHS) using isImpliedCond.
+ auto ProveViaCond = [&](Value *Condition, bool Inverse) {
+ if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
+ return true;
+ if (ProvingStrictComparison) {
+ if (!ProvedNonStrictComparison)
+ ProvedNonStrictComparison =
+ isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
+ if (!ProvedNonEquality)
+ ProvedNonEquality =
+ isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+ return false;
+ };
+
+ // Starting at the loop predecessor, climb up the predecessor chain, as long
+ // as there are predecessors that can be found that have unique successors
+ // leading to the original header.
+ for (std::pair<BasicBlock *, BasicBlock *>
+ Pair(L->getLoopPredecessor(), L->getHeader());
+ Pair.first;
+ Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
+
+ if (ProveViaGuard(Pair.first))
+ return true;
+
+ BranchInst *LoopEntryPredicate =
+ dyn_cast<BranchInst>(Pair.first->getTerminator());
+ if (!LoopEntryPredicate ||
+ LoopEntryPredicate->isUnconditional())
+ continue;
+
+ if (ProveViaCond(LoopEntryPredicate->getCondition(),
+ LoopEntryPredicate->getSuccessor(0) != Pair.second))
+ return true;
+ }
+
+ // Check conditions due to any @llvm.assume intrinsics.
+ for (auto &AssumeVH : AC.assumptions()) {
+ if (!AssumeVH)
+ continue;
+ auto *CI = cast<CallInst>(AssumeVH);
+ if (!DT.dominates(CI, L->getHeader()))
+ continue;
+
+ if (ProveViaCond(CI->getArgOperand(0), false))
+ return true;
+ }
+
+ return false;
+}
+
+bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ Value *FoundCondValue,
+ bool Inverse) {
+ if (!PendingLoopPredicates.insert(FoundCondValue).second)
+ return false;
+
+ auto ClearOnExit =
+ make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
+
+ // Recursively handle And and Or conditions.
+ if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
+ if (BO->getOpcode() == Instruction::And) {
+ if (!Inverse)
+ return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
+ isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+ } else if (BO->getOpcode() == Instruction::Or) {
+ if (Inverse)
+ return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
+ isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+ }
+ }
+
+ ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
+ if (!ICI) return false;
+
+ // Now that we found a conditional branch that dominates the loop or controls
+ // the loop latch. Check to see if it is the comparison we are looking for.
+ ICmpInst::Predicate FoundPred;
+ if (Inverse)
+ FoundPred = ICI->getInversePredicate();
+ else
+ FoundPred = ICI->getPredicate();
+
+ const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
+ const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
+
+ return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
+}
+
+bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
+ const SCEV *RHS,
+ ICmpInst::Predicate FoundPred,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS) {
+ // Balance the types.
+ if (getTypeSizeInBits(LHS->getType()) <
+ getTypeSizeInBits(FoundLHS->getType())) {
+ if (CmpInst::isSigned(Pred)) {
+ LHS = getSignExtendExpr(LHS, FoundLHS->getType());
+ RHS = getSignExtendExpr(RHS, FoundLHS->getType());
+ } else {
+ LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
+ RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
+ }
+ } else if (getTypeSizeInBits(LHS->getType()) >
+ getTypeSizeInBits(FoundLHS->getType())) {
+ if (CmpInst::isSigned(FoundPred)) {
+ FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
+ FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
+ } else {
+ FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
+ FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
+ }
+ }
+
+ // Canonicalize the query to match the way instcombine will have
+ // canonicalized the comparison.
+ if (SimplifyICmpOperands(Pred, LHS, RHS))
+ if (LHS == RHS)
+ return CmpInst::isTrueWhenEqual(Pred);
+ if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
+ if (FoundLHS == FoundRHS)
+ return CmpInst::isFalseWhenEqual(FoundPred);
+
+ // Check to see if we can make the LHS or RHS match.
+ if (LHS == FoundRHS || RHS == FoundLHS) {
+ if (isa<SCEVConstant>(RHS)) {
+ std::swap(FoundLHS, FoundRHS);
+ FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
+ } else {
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+ }
+
+ // Check whether the found predicate is the same as the desired predicate.
+ if (FoundPred == Pred)
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
+ // Check whether swapping the found predicate makes it the same as the
+ // desired predicate.
+ if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
+ if (isa<SCEVConstant>(RHS))
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
+ else
+ return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
+ RHS, LHS, FoundLHS, FoundRHS);
+ }
+
+ // Unsigned comparison is the same as signed comparison when both the operands
+ // are non-negative.
+ if (CmpInst::isUnsigned(FoundPred) &&
+ CmpInst::getSignedPredicate(FoundPred) == Pred &&
+ isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
+ // Check if we can make progress by sharpening ranges.
+ if (FoundPred == ICmpInst::ICMP_NE &&
+ (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
+
+ const SCEVConstant *C = nullptr;
+ const SCEV *V = nullptr;
+
+ if (isa<SCEVConstant>(FoundLHS)) {
+ C = cast<SCEVConstant>(FoundLHS);
+ V = FoundRHS;
+ } else {
+ C = cast<SCEVConstant>(FoundRHS);
+ V = FoundLHS;
+ }
+
+ // The guarding predicate tells us that C != V. If the known range
+ // of V is [C, t), we can sharpen the range to [C + 1, t). The
+ // range we consider has to correspond to same signedness as the
+ // predicate we're interested in folding.
+
+ APInt Min = ICmpInst::isSigned(Pred) ?
+ getSignedRangeMin(V) : getUnsignedRangeMin(V);
+
+ if (Min == C->getAPInt()) {
+ // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
+ // This is true even if (Min + 1) wraps around -- in case of
+ // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
+
+ APInt SharperMin = Min + 1;
+
+ switch (Pred) {
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_UGE:
+ // We know V `Pred` SharperMin. If this implies LHS `Pred`
+ // RHS, we're done.
+ if (isImpliedCondOperands(Pred, LHS, RHS, V,
+ getConstant(SharperMin)))
+ return true;
+ LLVM_FALLTHROUGH;
+
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_UGT:
+ // We know from the range information that (V `Pred` Min ||
+ // V == Min). We know from the guarding condition that !(V
+ // == Min). This gives us
+ //
+ // V `Pred` Min || V == Min && !(V == Min)
+ // => V `Pred` Min
+ //
+ // If V `Pred` Min implies LHS `Pred` RHS, we're done.
+
+ if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+ return true;
+ LLVM_FALLTHROUGH;
+
+ default:
+ // No change
+ break;
+ }
+ }
+ }
+
+ // Check whether the actual condition is beyond sufficient.
+ if (FoundPred == ICmpInst::ICMP_EQ)
+ if (ICmpInst::isTrueWhenEqual(Pred))
+ if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+ if (Pred == ICmpInst::ICMP_NE)
+ if (!ICmpInst::isTrueWhenEqual(FoundPred))
+ if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+
+ // Otherwise assume the worst.
+ return false;
+}
+
+bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
+ const SCEV *&L, const SCEV *&R,
+ SCEV::NoWrapFlags &Flags) {
+ const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
+ if (!AE || AE->getNumOperands() != 2)
+ return false;
+
+ L = AE->getOperand(0);
+ R = AE->getOperand(1);
+ Flags = AE->getNoWrapFlags();
+ return true;
+}
+
+Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
+ const SCEV *Less) {
+ // We avoid subtracting expressions here because this function is usually
+ // fairly deep in the call stack (i.e. is called many times).
+
+ // X - X = 0.
+ if (More == Less)
+ return APInt(getTypeSizeInBits(More->getType()), 0);
+
+ if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+ const auto *LAR = cast<SCEVAddRecExpr>(Less);
+ const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+ if (LAR->getLoop() != MAR->getLoop())
+ return None;
+
+ // We look at affine expressions only; not for correctness but to keep
+ // getStepRecurrence cheap.
+ if (!LAR->isAffine() || !MAR->isAffine())
+ return None;
+
+ if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+ return None;
+
+ Less = LAR->getStart();
+ More = MAR->getStart();
+
+ // fall through
+ }
+
+ if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
+ const auto &M = cast<SCEVConstant>(More)->getAPInt();
+ const auto &L = cast<SCEVConstant>(Less)->getAPInt();
+ return M - L;
+ }
+
+ SCEV::NoWrapFlags Flags;
+ const SCEV *LLess = nullptr, *RLess = nullptr;
+ const SCEV *LMore = nullptr, *RMore = nullptr;
+ const SCEVConstant *C1 = nullptr, *C2 = nullptr;
+ // Compare (X + C1) vs X.
+ if (splitBinaryAdd(Less, LLess, RLess, Flags))
+ if ((C1 = dyn_cast<SCEVConstant>(LLess)))
+ if (RLess == More)
+ return -(C1->getAPInt());
+
+ // Compare X vs (X + C2).
+ if (splitBinaryAdd(More, LMore, RMore, Flags))
+ if ((C2 = dyn_cast<SCEVConstant>(LMore)))
+ if (RMore == Less)
+ return C2->getAPInt();
+
+ // Compare (X + C1) vs (X + C2).
+ if (C1 && C2 && RLess == RMore)
+ return C2->getAPInt() - C1->getAPInt();
+
+ return None;
+}
+
+bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS, const SCEV *FoundRHS) {
+ if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
+ return false;
+
+ const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!AddRecLHS)
+ return false;
+
+ const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
+ if (!AddRecFoundLHS)
+ return false;
+
+ // We'd like to let SCEV reason about control dependencies, so we constrain
+ // both the inequalities to be about add recurrences on the same loop. This
+ // way we can use isLoopEntryGuardedByCond later.
+
+ const Loop *L = AddRecFoundLHS->getLoop();
+ if (L != AddRecLHS->getLoop())
+ return false;
+
+ // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
+ //
+ // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
+ // ... (2)
+ //
+ // Informal proof for (2), assuming (1) [*]:
+ //
+ // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
+ //
+ // Then
+ //
+ // FoundLHS s< FoundRHS s< INT_MIN - C
+ // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
+ // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
+ // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
+ // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
+ // <=> FoundLHS + C s< FoundRHS + C
+ //
+ // [*]: (1) can be proved by ruling out overflow.
+ //
+ // [**]: This can be proved by analyzing all the four possibilities:
+ // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
+ // (A s>= 0, B s>= 0).
+ //
+ // Note:
+ // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
+ // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
+ // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
+ // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
+ // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
+ // C)".
+
+ Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
+ Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
+ if (!LDiff || !RDiff || *LDiff != *RDiff)
+ return false;
+
+ if (LDiff->isMinValue())
+ return true;
+
+ APInt FoundRHSLimit;
+
+ if (Pred == CmpInst::ICMP_ULT) {
+ FoundRHSLimit = -(*RDiff);
+ } else {
+ assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
+ FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
+ }
+
+ // Try to prove (1) or (2), as needed.
+ return isAvailableAtLoopEntry(FoundRHS, L) &&
+ isLoopEntryGuardedByCond(L, Pred, FoundRHS,
+ getConstant(FoundRHSLimit));
+}
+
+bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS, unsigned Depth) {
+ const PHINode *LPhi = nullptr, *RPhi = nullptr;
+
+ auto ClearOnExit = make_scope_exit([&]() {
+ if (LPhi) {
+ bool Erased = PendingMerges.erase(LPhi);
+ assert(Erased && "Failed to erase LPhi!");
+ (void)Erased;
+ }
+ if (RPhi) {
+ bool Erased = PendingMerges.erase(RPhi);
+ assert(Erased && "Failed to erase RPhi!");
+ (void)Erased;
+ }
+ });
+
+ // Find respective Phis and check that they are not being pending.
+ if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
+ if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
+ if (!PendingMerges.insert(Phi).second)
+ return false;
+ LPhi = Phi;
+ }
+ if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
+ if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
+ // If we detect a loop of Phi nodes being processed by this method, for
+ // example:
+ //
+ // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
+ // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
+ //
+ // we don't want to deal with a case that complex, so return conservative
+ // answer false.
+ if (!PendingMerges.insert(Phi).second)
+ return false;
+ RPhi = Phi;
+ }
+
+ // If none of LHS, RHS is a Phi, nothing to do here.
+ if (!LPhi && !RPhi)
+ return false;
+
+ // If there is a SCEVUnknown Phi we are interested in, make it left.
+ if (!LPhi) {
+ std::swap(LHS, RHS);
+ std::swap(FoundLHS, FoundRHS);
+ std::swap(LPhi, RPhi);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
+ const BasicBlock *LBB = LPhi->getParent();
+ const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
+
+ auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
+ return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
+ isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
+ isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
+ };
+
+ if (RPhi && RPhi->getParent() == LBB) {
+ // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
+ // If we compare two Phis from the same block, and for each entry block
+ // the predicate is true for incoming values from this block, then the
+ // predicate is also true for the Phis.
+ for (const BasicBlock *IncBB : predecessors(LBB)) {
+ const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
+ const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
+ if (!ProvedEasily(L, R))
+ return false;
+ }
+ } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
+ // Case two: RHS is also a Phi from the same basic block, and it is an
+ // AddRec. It means that there is a loop which has both AddRec and Unknown
+ // PHIs, for it we can compare incoming values of AddRec from above the loop
+ // and latch with their respective incoming values of LPhi.
+ // TODO: Generalize to handle loops with many inputs in a header.
+ if (LPhi->getNumIncomingValues() != 2) return false;
+
+ auto *RLoop = RAR->getLoop();
+ auto *Predecessor = RLoop->getLoopPredecessor();
+ assert(Predecessor && "Loop with AddRec with no predecessor?");
+ const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
+ if (!ProvedEasily(L1, RAR->getStart()))
+ return false;
+ auto *Latch = RLoop->getLoopLatch();
+ assert(Latch && "Loop with AddRec with no latch?");
+ const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
+ if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
+ return false;
+ } else {
+ // In all other cases go over inputs of LHS and compare each of them to RHS,
+ // the predicate is true for (LHS, RHS) if it is true for all such pairs.
+ // At this point RHS is either a non-Phi, or it is a Phi from some block
+ // different from LBB.
+ for (const BasicBlock *IncBB : predecessors(LBB)) {
+ // Check that RHS is available in this block.
+ if (!dominates(RHS, IncBB))
+ return false;
+ const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
+ if (!ProvedEasily(L, RHS))
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS) {
+ if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+
+ if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+
+ return isImpliedCondOperandsHelper(Pred, LHS, RHS,
+ FoundLHS, FoundRHS) ||
+ // ~x < ~y --> x > y
+ isImpliedCondOperandsHelper(Pred, LHS, RHS,
+ getNotSCEV(FoundRHS),
+ getNotSCEV(FoundLHS));
+}
+
+/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
+template <typename MinMaxExprType>
+static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
+ const SCEV *Candidate) {
+ const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
+ if (!MinMaxExpr)
+ return false;
+
+ return find(MinMaxExpr->operands(), Candidate) != MinMaxExpr->op_end();
+}
+
+static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // If both sides are affine addrecs for the same loop, with equal
+ // steps, and we know the recurrences don't wrap, then we only
+ // need to check the predicate on the starting values.
+
+ if (!ICmpInst::isRelational(Pred))
+ return false;
+
+ const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!LAR)
+ return false;
+ const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
+ if (!RAR)
+ return false;
+ if (LAR->getLoop() != RAR->getLoop())
+ return false;
+ if (!LAR->isAffine() || !RAR->isAffine())
+ return false;
+
+ if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
+ return false;
+
+ SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
+ SCEV::FlagNSW : SCEV::FlagNUW;
+ if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
+ return false;
+
+ return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
+}
+
+/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
+/// expression?
+static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ switch (Pred) {
+ default:
+ return false;
+
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SLE:
+ return
+ // min(A, ...) <= A
+ IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
+ // A <= max(A, ...)
+ IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
+
+ case ICmpInst::ICMP_UGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_ULE:
+ return
+ // min(A, ...) <= A
+ IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
+ // A <= max(A, ...)
+ IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
+ }
+
+ llvm_unreachable("covered switch fell through?!");
+}
+
+bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS,
+ unsigned Depth) {
+ assert(getTypeSizeInBits(LHS->getType()) ==
+ getTypeSizeInBits(RHS->getType()) &&
+ "LHS and RHS have different sizes?");
+ assert(getTypeSizeInBits(FoundLHS->getType()) ==
+ getTypeSizeInBits(FoundRHS->getType()) &&
+ "FoundLHS and FoundRHS have different sizes?");
+ // We want to avoid hurting the compile time with analysis of too big trees.
+ if (Depth > MaxSCEVOperationsImplicationDepth)
+ return false;
+ // We only want to work with ICMP_SGT comparison so far.
+ // TODO: Extend to ICMP_UGT?
+ if (Pred == ICmpInst::ICMP_SLT) {
+ Pred = ICmpInst::ICMP_SGT;
+ std::swap(LHS, RHS);
+ std::swap(FoundLHS, FoundRHS);
+ }
+ if (Pred != ICmpInst::ICMP_SGT)
+ return false;
+
+ auto GetOpFromSExt = [&](const SCEV *S) {
+ if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
+ return Ext->getOperand();
+ // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
+ // the constant in some cases.
+ return S;
+ };
+
+ // Acquire values from extensions.
+ auto *OrigLHS = LHS;
+ auto *OrigFoundLHS = FoundLHS;
+ LHS = GetOpFromSExt(LHS);
+ FoundLHS = GetOpFromSExt(FoundLHS);
+
+ // Is the SGT predicate can be proved trivially or using the found context.
+ auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
+ return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
+ isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
+ FoundRHS, Depth + 1);
+ };
+
+ if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
+ // We want to avoid creation of any new non-constant SCEV. Since we are
+ // going to compare the operands to RHS, we should be certain that we don't
+ // need any size extensions for this. So let's decline all cases when the
+ // sizes of types of LHS and RHS do not match.
+ // TODO: Maybe try to get RHS from sext to catch more cases?
+ if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
+ return false;
+
+ // Should not overflow.
+ if (!LHSAddExpr->hasNoSignedWrap())
+ return false;
+
+ auto *LL = LHSAddExpr->getOperand(0);
+ auto *LR = LHSAddExpr->getOperand(1);
+ auto *MinusOne = getNegativeSCEV(getOne(RHS->getType()));
+
+ // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
+ auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
+ return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
+ };
+ // Try to prove the following rule:
+ // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
+ // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
+ if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
+ return true;
+ } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
+ Value *LL, *LR;
+ // FIXME: Once we have SDiv implemented, we can get rid of this matching.
+
+ using namespace llvm::PatternMatch;
+
+ if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
+ // Rules for division.
+ // We are going to perform some comparisons with Denominator and its
+ // derivative expressions. In general case, creating a SCEV for it may
+ // lead to a complex analysis of the entire graph, and in particular it
+ // can request trip count recalculation for the same loop. This would
+ // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
+ // this, we only want to create SCEVs that are constants in this section.
+ // So we bail if Denominator is not a constant.
+ if (!isa<ConstantInt>(LR))
+ return false;
+
+ auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
+
+ // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
+ // then a SCEV for the numerator already exists and matches with FoundLHS.
+ auto *Numerator = getExistingSCEV(LL);
+ if (!Numerator || Numerator->getType() != FoundLHS->getType())
+ return false;
+
+ // Make sure that the numerator matches with FoundLHS and the denominator
+ // is positive.
+ if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
+ return false;
+
+ auto *DTy = Denominator->getType();
+ auto *FRHSTy = FoundRHS->getType();
+ if (DTy->isPointerTy() != FRHSTy->isPointerTy())
+ // One of types is a pointer and another one is not. We cannot extend
+ // them properly to a wider type, so let us just reject this case.
+ // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
+ // to avoid this check.
+ return false;
+
+ // Given that:
+ // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
+ auto *WTy = getWiderType(DTy, FRHSTy);
+ auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
+ auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
+
+ // Try to prove the following rule:
+ // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
+ // For example, given that FoundLHS > 2. It means that FoundLHS is at
+ // least 3. If we divide it by Denominator < 4, we will have at least 1.
+ auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
+ if (isKnownNonPositive(RHS) &&
+ IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
+ return true;
+
+ // Try to prove the following rule:
+ // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
+ // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
+ // If we divide it by Denominator > 2, then:
+ // 1. If FoundLHS is negative, then the result is 0.
+ // 2. If FoundLHS is non-negative, then the result is non-negative.
+ // Anyways, the result is non-negative.
+ auto *MinusOne = getNegativeSCEV(getOne(WTy));
+ auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
+ if (isKnownNegative(RHS) &&
+ IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
+ return true;
+ }
+ }
+
+ // If our expression contained SCEVUnknown Phis, and we split it down and now
+ // need to prove something for them, try to prove the predicate for every
+ // possible incoming values of those Phis.
+ if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
+ return true;
+
+ return false;
+}
+
+static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // zext x u<= sext x, sext x s<= zext x
+ switch (Pred) {
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SLE: {
+ // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
+ const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
+ const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
+ if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
+ return true;
+ break;
+ }
+ case ICmpInst::ICMP_UGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_ULE: {
+ // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
+ const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
+ const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
+ if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
+ return true;
+ break;
+ }
+ default:
+ break;
+ };
+ return false;
+}
+
+bool
+ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
+ isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
+ IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
+ IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+ isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
+}
+
+bool
+ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS) {
+ switch (Pred) {
+ default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE:
+ if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
+ return true;
+ break;
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
+ return true;
+ break;
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
+ return true;
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
+ return true;
+ break;
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
+ return true;
+ break;
+ }
+
+ // Maybe it can be proved via operations?
+ if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+
+ return false;
+}
+
+bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS) {
+ if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
+ // The restriction on `FoundRHS` be lifted easily -- it exists only to
+ // reduce the compile time impact of this optimization.
+ return false;
+
+ Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
+ if (!Addend)
+ return false;
+
+ const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
+
+ // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
+ // antecedent "`FoundLHS` `Pred` `FoundRHS`".
+ ConstantRange FoundLHSRange =
+ ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
+
+ // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
+ ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
+
+ // We can also compute the range of values for `LHS` that satisfy the
+ // consequent, "`LHS` `Pred` `RHS`":
+ const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
+ ConstantRange SatisfyingLHSRange =
+ ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
+
+ // The antecedent implies the consequent if every value of `LHS` that
+ // satisfies the antecedent also satisfies the consequent.
+ return SatisfyingLHSRange.contains(LHSRange);
+}
+
+bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
+ bool IsSigned, bool NoWrap) {
+ assert(isKnownPositive(Stride) && "Positive stride expected!");
+
+ if (NoWrap) return false;
+
+ unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+ const SCEV *One = getOne(Stride->getType());
+
+ if (IsSigned) {
+ APInt MaxRHS = getSignedRangeMax(RHS);
+ APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
+ APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
+
+ // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
+ return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
+ }
+
+ APInt MaxRHS = getUnsignedRangeMax(RHS);
+ APInt MaxValue = APInt::getMaxValue(BitWidth);
+ APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
+
+ // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
+ return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
+}
+
+bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
+ bool IsSigned, bool NoWrap) {
+ if (NoWrap) return false;
+
+ unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+ const SCEV *One = getOne(Stride->getType());
+
+ if (IsSigned) {
+ APInt MinRHS = getSignedRangeMin(RHS);
+ APInt MinValue = APInt::getSignedMinValue(BitWidth);
+ APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
+
+ // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
+ return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
+ }
+
+ APInt MinRHS = getUnsignedRangeMin(RHS);
+ APInt MinValue = APInt::getMinValue(BitWidth);
+ APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
+
+ // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
+ return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
+}
+
+const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
+ bool Equality) {
+ const SCEV *One = getOne(Step->getType());
+ Delta = Equality ? getAddExpr(Delta, Step)
+ : getAddExpr(Delta, getMinusSCEV(Step, One));
+ return getUDivExpr(Delta, Step);
+}
+
+const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
+ const SCEV *Stride,
+ const SCEV *End,
+ unsigned BitWidth,
+ bool IsSigned) {
+
+ assert(!isKnownNonPositive(Stride) &&
+ "Stride is expected strictly positive!");
+ // Calculate the maximum backedge count based on the range of values
+ // permitted by Start, End, and Stride.
+ const SCEV *MaxBECount;
+ APInt MinStart =
+ IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
+
+ APInt StrideForMaxBECount =
+ IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
+
+ // We already know that the stride is positive, so we paper over conservatism
+ // in our range computation by forcing StrideForMaxBECount to be at least one.
+ // In theory this is unnecessary, but we expect MaxBECount to be a
+ // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there
+ // is nothing to constant fold it to).
+ APInt One(BitWidth, 1, IsSigned);
+ StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount);
+
+ APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
+ : APInt::getMaxValue(BitWidth);
+ APInt Limit = MaxValue - (StrideForMaxBECount - 1);
+
+ // Although End can be a MAX expression we estimate MaxEnd considering only
+ // the case End = RHS of the loop termination condition. This is safe because
+ // in the other case (End - Start) is zero, leading to a zero maximum backedge
+ // taken count.
+ APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
+ : APIntOps::umin(getUnsignedRangeMax(End), Limit);
+
+ MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */,
+ getConstant(StrideForMaxBECount) /* Step */,
+ false /* Equality */);
+
+ return MaxBECount;
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
+ const Loop *L, bool IsSigned,
+ bool ControlsExit, bool AllowPredicates) {
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+
+ const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
+ bool PredicatedIV = false;
+
+ if (!IV && AllowPredicates) {
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
+ PredicatedIV = true;
+ }
+
+ // Avoid weird loops
+ if (!IV || IV->getLoop() != L || !IV->isAffine())
+ return getCouldNotCompute();
+
+ bool NoWrap = ControlsExit &&
+ IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
+
+ const SCEV *Stride = IV->getStepRecurrence(*this);
+
+ bool PositiveStride = isKnownPositive(Stride);
+
+ // Avoid negative or zero stride values.
+ if (!PositiveStride) {
+ // We can compute the correct backedge taken count for loops with unknown
+ // strides if we can prove that the loop is not an infinite loop with side
+ // effects. Here's the loop structure we are trying to handle -
+ //
+ // i = start
+ // do {
+ // A[i] = i;
+ // i += s;
+ // } while (i < end);
+ //
+ // The backedge taken count for such loops is evaluated as -
+ // (max(end, start + stride) - start - 1) /u stride
+ //
+ // The additional preconditions that we need to check to prove correctness
+ // of the above formula is as follows -
+ //
+ // a) IV is either nuw or nsw depending upon signedness (indicated by the
+ // NoWrap flag).
+ // b) loop is single exit with no side effects.
+ //
+ //
+ // Precondition a) implies that if the stride is negative, this is a single
+ // trip loop. The backedge taken count formula reduces to zero in this case.
+ //
+ // Precondition b) implies that the unknown stride cannot be zero otherwise
+ // we have UB.
+ //
+ // The positive stride case is the same as isKnownPositive(Stride) returning
+ // true (original behavior of the function).
+ //
+ // We want to make sure that the stride is truly unknown as there are edge
+ // cases where ScalarEvolution propagates no wrap flags to the
+ // post-increment/decrement IV even though the increment/decrement operation
+ // itself is wrapping. The computed backedge taken count may be wrong in
+ // such cases. This is prevented by checking that the stride is not known to
+ // be either positive or non-positive. For example, no wrap flags are
+ // propagated to the post-increment IV of this loop with a trip count of 2 -
+ //
+ // unsigned char i;
+ // for(i=127; i<128; i+=129)
+ // A[i] = i;
+ //
+ if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||
+ !loopHasNoSideEffects(L))
+ return getCouldNotCompute();
+ } else if (!Stride->isOne() &&
+ doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
+ // Avoid proven overflow cases: this will ensure that the backedge taken
+ // count will not generate any unsigned overflow. Relaxed no-overflow
+ // conditions exploit NoWrapFlags, allowing to optimize in presence of
+ // undefined behaviors like the case of C language.
+ return getCouldNotCompute();
+
+ ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT
+ : ICmpInst::ICMP_ULT;
+ const SCEV *Start = IV->getStart();
+ const SCEV *End = RHS;
+ // When the RHS is not invariant, we do not know the end bound of the loop and
+ // cannot calculate the ExactBECount needed by ExitLimit. However, we can
+ // calculate the MaxBECount, given the start, stride and max value for the end
+ // bound of the loop (RHS), and the fact that IV does not overflow (which is
+ // checked above).
+ if (!isLoopInvariant(RHS, L)) {
+ const SCEV *MaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+ return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
+ false /*MaxOrZero*/, Predicates);
+ }
+ // If the backedge is taken at least once, then it will be taken
+ // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
+ // is the LHS value of the less-than comparison the first time it is evaluated
+ // and End is the RHS.
+ const SCEV *BECountIfBackedgeTaken =
+ computeBECount(getMinusSCEV(End, Start), Stride, false);
+ // If the loop entry is guarded by the result of the backedge test of the
+ // first loop iteration, then we know the backedge will be taken at least
+ // once and so the backedge taken count is as above. If not then we use the
+ // expression (max(End,Start)-Start)/Stride to describe the backedge count,
+ // as if the backedge is taken at least once max(End,Start) is End and so the
+ // result is as above, and if not max(End,Start) is Start so we get a backedge
+ // count of zero.
+ const SCEV *BECount;
+ if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
+ BECount = BECountIfBackedgeTaken;
+ else {
+ End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
+ BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
+ }
+
+ const SCEV *MaxBECount;
+ bool MaxOrZero = false;
+ if (isa<SCEVConstant>(BECount))
+ MaxBECount = BECount;
+ else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+ // If we know exactly how many times the backedge will be taken if it's
+ // taken at least once, then the backedge count will either be that or
+ // zero.
+ MaxBECount = BECountIfBackedgeTaken;
+ MaxOrZero = true;
+ } else {
+ MaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+ }
+
+ if (isa<SCEVCouldNotCompute>(MaxBECount) &&
+ !isa<SCEVCouldNotCompute>(BECount))
+ MaxBECount = getConstant(getUnsignedRangeMax(BECount));
+
+ return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
+ const Loop *L, bool IsSigned,
+ bool ControlsExit, bool AllowPredicates) {
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+ // We handle only IV > Invariant
+ if (!isLoopInvariant(RHS, L))
+ return getCouldNotCompute();
+
+ const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!IV && AllowPredicates)
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
+
+ // Avoid weird loops
+ if (!IV || IV->getLoop() != L || !IV->isAffine())
+ return getCouldNotCompute();
+
+ bool NoWrap = ControlsExit &&
+ IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
+
+ const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
+
+ // Avoid negative or zero stride values
+ if (!isKnownPositive(Stride))
+ return getCouldNotCompute();
+
+ // Avoid proven overflow cases: this will ensure that the backedge taken count
+ // will not generate any unsigned overflow. Relaxed no-overflow conditions
+ // exploit NoWrapFlags, allowing to optimize in presence of undefined
+ // behaviors like the case of C language.
+ if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
+ return getCouldNotCompute();
+
+ ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT
+ : ICmpInst::ICMP_UGT;
+
+ const SCEV *Start = IV->getStart();
+ const SCEV *End = RHS;
+ if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS))
+ End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
+
+ const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
+
+ APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
+ : getUnsignedRangeMax(Start);
+
+ APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
+ : getUnsignedRangeMin(Stride);
+
+ unsigned BitWidth = getTypeSizeInBits(LHS->getType());
+ APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
+ : APInt::getMinValue(BitWidth) + (MinStride - 1);
+
+ // Although End can be a MIN expression we estimate MinEnd considering only
+ // the case End = RHS. This is safe because in the other case (Start - End)
+ // is zero, leading to a zero maximum backedge taken count.
+ APInt MinEnd =
+ IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
+ : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
+
+ const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
+ ? BECount
+ : computeBECount(getConstant(MaxStart - MinEnd),
+ getConstant(MinStride), false);
+
+ if (isa<SCEVCouldNotCompute>(MaxBECount))
+ MaxBECount = BECount;
+
+ return ExitLimit(BECount, MaxBECount, false, Predicates);
+}
+
+const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
+ ScalarEvolution &SE) const {
+ if (Range.isFullSet()) // Infinite loop.
+ return SE.getCouldNotCompute();
+
+ // If the start is a non-zero constant, shift the range to simplify things.
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
+ if (!SC->getValue()->isZero()) {
+ SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
+ Operands[0] = SE.getZero(SC->getType());
+ const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
+ getNoWrapFlags(FlagNW));
+ if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
+ return ShiftedAddRec->getNumIterationsInRange(
+ Range.subtract(SC->getAPInt()), SE);
+ // This is strange and shouldn't happen.
+ return SE.getCouldNotCompute();
+ }
+
+ // The only time we can solve this is when we have all constant indices.
+ // Otherwise, we cannot determine the overflow conditions.
+ if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
+ return SE.getCouldNotCompute();
+
+ // Okay at this point we know that all elements of the chrec are constants and
+ // that the start element is zero.
+
+ // First check to see if the range contains zero. If not, the first
+ // iteration exits.
+ unsigned BitWidth = SE.getTypeSizeInBits(getType());
+ if (!Range.contains(APInt(BitWidth, 0)))
+ return SE.getZero(getType());
+
+ if (isAffine()) {
+ // If this is an affine expression then we have this situation:
+ // Solve {0,+,A} in Range === Ax in Range
+
+ // We know that zero is in the range. If A is positive then we know that
+ // the upper value of the range must be the first possible exit value.
+ // If A is negative then the lower of the range is the last possible loop
+ // value. Also note that we already checked for a full range.
+ APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
+ APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
+
+ // The exit value should be (End+A)/A.
+ APInt ExitVal = (End + A).udiv(A);
+ ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
+
+ // Evaluate at the exit value. If we really did fall out of the valid
+ // range, then we computed our trip count, otherwise wrap around or other
+ // things must have happened.
+ ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
+ if (Range.contains(Val->getValue()))
+ return SE.getCouldNotCompute(); // Something strange happened
+
+ // Ensure that the previous value is in the range. This is a sanity check.
+ assert(Range.contains(
+ EvaluateConstantChrecAtConstant(this,
+ ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
+ "Linear scev computation is off in a bad way!");
+ return SE.getConstant(ExitValue);
+ }
+
+ if (isQuadratic()) {
+ if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
+ return SE.getConstant(S.getValue());
+ }
+
+ return SE.getCouldNotCompute();
+}
+
+const SCEVAddRecExpr *
+SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
+ assert(getNumOperands() > 1 && "AddRec with zero step?");
+ // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
+ // but in this case we cannot guarantee that the value returned will be an
+ // AddRec because SCEV does not have a fixed point where it stops
+ // simplification: it is legal to return ({rec1} + {rec2}). For example, it
+ // may happen if we reach arithmetic depth limit while simplifying. So we
+ // construct the returned value explicitly.
+ SmallVector<const SCEV *, 3> Ops;
+ // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
+ // (this + Step) is {A+B,+,B+C,+...,+,N}.
+ for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
+ Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
+ // We know that the last operand is not a constant zero (otherwise it would
+ // have been popped out earlier). This guarantees us that if the result has
+ // the same last operand, then it will also not be popped out, meaning that
+ // the returned value will be an AddRec.
+ const SCEV *Last = getOperand(getNumOperands() - 1);
+ assert(!Last->isZero() && "Recurrency with zero step?");
+ Ops.push_back(Last);
+ return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
+ SCEV::FlagAnyWrap));
+}
+
+// Return true when S contains at least an undef value.
+static inline bool containsUndefs(const SCEV *S) {
+ return SCEVExprContains(S, [](const SCEV *S) {
+ if (const auto *SU = dyn_cast<SCEVUnknown>(S))
+ return isa<UndefValue>(SU->getValue());
+ return false;
+ });
+}
+
+namespace {
+
+// Collect all steps of SCEV expressions.
+struct SCEVCollectStrides {
+ ScalarEvolution &SE;
+ SmallVectorImpl<const SCEV *> &Strides;
+
+ SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S)
+ : SE(SE), Strides(S) {}
+
+ bool follow(const SCEV *S) {
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
+ Strides.push_back(AR->getStepRecurrence(SE));
+ return true;
+ }
+
+ bool isDone() const { return false; }
+};
+
+// Collect all SCEVUnknown and SCEVMulExpr expressions.
+struct SCEVCollectTerms {
+ SmallVectorImpl<const SCEV *> &Terms;
+
+ SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {}
+
+ bool follow(const SCEV *S) {
+ if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
+ isa<SCEVSignExtendExpr>(S)) {
+ if (!containsUndefs(S))
+ Terms.push_back(S);
+
+ // Stop recursion: once we collected a term, do not walk its operands.
+ return false;
+ }
+
+ // Keep looking.
+ return true;
+ }
+
+ bool isDone() const { return false; }
+};
+
+// Check if a SCEV contains an AddRecExpr.
+struct SCEVHasAddRec {
+ bool &ContainsAddRec;
+
+ SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
+ ContainsAddRec = false;
+ }
+
+ bool follow(const SCEV *S) {
+ if (isa<SCEVAddRecExpr>(S)) {
+ ContainsAddRec = true;
+
+ // Stop recursion: once we collected a term, do not walk its operands.
+ return false;
+ }
+
+ // Keep looking.
+ return true;
+ }
+
+ bool isDone() const { return false; }
+};
+
+// Find factors that are multiplied with an expression that (possibly as a
+// subexpression) contains an AddRecExpr. In the expression:
+//
+// 8 * (100 + %p * %q * (%a + {0, +, 1}_loop))
+//
+// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)"
+// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size
+// parameters as they form a product with an induction variable.
+//
+// This collector expects all array size parameters to be in the same MulExpr.
+// It might be necessary to later add support for collecting parameters that are
+// spread over different nested MulExpr.
+struct SCEVCollectAddRecMultiplies {
+ SmallVectorImpl<const SCEV *> &Terms;
+ ScalarEvolution &SE;
+
+ SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE)
+ : Terms(T), SE(SE) {}
+
+ bool follow(const SCEV *S) {
+ if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) {
+ bool HasAddRec = false;
+ SmallVector<const SCEV *, 0> Operands;
+ for (auto Op : Mul->operands()) {
+ const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op);
+ if (Unknown && !isa<CallInst>(Unknown->getValue())) {
+ Operands.push_back(Op);
+ } else if (Unknown) {
+ HasAddRec = true;
+ } else {
+ bool ContainsAddRec;
+ SCEVHasAddRec ContiansAddRec(ContainsAddRec);
+ visitAll(Op, ContiansAddRec);
+ HasAddRec |= ContainsAddRec;
+ }
+ }
+ if (Operands.size() == 0)
+ return true;
+
+ if (!HasAddRec)
+ return false;
+
+ Terms.push_back(SE.getMulExpr(Operands));
+ // Stop recursion: once we collected a term, do not walk its operands.
+ return false;
+ }
+
+ // Keep looking.
+ return true;
+ }
+
+ bool isDone() const { return false; }
+};
+
+} // end anonymous namespace
+
+/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
+/// two places:
+/// 1) The strides of AddRec expressions.
+/// 2) Unknowns that are multiplied with AddRec expressions.
+void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
+ SmallVectorImpl<const SCEV *> &Terms) {
+ SmallVector<const SCEV *, 4> Strides;
+ SCEVCollectStrides StrideCollector(*this, Strides);
+ visitAll(Expr, StrideCollector);
+
+ LLVM_DEBUG({
+ dbgs() << "Strides:\n";
+ for (const SCEV *S : Strides)
+ dbgs() << *S << "\n";
+ });
+
+ for (const SCEV *S : Strides) {
+ SCEVCollectTerms TermCollector(Terms);
+ visitAll(S, TermCollector);
+ }
+
+ LLVM_DEBUG({
+ dbgs() << "Terms:\n";
+ for (const SCEV *T : Terms)
+ dbgs() << *T << "\n";
+ });
+
+ SCEVCollectAddRecMultiplies MulCollector(Terms, *this);
+ visitAll(Expr, MulCollector);
+}
+
+static bool findArrayDimensionsRec(ScalarEvolution &SE,
+ SmallVectorImpl<const SCEV *> &Terms,
+ SmallVectorImpl<const SCEV *> &Sizes) {
+ int Last = Terms.size() - 1;
+ const SCEV *Step = Terms[Last];
+
+ // End of recursion.
+ if (Last == 0) {
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) {
+ SmallVector<const SCEV *, 2> Qs;
+ for (const SCEV *Op : M->operands())
+ if (!isa<SCEVConstant>(Op))
+ Qs.push_back(Op);
+
+ Step = SE.getMulExpr(Qs);
+ }
+
+ Sizes.push_back(Step);
+ return true;
+ }
+
+ for (const SCEV *&Term : Terms) {
+ // Normalize the terms before the next call to findArrayDimensionsRec.
+ const SCEV *Q, *R;
+ SCEVDivision::divide(SE, Term, Step, &Q, &R);
+
+ // Bail out when GCD does not evenly divide one of the terms.
+ if (!R->isZero())
+ return false;
+
+ Term = Q;
+ }
+
+ // Remove all SCEVConstants.
+ Terms.erase(
+ remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }),
+ Terms.end());
+
+ if (Terms.size() > 0)
+ if (!findArrayDimensionsRec(SE, Terms, Sizes))
+ return false;
+
+ Sizes.push_back(Step);
+ return true;
+}
+
+// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
+static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
+ for (const SCEV *T : Terms)
+ if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>))
+ return true;
+ return false;
+}
+
+// Return the number of product terms in S.
+static inline int numberOfTerms(const SCEV *S) {
+ if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S))
+ return Expr->getNumOperands();
+ return 1;
+}
+
+static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) {
+ if (isa<SCEVConstant>(T))
+ return nullptr;
+
+ if (isa<SCEVUnknown>(T))
+ return T;
+
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) {
+ SmallVector<const SCEV *, 2> Factors;
+ for (const SCEV *Op : M->operands())
+ if (!isa<SCEVConstant>(Op))
+ Factors.push_back(Op);
+
+ return SE.getMulExpr(Factors);
+ }
+
+ return T;
+}
+
+/// Return the size of an element read or written by Inst.
+const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
+ Type *Ty;
+ if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
+ Ty = Store->getValueOperand()->getType();
+ else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
+ Ty = Load->getType();
+ else
+ return nullptr;
+
+ Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
+ return getSizeOfExpr(ETy, Ty);
+}
+
+void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
+ SmallVectorImpl<const SCEV *> &Sizes,
+ const SCEV *ElementSize) {
+ if (Terms.size() < 1 || !ElementSize)
+ return;
+
+ // Early return when Terms do not contain parameters: we do not delinearize
+ // non parametric SCEVs.
+ if (!containsParameters(Terms))
+ return;
+
+ LLVM_DEBUG({
+ dbgs() << "Terms:\n";
+ for (const SCEV *T : Terms)
+ dbgs() << *T << "\n";
+ });
+
+ // Remove duplicates.
+ array_pod_sort(Terms.begin(), Terms.end());
+ Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
+
+ // Put larger terms first.
+ llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) {
+ return numberOfTerms(LHS) > numberOfTerms(RHS);
+ });
+
+ // Try to divide all terms by the element size. If term is not divisible by
+ // element size, proceed with the original term.
+ for (const SCEV *&Term : Terms) {
+ const SCEV *Q, *R;
+ SCEVDivision::divide(*this, Term, ElementSize, &Q, &R);
+ if (!Q->isZero())
+ Term = Q;
+ }
+
+ SmallVector<const SCEV *, 4> NewTerms;
+
+ // Remove constant factors.
+ for (const SCEV *T : Terms)
+ if (const SCEV *NewT = removeConstantFactors(*this, T))
+ NewTerms.push_back(NewT);
+
+ LLVM_DEBUG({
+ dbgs() << "Terms after sorting:\n";
+ for (const SCEV *T : NewTerms)
+ dbgs() << *T << "\n";
+ });
+
+ if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) {
+ Sizes.clear();
+ return;
+ }
+
+ // The last element to be pushed into Sizes is the size of an element.
+ Sizes.push_back(ElementSize);
+
+ LLVM_DEBUG({
+ dbgs() << "Sizes:\n";
+ for (const SCEV *S : Sizes)
+ dbgs() << *S << "\n";
+ });
+}
+
+void ScalarEvolution::computeAccessFunctions(
+ const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
+ SmallVectorImpl<const SCEV *> &Sizes) {
+ // Early exit in case this SCEV is not an affine multivariate function.
+ if (Sizes.empty())
+ return;
+
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr))
+ if (!AR->isAffine())
+ return;
+
+ const SCEV *Res = Expr;
+ int Last = Sizes.size() - 1;
+ for (int i = Last; i >= 0; i--) {
+ const SCEV *Q, *R;
+ SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R);
+
+ LLVM_DEBUG({
+ dbgs() << "Res: " << *Res << "\n";
+ dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
+ dbgs() << "Res divided by Sizes[i]:\n";
+ dbgs() << "Quotient: " << *Q << "\n";
+ dbgs() << "Remainder: " << *R << "\n";
+ });
+
+ Res = Q;
+
+ // Do not record the last subscript corresponding to the size of elements in
+ // the array.
+ if (i == Last) {
+
+ // Bail out if the remainder is too complex.
+ if (isa<SCEVAddRecExpr>(R)) {
+ Subscripts.clear();
+ Sizes.clear();
+ return;
+ }
+
+ continue;
+ }
+
+ // Record the access function for the current subscript.
+ Subscripts.push_back(R);
+ }
+
+ // Also push in last position the remainder of the last division: it will be
+ // the access function of the innermost dimension.
+ Subscripts.push_back(Res);
+
+ std::reverse(Subscripts.begin(), Subscripts.end());
+
+ LLVM_DEBUG({
+ dbgs() << "Subscripts:\n";
+ for (const SCEV *S : Subscripts)
+ dbgs() << *S << "\n";
+ });
+}
+
+/// Splits the SCEV into two vectors of SCEVs representing the subscripts and
+/// sizes of an array access. Returns the remainder of the delinearization that
+/// is the offset start of the array. The SCEV->delinearize algorithm computes
+/// the multiples of SCEV coefficients: that is a pattern matching of sub
+/// expressions in the stride and base of a SCEV corresponding to the
+/// computation of a GCD (greatest common divisor) of base and stride. When
+/// SCEV->delinearize fails, it returns the SCEV unchanged.
+///
+/// For example: when analyzing the memory access A[i][j][k] in this loop nest
+///
+/// void foo(long n, long m, long o, double A[n][m][o]) {
+///
+/// for (long i = 0; i < n; i++)
+/// for (long j = 0; j < m; j++)
+/// for (long k = 0; k < o; k++)
+/// A[i][j][k] = 1.0;
+/// }
+///
+/// the delinearization input is the following AddRec SCEV:
+///
+/// AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k>
+///
+/// From this SCEV, we are able to say that the base offset of the access is %A
+/// because it appears as an offset that does not divide any of the strides in
+/// the loops:
+///
+/// CHECK: Base offset: %A
+///
+/// and then SCEV->delinearize determines the size of some of the dimensions of
+/// the array as these are the multiples by which the strides are happening:
+///
+/// CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes.
+///
+/// Note that the outermost dimension remains of UnknownSize because there are
+/// no strides that would help identifying the size of the last dimension: when
+/// the array has been statically allocated, one could compute the size of that
+/// dimension by dividing the overall size of the array by the size of the known
+/// dimensions: %m * %o * 8.
+///
+/// Finally delinearize provides the access functions for the array reference
+/// that does correspond to A[i][j][k] of the above C testcase:
+///
+/// CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>]
+///
+/// The testcases are checking the output of a function pass:
+/// DelinearizationPass that walks through all loads and stores of a function
+/// asking for the SCEV of the memory access with respect to all enclosing
+/// loops, calling SCEV->delinearize on that and printing the results.
+void ScalarEvolution::delinearize(const SCEV *Expr,
+ SmallVectorImpl<const SCEV *> &Subscripts,
+ SmallVectorImpl<const SCEV *> &Sizes,
+ const SCEV *ElementSize) {
+ // First step: collect parametric terms.
+ SmallVector<const SCEV *, 4> Terms;
+ collectParametricTerms(Expr, Terms);
+
+ if (Terms.empty())
+ return;
+
+ // Second step: find subscript sizes.
+ findArrayDimensions(Terms, Sizes, ElementSize);
+
+ if (Sizes.empty())
+ return;
+
+ // Third step: compute the access functions for each subscript.
+ computeAccessFunctions(Expr, Subscripts, Sizes);
+
+ if (Subscripts.empty())
+ return;
+
+ LLVM_DEBUG({
+ dbgs() << "succeeded to delinearize " << *Expr << "\n";
+ dbgs() << "ArrayDecl[UnknownSize]";
+ for (const SCEV *S : Sizes)
+ dbgs() << "[" << *S << "]";
+
+ dbgs() << "\nArrayRef";
+ for (const SCEV *S : Subscripts)
+ dbgs() << "[" << *S << "]";
+ dbgs() << "\n";
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// SCEVCallbackVH Class Implementation
+//===----------------------------------------------------------------------===//
+
+void ScalarEvolution::SCEVCallbackVH::deleted() {
+ assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
+ if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
+ SE->ConstantEvolutionLoopExitValue.erase(PN);
+ SE->eraseValueFromMap(getValPtr());
+ // this now dangles!
+}
+
+void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
+ assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
+
+ // Forget all the expressions associated with users of the old value,
+ // so that future queries will recompute the expressions using the new
+ // value.
+ Value *Old = getValPtr();
+ SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end());
+ SmallPtrSet<User *, 8> Visited;
+ while (!Worklist.empty()) {
+ User *U = Worklist.pop_back_val();
+ // Deleting the Old value will cause this to dangle. Postpone
+ // that until everything else is done.
+ if (U == Old)
+ continue;
+ if (!Visited.insert(U).second)
+ continue;
+ if (PHINode *PN = dyn_cast<PHINode>(U))
+ SE->ConstantEvolutionLoopExitValue.erase(PN);
+ SE->eraseValueFromMap(U);
+ Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
+ }
+ // Delete the Old value.
+ if (PHINode *PN = dyn_cast<PHINode>(Old))
+ SE->ConstantEvolutionLoopExitValue.erase(PN);
+ SE->eraseValueFromMap(Old);
+ // this now dangles!
+}
+
+ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
+ : CallbackVH(V), SE(se) {}
+
+//===----------------------------------------------------------------------===//
+// ScalarEvolution Class Implementation
+//===----------------------------------------------------------------------===//
+
+ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
+ AssumptionCache &AC, DominatorTree &DT,
+ LoopInfo &LI)
+ : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
+ CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
+ LoopDispositions(64), BlockDispositions(64) {
+ // To use guards for proving predicates, we need to scan every instruction in
+ // relevant basic blocks, and not just terminators. Doing this is a waste of
+ // time if the IR does not actually contain any calls to
+ // @llvm.experimental.guard, so do a quick check and remember this beforehand.
+ //
+ // This pessimizes the case where a pass that preserves ScalarEvolution wants
+ // to _add_ guards to the module when there weren't any before, and wants
+ // ScalarEvolution to optimize based on those guards. For now we prefer to be
+ // efficient in lieu of being smart in that rather obscure case.
+
+ auto *GuardDecl = F.getParent()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_guard));
+ HasGuards = GuardDecl && !GuardDecl->use_empty();
+}
+
+ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
+ : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
+ LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
+ ValueExprMap(std::move(Arg.ValueExprMap)),
+ PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
+ PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
+ PendingMerges(std::move(Arg.PendingMerges)),
+ MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
+ BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
+ PredicatedBackedgeTakenCounts(
+ std::move(Arg.PredicatedBackedgeTakenCounts)),
+ ConstantEvolutionLoopExitValue(
+ std::move(Arg.ConstantEvolutionLoopExitValue)),
+ ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
+ LoopDispositions(std::move(Arg.LoopDispositions)),
+ LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
+ BlockDispositions(std::move(Arg.BlockDispositions)),
+ UnsignedRanges(std::move(Arg.UnsignedRanges)),
+ SignedRanges(std::move(Arg.SignedRanges)),
+ UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+ UniquePreds(std::move(Arg.UniquePreds)),
+ SCEVAllocator(std::move(Arg.SCEVAllocator)),
+ LoopUsers(std::move(Arg.LoopUsers)),
+ PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
+ FirstUnknown(Arg.FirstUnknown) {
+ Arg.FirstUnknown = nullptr;
+}
+
+ScalarEvolution::~ScalarEvolution() {
+ // Iterate through all the SCEVUnknown instances and call their
+ // destructors, so that they release their references to their values.
+ for (SCEVUnknown *U = FirstUnknown; U;) {
+ SCEVUnknown *Tmp = U;
+ U = U->Next;
+ Tmp->~SCEVUnknown();
+ }
+ FirstUnknown = nullptr;
+
+ ExprValueMap.clear();
+ ValueExprMap.clear();
+ HasRecMap.clear();
+
+ // Free any extra memory created for ExitNotTakenInfo in the unlikely event
+ // that a loop had multiple computable exits.
+ for (auto &BTCI : BackedgeTakenCounts)
+ BTCI.second.clear();
+ for (auto &BTCI : PredicatedBackedgeTakenCounts)
+ BTCI.second.clear();
+
+ assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
+ assert(PendingPhiRanges.empty() && "getRangeRef garbage");
+ assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
+ assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
+ assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
+}
+
+bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
+ return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
+}
+
+static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
+ const Loop *L) {
+ // Print all inner loops first
+ for (Loop *I : *L)
+ PrintLoopInfo(OS, SE, I);
+
+ OS << "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+
+ SmallVector<BasicBlock *, 8> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+ if (ExitingBlocks.size() != 1)
+ OS << "<multiple exits> ";
+
+ if (SE->hasLoopInvariantBackedgeTakenCount(L))
+ OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
+ else
+ OS << "Unpredictable backedge-taken count.\n";
+
+ if (ExitingBlocks.size() > 1)
+ for (BasicBlock *ExitingBlock : ExitingBlocks) {
+ OS << " exit count for " << ExitingBlock->getName() << ": "
+ << *SE->getExitCount(L, ExitingBlock) << "\n";
+ }
+
+ OS << "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+
+ if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
+ OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
+ if (SE->isBackedgeTakenCountMaxOrZero(L))
+ OS << ", actual taken count either this or zero.";
+ } else {
+ OS << "Unpredictable max backedge-taken count. ";
+ }
+
+ OS << "\n"
+ "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+
+ SCEVUnionPredicate Pred;
+ auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred);
+ if (!isa<SCEVCouldNotCompute>(PBT)) {
+ OS << "Predicated backedge-taken count is " << *PBT << "\n";
+ OS << " Predicates:\n";
+ Pred.print(OS, 4);
+ } else {
+ OS << "Unpredictable predicated backedge-taken count. ";
+ }
+ OS << "\n";
+
+ if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
+ OS << "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+ OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
+ }
+}
+
+static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
+ switch (LD) {
+ case ScalarEvolution::LoopVariant:
+ return "Variant";
+ case ScalarEvolution::LoopInvariant:
+ return "Invariant";
+ case ScalarEvolution::LoopComputable:
+ return "Computable";
+ }
+ llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
+}
+
+void ScalarEvolution::print(raw_ostream &OS) const {
+ // ScalarEvolution's implementation of the print method is to print
+ // out SCEV values of all instructions that are interesting. Doing
+ // this potentially causes it to create new SCEV objects though,
+ // which technically conflicts with the const qualifier. This isn't
+ // observable from outside the class though, so casting away the
+ // const isn't dangerous.
+ ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
+
+ OS << "Classifying expressions for: ";
+ F.printAsOperand(OS, /*PrintType=*/false);
+ OS << "\n";
+ for (Instruction &I : instructions(F))
+ if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
+ OS << I << '\n';
+ OS << " --> ";
+ const SCEV *SV = SE.getSCEV(&I);
+ SV->print(OS);
+ if (!isa<SCEVCouldNotCompute>(SV)) {
+ OS << " U: ";
+ SE.getUnsignedRange(SV).print(OS);
+ OS << " S: ";
+ SE.getSignedRange(SV).print(OS);
+ }
+
+ const Loop *L = LI.getLoopFor(I.getParent());
+
+ const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
+ if (AtUse != SV) {
+ OS << " --> ";
+ AtUse->print(OS);
+ if (!isa<SCEVCouldNotCompute>(AtUse)) {
+ OS << " U: ";
+ SE.getUnsignedRange(AtUse).print(OS);
+ OS << " S: ";
+ SE.getSignedRange(AtUse).print(OS);
+ }
+ }
+
+ if (L) {
+ OS << "\t\t" "Exits: ";
+ const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
+ if (!SE.isLoopInvariant(ExitValue, L)) {
+ OS << "<<Unknown>>";
+ } else {
+ OS << *ExitValue;
+ }
+
+ bool First = true;
+ for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
+ if (First) {
+ OS << "\t\t" "LoopDispositions: { ";
+ First = false;
+ } else {
+ OS << ", ";
+ }
+
+ Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
+ }
+
+ for (auto *InnerL : depth_first(L)) {
+ if (InnerL == L)
+ continue;
+ if (First) {
+ OS << "\t\t" "LoopDispositions: { ";
+ First = false;
+ } else {
+ OS << ", ";
+ }
+
+ InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
+ }
+
+ OS << " }";
+ }
+
+ OS << "\n";
+ }
+
+ OS << "Determining loop execution counts for: ";
+ F.printAsOperand(OS, /*PrintType=*/false);
+ OS << "\n";
+ for (Loop *I : LI)
+ PrintLoopInfo(OS, &SE, I);
+}
+
+ScalarEvolution::LoopDisposition
+ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
+ auto &Values = LoopDispositions[S];
+ for (auto &V : Values) {
+ if (V.getPointer() == L)
+ return V.getInt();
+ }
+ Values.emplace_back(L, LoopVariant);
+ LoopDisposition D = computeLoopDisposition(S, L);
+ auto &Values2 = LoopDispositions[S];
+ for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+ if (V.getPointer() == L) {
+ V.setInt(D);
+ break;
+ }
+ }
+ return D;
+}
+
+ScalarEvolution::LoopDisposition
+ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
+ switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+ case scConstant:
+ return LoopInvariant;
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
+
+ // If L is the addrec's loop, it's computable.
+ if (AR->getLoop() == L)
+ return LoopComputable;
+
+ // Add recurrences are never invariant in the function-body (null loop).
+ if (!L)
+ return LoopVariant;
+
+ // Everything that is not defined at loop entry is variant.
+ if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
+ return LoopVariant;
+ assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
+ " dominate the contained loop's header?");
+
+ // This recurrence is invariant w.r.t. L if AR's loop contains L.
+ if (AR->getLoop()->contains(L))
+ return LoopInvariant;
+
+ // This recurrence is variant w.r.t. L if any of its operands
+ // are variant.
+ for (auto *Op : AR->operands())
+ if (!isLoopInvariant(Op, L))
+ return LoopVariant;
+
+ // Otherwise it's loop-invariant.
+ return LoopInvariant;
+ }
+ case scAddExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr: {
+ bool HasVarying = false;
+ for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+ LoopDisposition D = getLoopDisposition(Op, L);
+ if (D == LoopVariant)
+ return LoopVariant;
+ if (D == LoopComputable)
+ HasVarying = true;
+ }
+ return HasVarying ? LoopComputable : LoopInvariant;
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
+ LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
+ if (LD == LoopVariant)
+ return LoopVariant;
+ LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
+ if (RD == LoopVariant)
+ return LoopVariant;
+ return (LD == LoopInvariant && RD == LoopInvariant) ?
+ LoopInvariant : LoopComputable;
+ }
+ case scUnknown:
+ // All non-instruction values are loop invariant. All instructions are loop
+ // invariant if they are not contained in the specified loop.
+ // Instructions are never considered invariant in the function body
+ // (null loop) because they are defined within the "loop".
+ if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+ return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
+ return LoopInvariant;
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
+
+bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
+ return getLoopDisposition(S, L) == LoopInvariant;
+}
+
+bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
+ return getLoopDisposition(S, L) == LoopComputable;
+}
+
+ScalarEvolution::BlockDisposition
+ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
+ auto &Values = BlockDispositions[S];
+ for (auto &V : Values) {
+ if (V.getPointer() == BB)
+ return V.getInt();
+ }
+ Values.emplace_back(BB, DoesNotDominateBlock);
+ BlockDisposition D = computeBlockDisposition(S, BB);
+ auto &Values2 = BlockDispositions[S];
+ for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+ if (V.getPointer() == BB) {
+ V.setInt(D);
+ break;
+ }
+ }
+ return D;
+}
+
+ScalarEvolution::BlockDisposition
+ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
+ switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+ case scConstant:
+ return ProperlyDominatesBlock;
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
+ case scAddRecExpr: {
+ // This uses a "dominates" query instead of "properly dominates" query
+ // to test for proper dominance too, because the instruction which
+ // produces the addrec's value is a PHI, and a PHI effectively properly
+ // dominates its entire containing block.
+ const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
+ if (!DT.dominates(AR->getLoop()->getHeader(), BB))
+ return DoesNotDominateBlock;
+
+ // Fall through into SCEVNAryExpr handling.
+ LLVM_FALLTHROUGH;
+ }
+ case scAddExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr: {
+ const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
+ bool Proper = true;
+ for (const SCEV *NAryOp : NAry->operands()) {
+ BlockDisposition D = getBlockDisposition(NAryOp, BB);
+ if (D == DoesNotDominateBlock)
+ return DoesNotDominateBlock;
+ if (D == DominatesBlock)
+ Proper = false;
+ }
+ return Proper ? ProperlyDominatesBlock : DominatesBlock;
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
+ const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
+ BlockDisposition LD = getBlockDisposition(LHS, BB);
+ if (LD == DoesNotDominateBlock)
+ return DoesNotDominateBlock;
+ BlockDisposition RD = getBlockDisposition(RHS, BB);
+ if (RD == DoesNotDominateBlock)
+ return DoesNotDominateBlock;
+ return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
+ ProperlyDominatesBlock : DominatesBlock;
+ }
+ case scUnknown:
+ if (Instruction *I =
+ dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
+ if (I->getParent() == BB)
+ return DominatesBlock;
+ if (DT.properlyDominates(I->getParent(), BB))
+ return ProperlyDominatesBlock;
+ return DoesNotDominateBlock;
+ }
+ return ProperlyDominatesBlock;
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
+
+bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
+ return getBlockDisposition(S, BB) >= DominatesBlock;
+}
+
+bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
+ return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
+}
+
+bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
+ return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
+}
+
+bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const {
+ auto IsS = [&](const SCEV *X) { return S == X; };
+ auto ContainsS = [&](const SCEV *X) {
+ return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS);
+ };
+ return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken);
+}
+
+void
+ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
+ ValuesAtScopes.erase(S);
+ LoopDispositions.erase(S);
+ BlockDispositions.erase(S);
+ UnsignedRanges.erase(S);
+ SignedRanges.erase(S);
+ ExprValueMap.erase(S);
+ HasRecMap.erase(S);
+ MinTrailingZerosCache.erase(S);
+
+ for (auto I = PredicatedSCEVRewrites.begin();
+ I != PredicatedSCEVRewrites.end();) {
+ std::pair<const SCEV *, const Loop *> Entry = I->first;
+ if (Entry.first == S)
+ PredicatedSCEVRewrites.erase(I++);
+ else
+ ++I;
+ }
+
+ auto RemoveSCEVFromBackedgeMap =
+ [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
+ for (auto I = Map.begin(), E = Map.end(); I != E;) {
+ BackedgeTakenInfo &BEInfo = I->second;
+ if (BEInfo.hasOperand(S, this)) {
+ BEInfo.clear();
+ Map.erase(I++);
+ } else
+ ++I;
+ }
+ };
+
+ RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
+ RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
+}
+
+void
+ScalarEvolution::getUsedLoops(const SCEV *S,
+ SmallPtrSetImpl<const Loop *> &LoopsUsed) {
+ struct FindUsedLoops {
+ FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
+ : LoopsUsed(LoopsUsed) {}
+ SmallPtrSetImpl<const Loop *> &LoopsUsed;
+ bool follow(const SCEV *S) {
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
+ LoopsUsed.insert(AR->getLoop());
+ return true;
+ }
+
+ bool isDone() const { return false; }
+ };
+
+ FindUsedLoops F(LoopsUsed);
+ SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+}
+
+void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
+ SmallPtrSet<const Loop *, 8> LoopsUsed;
+ getUsedLoops(S, LoopsUsed);
+ for (auto *L : LoopsUsed)
+ LoopUsers[L].push_back(S);
+}
+
+void ScalarEvolution::verify() const {
+ ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
+ ScalarEvolution SE2(F, TLI, AC, DT, LI);
+
+ SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
+
+ // Map's SCEV expressions from one ScalarEvolution "universe" to another.
+ struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
+ SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
+
+ const SCEV *visitConstant(const SCEVConstant *Constant) {
+ return SE.getConstant(Constant->getAPInt());
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ return SE.getUnknown(Expr->getValue());
+ }
+
+ const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
+ return SE.getCouldNotCompute();
+ }
+ };
+
+ SCEVMapper SCM(SE2);
+
+ while (!LoopStack.empty()) {
+ auto *L = LoopStack.pop_back_val();
+ LoopStack.insert(LoopStack.end(), L->begin(), L->end());
+
+ auto *CurBECount = SCM.visit(
+ const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L));
+ auto *NewBECount = SE2.getBackedgeTakenCount(L);
+
+ if (CurBECount == SE2.getCouldNotCompute() ||
+ NewBECount == SE2.getCouldNotCompute()) {
+ // NB! This situation is legal, but is very suspicious -- whatever pass
+ // change the loop to make a trip count go from could not compute to
+ // computable or vice-versa *should have* invalidated SCEV. However, we
+ // choose not to assert here (for now) since we don't want false
+ // positives.
+ continue;
+ }
+
+ if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) {
+ // SCEV treats "undef" as an unknown but consistent value (i.e. it does
+ // not propagate undef aggressively). This means we can (and do) fail
+ // verification in cases where a transform makes the trip count of a loop
+ // go from "undef" to "undef+1" (say). The transform is fine, since in
+ // both cases the loop iterates "undef" times, but SCEV thinks we
+ // increased the trip count of the loop by 1 incorrectly.
+ continue;
+ }
+
+ if (SE.getTypeSizeInBits(CurBECount->getType()) >
+ SE.getTypeSizeInBits(NewBECount->getType()))
+ NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
+ else if (SE.getTypeSizeInBits(CurBECount->getType()) <
+ SE.getTypeSizeInBits(NewBECount->getType()))
+ CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
+
+ const SCEV *Delta = SE2.getMinusSCEV(CurBECount, NewBECount);
+
+ // Unless VerifySCEVStrict is set, we only compare constant deltas.
+ if ((VerifySCEVStrict || isa<SCEVConstant>(Delta)) && !Delta->isZero()) {
+ dbgs() << "Trip Count for " << *L << " Changed!\n";
+ dbgs() << "Old: " << *CurBECount << "\n";
+ dbgs() << "New: " << *NewBECount << "\n";
+ dbgs() << "Delta: " << *Delta << "\n";
+ std::abort();
+ }
+ }
+}
+
+bool ScalarEvolution::invalidate(
+ Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // Invalidate the ScalarEvolution object whenever it isn't preserved or one
+ // of its dependencies is invalidated.
+ auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
+ Inv.invalidate<AssumptionAnalysis>(F, PA) ||
+ Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
+ Inv.invalidate<LoopAnalysis>(F, PA);
+}
+
+AnalysisKey ScalarEvolutionAnalysis::Key;
+
+ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
+ AM.getResult<AssumptionAnalysis>(F),
+ AM.getResult<DominatorTreeAnalysis>(F),
+ AM.getResult<LoopAnalysis>(F));
+}
+
+PreservedAnalyses
+ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
+
+INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
+ "Scalar Evolution Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
+ "Scalar Evolution Analysis", false, true)
+
+char ScalarEvolutionWrapperPass::ID = 0;
+
+ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
+ initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
+ SE.reset(new ScalarEvolution(
+ F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
+ getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
+ getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
+ getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
+ return false;
+}
+
+void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
+
+void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
+ SE->print(OS);
+}
+
+void ScalarEvolutionWrapperPass::verifyAnalysis() const {
+ if (!VerifySCEV)
+ return;
+
+ SE->verify();
+}
+
+void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<AssumptionCacheTracker>();
+ AU.addRequiredTransitive<LoopInfoWrapperPass>();
+ AU.addRequiredTransitive<DominatorTreeWrapperPass>();
+ AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
+}
+
+const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
+ const SCEV *RHS) {
+ FoldingSetNodeID ID;
+ assert(LHS->getType() == RHS->getType() &&
+ "Type mismatch between LHS and RHS");
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Equal);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEVEqualPredicate *Eq = new (SCEVAllocator)
+ SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+ UniquePreds.InsertNode(Eq, IP);
+ return Eq;
+}
+
+const SCEVPredicate *ScalarEvolution::getWrapPredicate(
+ const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Wrap);
+ ID.AddPointer(AR);
+ ID.AddInteger(AddedFlags);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ auto *OF = new (SCEVAllocator)
+ SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
+ UniquePreds.InsertNode(OF, IP);
+ return OF;
+}
+
+namespace {
+
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+
+ /// Rewrites \p S in the context of a loop L and the SCEV predication
+ /// infrastructure.
+ ///
+ /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
+ /// equivalences present in \p Pred.
+ ///
+ /// If \p NewPreds is non-null, rewrite is free to add further predicates to
+ /// \p NewPreds such that the result will be an AddRecExpr.
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+ SCEVUnionPredicate *Pred) {
+ SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
+ return Rewriter.visit(S);
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (Pred) {
+ auto ExprPreds = Pred->getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+ }
+ return convertToAddRecWithPreds(Expr);
+ }
+
+ const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nuw
+ // flag. Add the nusw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
+ return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getZeroExtendExpr(Operand, Expr->getType());
+ }
+
+ const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nsw
+ // flag. Add the nssw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
+ return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getSignExtendExpr(Operand, Expr->getType());
+ }
+
+private:
+ explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+ SCEVUnionPredicate *Pred)
+ : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
+
+ bool addOverflowAssumption(const SCEVPredicate *P) {
+ if (!NewPreds) {
+ // Check if we've already made this assumption.
+ return Pred && Pred->implies(P);
+ }
+ NewPreds->insert(P);
+ return true;
+ }
+
+ bool addOverflowAssumption(const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ auto *A = SE.getWrapPredicate(AR, AddedFlags);
+ return addOverflowAssumption(A);
+ }
+
+ // If \p Expr represents a PHINode, we try to see if it can be represented
+ // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
+ // to add this predicate as a runtime overflow check, we return the AddRec.
+ // If \p Expr does not meet these conditions (is not a PHI node, or we
+ // couldn't create an AddRec for it, or couldn't add the predicate), we just
+ // return \p Expr.
+ const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
+ if (!isa<PHINode>(Expr->getValue()))
+ return Expr;
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
+ if (!PredicatedRewrite)
+ return Expr;
+ for (auto *P : PredicatedRewrite->second){
+ // Wrap predicates from outer loops are not supported.
+ if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
+ auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr());
+ if (L != AR->getLoop())
+ return Expr;
+ }
+ if (!addOverflowAssumption(P))
+ return Expr;
+ }
+ return PredicatedRewrite->first;
+ }
+
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
+ SCEVUnionPredicate *Pred;
+ const Loop *L;
+};
+
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
+ SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
+}
+
+const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
+ const SCEV *S, const Loop *L,
+ SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
+ SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
+ S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
+
+ if (!AddRec)
+ return nullptr;
+
+ // Since the transformation was successful, we can now transfer the SCEV
+ // predicates.
+ for (auto *P : TransformPreds)
+ Preds.insert(P);
+
+ return AddRec;
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+ SCEVPredicateKind Kind)
+ : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+ const SCEV *LHS, const SCEV *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
+ assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
+ assert(LHS != RHS && "LHS and RHS are the same SCEV");
+}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
+
+ if (!Op)
+ return false;
+
+ return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVAddRecExpr *AR,
+ IncrementWrapFlags Flags)
+ : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
+
+const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
+
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+
+ return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+}
+
+bool SCEVWrapPredicate::isAlwaysTrue() const {
+ SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
+ IncrementWrapFlags IFlags = Flags;
+
+ if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
+ IFlags = clearFlags(IFlags, IncrementNSSW);
+
+ return IFlags == IncrementAnyWrap;
+}
+
+void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << *getExpr() << " Added Flags: ";
+ if (SCEVWrapPredicate::IncrementNUSW & getFlags())
+ OS << "<nusw>";
+ if (SCEVWrapPredicate::IncrementNSSW & getFlags())
+ OS << "<nssw>";
+ OS << "\n";
+}
+
+SCEVWrapPredicate::IncrementWrapFlags
+SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
+ ScalarEvolution &SE) {
+ IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
+ SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
+
+ // We can safely transfer the NSW flag as NSSW.
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
+ ImpliedFlags = IncrementNSSW;
+
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
+ // If the increment is positive, the SCEV NUW flag will also imply the
+ // WrapPredicate NUSW flag.
+ if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+ if (Step->getValue()->getValue().isNonNegative())
+ ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
+ }
+
+ return ImpliedFlags;
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+ return all_of(Preds,
+ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+ auto I = SCEVToPreds.find(Expr);
+ if (I == SCEVToPreds.end())
+ return ArrayRef<const SCEVPredicate *>();
+ return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+ if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
+ return all_of(Set->Preds,
+ [this](const SCEVPredicate *I) { return this->implies(I); });
+
+ auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+ if (ScevPredsIt == SCEVToPreds.end())
+ return false;
+ auto &SCEVPreds = ScevPredsIt->second;
+
+ return any_of(SCEVPreds,
+ [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ for (auto Pred : Preds)
+ Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+ if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
+ for (auto Pred : Set->Preds)
+ add(Pred);
+ return;
+ }
+
+ if (implies(N))
+ return;
+
+ const SCEV *Key = N->getExpr();
+ assert(Key && "Only SCEVUnionPredicate doesn't have an "
+ " associated expression!");
+
+ SCEVToPreds[Key].push_back(N);
+ Preds.push_back(N);
+}
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
+ Loop &L)
+ : SE(SE), L(L) {}
+
+const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
+ const SCEV *Expr = SE.getSCEV(V);
+ RewriteEntry &Entry = RewriteMap[Expr];
+
+ // If we already have an entry and the version matches, return it.
+ if (Entry.second && Generation == Entry.first)
+ return Entry.second;
+
+ // We found an entry but it's stale. Rewrite the stale entry
+ // according to the current predicate.
+ if (Entry.second)
+ Expr = Entry.second;
+
+ const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
+ Entry = {Generation, NewSCEV};
+
+ return NewSCEV;
+}
+
+const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
+ if (!BackedgeCount) {
+ SCEVUnionPredicate BackedgePred;
+ BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred);
+ addPredicate(BackedgePred);
+ }
+ return BackedgeCount;
+}
+
+void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
+ if (Preds.implies(&Pred))
+ return;
+ Preds.add(&Pred);
+ updateGeneration();
+}
+
+const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
+ return Preds;
+}
+
+void PredicatedScalarEvolution::updateGeneration() {
+ // If the generation number wrapped recompute everything.
+ if (++Generation == 0) {
+ for (auto &II : RewriteMap) {
+ const SCEV *Rewritten = II.second.second;
+ II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
+ }
+ }
+}
+
+void PredicatedScalarEvolution::setNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
+
+ // Clear the statically implied flags.
+ Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
+ addPredicate(*SE.getWrapPredicate(AR, Flags));
+
+ auto II = FlagsMap.insert({V, Flags});
+ if (!II.second)
+ II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
+}
+
+bool PredicatedScalarEvolution::hasNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ Flags = SCEVWrapPredicate::clearFlags(
+ Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
+
+ auto II = FlagsMap.find(V);
+
+ if (II != FlagsMap.end())
+ Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
+
+ return Flags == SCEVWrapPredicate::IncrementAnyWrap;
+}
+
+const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
+ const SCEV *Expr = this->getSCEV(V);
+ SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
+ auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
+
+ if (!New)
+ return nullptr;
+
+ for (auto *P : NewPreds)
+ Preds.add(P);
+
+ updateGeneration();
+ RewriteMap[SE.getSCEV(V)] = {Generation, New};
+ return New;
+}
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(
+ const PredicatedScalarEvolution &Init)
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ for (const auto &I : Init.FlagsMap)
+ FlagsMap.insert(I);
+}
+
+void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
+ // For each block.
+ for (auto *BB : L.getBlocks())
+ for (auto &I : *BB) {
+ if (!SE.isSCEVable(I.getType()))
+ continue;
+
+ auto *Expr = SE.getSCEV(&I);
+ auto II = RewriteMap.find(Expr);
+
+ if (II == RewriteMap.end())
+ continue;
+
+ // Don't print things that are not interesting.
+ if (II->second.second == Expr)
+ continue;
+
+ OS.indent(Depth) << "[PSE]" << I << ":\n";
+ OS.indent(Depth + 2) << *Expr << "\n";
+ OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
+ }
+}
+
+// Match the mathematical pattern A - (A / B) * B, where A and B can be
+// arbitrary expressions.
+// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
+// 4, A / B becomes X / 8).
+bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
+ const SCEV *&RHS) {
+ const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
+ if (Add == nullptr || Add->getNumOperands() != 2)
+ return false;
+
+ const SCEV *A = Add->getOperand(1);
+ const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+
+ if (Mul == nullptr)
+ return false;
+
+ const auto MatchURemWithDivisor = [&](const SCEV *B) {
+ // (SomeExpr + (-(SomeExpr / B) * B)).
+ if (Expr == getURemExpr(A, B)) {
+ LHS = A;
+ RHS = B;
+ return true;
+ }
+ return false;
+ };
+
+ // (SomeExpr + (-1 * (SomeExpr / B) * B)).
+ if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(2));
+
+ // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
+ if (Mul->getNumOperands() == 2)
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(0)) ||
+ MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
+ MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
+ return false;
+}
diff --git a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp
new file mode 100644
index 000000000000..96da0a24cddd
--- /dev/null
+++ b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp
@@ -0,0 +1,147 @@
+//===- ScalarEvolutionAliasAnalysis.cpp - SCEV-based Alias Analysis -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ScalarEvolutionAliasAnalysis pass, which implements a
+// simple alias analysis implemented in terms of ScalarEvolution queries.
+//
+// This differs from traditional loop dependence analysis in that it tests
+// for dependencies within a single iteration of a loop, rather than
+// dependencies between different iterations.
+//
+// ScalarEvolution has a more complete understanding of pointer arithmetic
+// than BasicAliasAnalysis' collection of ad-hoc analyses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
+using namespace llvm;
+
+AliasResult SCEVAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB, AAQueryInfo &AAQI) {
+ // If either of the memory references is empty, it doesn't matter what the
+ // pointer values are. This allows the code below to ignore this special
+ // case.
+ if (LocA.Size.isZero() || LocB.Size.isZero())
+ return NoAlias;
+
+ // This is SCEVAAResult. Get the SCEVs!
+ const SCEV *AS = SE.getSCEV(const_cast<Value *>(LocA.Ptr));
+ const SCEV *BS = SE.getSCEV(const_cast<Value *>(LocB.Ptr));
+
+ // If they evaluate to the same expression, it's a MustAlias.
+ if (AS == BS)
+ return MustAlias;
+
+ // If something is known about the difference between the two addresses,
+ // see if it's enough to prove a NoAlias.
+ if (SE.getEffectiveSCEVType(AS->getType()) ==
+ SE.getEffectiveSCEVType(BS->getType())) {
+ unsigned BitWidth = SE.getTypeSizeInBits(AS->getType());
+ APInt ASizeInt(BitWidth, LocA.Size.hasValue()
+ ? LocA.Size.getValue()
+ : MemoryLocation::UnknownSize);
+ APInt BSizeInt(BitWidth, LocB.Size.hasValue()
+ ? LocB.Size.getValue()
+ : MemoryLocation::UnknownSize);
+
+ // Compute the difference between the two pointers.
+ const SCEV *BA = SE.getMinusSCEV(BS, AS);
+
+ // Test whether the difference is known to be great enough that memory of
+ // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
+ // are non-zero, which is special-cased above.
+ if (ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) &&
+ (-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax()))
+ return NoAlias;
+
+ // Folding the subtraction while preserving range information can be tricky
+ // (because of INT_MIN, etc.); if the prior test failed, swap AS and BS
+ // and try again to see if things fold better that way.
+
+ // Compute the difference between the two pointers.
+ const SCEV *AB = SE.getMinusSCEV(AS, BS);
+
+ // Test whether the difference is known to be great enough that memory of
+ // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
+ // are non-zero, which is special-cased above.
+ if (BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) &&
+ (-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax()))
+ return NoAlias;
+ }
+
+ // If ScalarEvolution can find an underlying object, form a new query.
+ // The correctness of this depends on ScalarEvolution not recognizing
+ // inttoptr and ptrtoint operators.
+ Value *AO = GetBaseValue(AS);
+ Value *BO = GetBaseValue(BS);
+ if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr))
+ if (alias(MemoryLocation(AO ? AO : LocA.Ptr,
+ AO ? LocationSize::unknown() : LocA.Size,
+ AO ? AAMDNodes() : LocA.AATags),
+ MemoryLocation(BO ? BO : LocB.Ptr,
+ BO ? LocationSize::unknown() : LocB.Size,
+ BO ? AAMDNodes() : LocB.AATags),
+ AAQI) == NoAlias)
+ return NoAlias;
+
+ // Forward the query to the next analysis.
+ return AAResultBase::alias(LocA, LocB, AAQI);
+}
+
+/// Given an expression, try to find a base value.
+///
+/// Returns null if none was found.
+Value *SCEVAAResult::GetBaseValue(const SCEV *S) {
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+ // In an addrec, assume that the base will be in the start, rather
+ // than the step.
+ return GetBaseValue(AR->getStart());
+ } else if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
+ // If there's a pointer operand, it'll be sorted at the end of the list.
+ const SCEV *Last = A->getOperand(A->getNumOperands() - 1);
+ if (Last->getType()->isPointerTy())
+ return GetBaseValue(Last);
+ } else if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ // This is a leaf node.
+ return U->getValue();
+ }
+ // No Identified object found.
+ return nullptr;
+}
+
+AnalysisKey SCEVAA::Key;
+
+SCEVAAResult SCEVAA::run(Function &F, FunctionAnalysisManager &AM) {
+ return SCEVAAResult(AM.getResult<ScalarEvolutionAnalysis>(F));
+}
+
+char SCEVAAWrapperPass::ID = 0;
+INITIALIZE_PASS_BEGIN(SCEVAAWrapperPass, "scev-aa",
+ "ScalarEvolution-based Alias Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_END(SCEVAAWrapperPass, "scev-aa",
+ "ScalarEvolution-based Alias Analysis", false, true)
+
+FunctionPass *llvm::createSCEVAAWrapperPass() {
+ return new SCEVAAWrapperPass();
+}
+
+SCEVAAWrapperPass::SCEVAAWrapperPass() : FunctionPass(ID) {
+ initializeSCEVAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool SCEVAAWrapperPass::runOnFunction(Function &F) {
+ Result.reset(
+ new SCEVAAResult(getAnalysis<ScalarEvolutionWrapperPass>().getSE()));
+ return false;
+}
+
+void SCEVAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+}
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
new file mode 100644
index 000000000000..bceec921188e
--- /dev/null
+++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
@@ -0,0 +1,2452 @@
+//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the scalar evolution expander,
+// which is used to generate the code corresponding to a given scalar evolution
+// expression.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+using namespace PatternMatch;
+
+/// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP,
+/// reusing an existing cast if a suitable one exists, moving an existing
+/// cast if a suitable one exists but isn't in the right place, or
+/// creating a new one.
+Value *SCEVExpander::ReuseOrCreateCast(Value *V, Type *Ty,
+ Instruction::CastOps Op,
+ BasicBlock::iterator IP) {
+ // This function must be called with the builder having a valid insertion
+ // point. It doesn't need to be the actual IP where the uses of the returned
+ // cast will be added, but it must dominate such IP.
+ // We use this precondition to produce a cast that will dominate all its
+ // uses. In particular, this is crucial for the case where the builder's
+ // insertion point *is* the point where we were asked to put the cast.
+ // Since we don't know the builder's insertion point is actually
+ // where the uses will be added (only that it dominates it), we are
+ // not allowed to move it.
+ BasicBlock::iterator BIP = Builder.GetInsertPoint();
+
+ Instruction *Ret = nullptr;
+
+ // Check to see if there is already a cast!
+ for (User *U : V->users())
+ if (U->getType() == Ty)
+ if (CastInst *CI = dyn_cast<CastInst>(U))
+ if (CI->getOpcode() == Op) {
+ // If the cast isn't where we want it, create a new cast at IP.
+ // Likewise, do not reuse a cast at BIP because it must dominate
+ // instructions that might be inserted before BIP.
+ if (BasicBlock::iterator(CI) != IP || BIP == IP) {
+ // Create a new cast, and leave the old cast in place in case
+ // it is being used as an insert point.
+ Ret = CastInst::Create(Op, V, Ty, "", &*IP);
+ Ret->takeName(CI);
+ CI->replaceAllUsesWith(Ret);
+ break;
+ }
+ Ret = CI;
+ break;
+ }
+
+ // Create a new cast.
+ if (!Ret)
+ Ret = CastInst::Create(Op, V, Ty, V->getName(), &*IP);
+
+ // We assert at the end of the function since IP might point to an
+ // instruction with different dominance properties than a cast
+ // (an invoke for example) and not dominate BIP (but the cast does).
+ assert(SE.DT.dominates(Ret, &*BIP));
+
+ rememberInstruction(Ret);
+ return Ret;
+}
+
+static BasicBlock::iterator findInsertPointAfter(Instruction *I,
+ BasicBlock *MustDominate) {
+ BasicBlock::iterator IP = ++I->getIterator();
+ if (auto *II = dyn_cast<InvokeInst>(I))
+ IP = II->getNormalDest()->begin();
+
+ while (isa<PHINode>(IP))
+ ++IP;
+
+ if (isa<FuncletPadInst>(IP) || isa<LandingPadInst>(IP)) {
+ ++IP;
+ } else if (isa<CatchSwitchInst>(IP)) {
+ IP = MustDominate->getFirstInsertionPt();
+ } else {
+ assert(!IP->isEHPad() && "unexpected eh pad!");
+ }
+
+ return IP;
+}
+
+/// InsertNoopCastOfTo - Insert a cast of V to the specified type,
+/// which must be possible with a noop cast, doing what we can to share
+/// the casts.
+Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) {
+ Instruction::CastOps Op = CastInst::getCastOpcode(V, false, Ty, false);
+ assert((Op == Instruction::BitCast ||
+ Op == Instruction::PtrToInt ||
+ Op == Instruction::IntToPtr) &&
+ "InsertNoopCastOfTo cannot perform non-noop casts!");
+ assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) &&
+ "InsertNoopCastOfTo cannot change sizes!");
+
+ // Short-circuit unnecessary bitcasts.
+ if (Op == Instruction::BitCast) {
+ if (V->getType() == Ty)
+ return V;
+ if (CastInst *CI = dyn_cast<CastInst>(V)) {
+ if (CI->getOperand(0)->getType() == Ty)
+ return CI->getOperand(0);
+ }
+ }
+ // Short-circuit unnecessary inttoptr<->ptrtoint casts.
+ if ((Op == Instruction::PtrToInt || Op == Instruction::IntToPtr) &&
+ SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(V->getType())) {
+ if (CastInst *CI = dyn_cast<CastInst>(V))
+ if ((CI->getOpcode() == Instruction::PtrToInt ||
+ CI->getOpcode() == Instruction::IntToPtr) &&
+ SE.getTypeSizeInBits(CI->getType()) ==
+ SE.getTypeSizeInBits(CI->getOperand(0)->getType()))
+ return CI->getOperand(0);
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
+ if ((CE->getOpcode() == Instruction::PtrToInt ||
+ CE->getOpcode() == Instruction::IntToPtr) &&
+ SE.getTypeSizeInBits(CE->getType()) ==
+ SE.getTypeSizeInBits(CE->getOperand(0)->getType()))
+ return CE->getOperand(0);
+ }
+
+ // Fold a cast of a constant.
+ if (Constant *C = dyn_cast<Constant>(V))
+ return ConstantExpr::getCast(Op, C, Ty);
+
+ // Cast the argument at the beginning of the entry block, after
+ // any bitcasts of other arguments.
+ if (Argument *A = dyn_cast<Argument>(V)) {
+ BasicBlock::iterator IP = A->getParent()->getEntryBlock().begin();
+ while ((isa<BitCastInst>(IP) &&
+ isa<Argument>(cast<BitCastInst>(IP)->getOperand(0)) &&
+ cast<BitCastInst>(IP)->getOperand(0) != A) ||
+ isa<DbgInfoIntrinsic>(IP))
+ ++IP;
+ return ReuseOrCreateCast(A, Ty, Op, IP);
+ }
+
+ // Cast the instruction immediately after the instruction.
+ Instruction *I = cast<Instruction>(V);
+ BasicBlock::iterator IP = findInsertPointAfter(I, Builder.GetInsertBlock());
+ return ReuseOrCreateCast(I, Ty, Op, IP);
+}
+
+/// InsertBinop - Insert the specified binary operator, doing a small amount
+/// of work to avoid inserting an obviously redundant operation, and hoisting
+/// to an outer loop when the opportunity is there and it is safe.
+Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
+ Value *LHS, Value *RHS,
+ SCEV::NoWrapFlags Flags, bool IsSafeToHoist) {
+ // Fold a binop with constant operands.
+ if (Constant *CLHS = dyn_cast<Constant>(LHS))
+ if (Constant *CRHS = dyn_cast<Constant>(RHS))
+ return ConstantExpr::get(Opcode, CLHS, CRHS);
+
+ // Do a quick scan to see if we have this binop nearby. If so, reuse it.
+ unsigned ScanLimit = 6;
+ BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin();
+ // Scanning starts from the last instruction before the insertion point.
+ BasicBlock::iterator IP = Builder.GetInsertPoint();
+ if (IP != BlockBegin) {
+ --IP;
+ for (; ScanLimit; --IP, --ScanLimit) {
+ // Don't count dbg.value against the ScanLimit, to avoid perturbing the
+ // generated code.
+ if (isa<DbgInfoIntrinsic>(IP))
+ ScanLimit++;
+
+ auto canGenerateIncompatiblePoison = [&Flags](Instruction *I) {
+ // Ensure that no-wrap flags match.
+ if (isa<OverflowingBinaryOperator>(I)) {
+ if (I->hasNoSignedWrap() != (Flags & SCEV::FlagNSW))
+ return true;
+ if (I->hasNoUnsignedWrap() != (Flags & SCEV::FlagNUW))
+ return true;
+ }
+ // Conservatively, do not use any instruction which has any of exact
+ // flags installed.
+ if (isa<PossiblyExactOperator>(I) && I->isExact())
+ return true;
+ return false;
+ };
+ if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS &&
+ IP->getOperand(1) == RHS && !canGenerateIncompatiblePoison(&*IP))
+ return &*IP;
+ if (IP == BlockBegin) break;
+ }
+ }
+
+ // Save the original insertion point so we can restore it when we're done.
+ DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc();
+ SCEVInsertPointGuard Guard(Builder, this);
+
+ if (IsSafeToHoist) {
+ // Move the insertion point out of as many loops as we can.
+ while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
+ if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break;
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) break;
+
+ // Ok, move up a level.
+ Builder.SetInsertPoint(Preheader->getTerminator());
+ }
+ }
+
+ // If we haven't found this binop, insert it.
+ Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS));
+ BO->setDebugLoc(Loc);
+ if (Flags & SCEV::FlagNUW)
+ BO->setHasNoUnsignedWrap();
+ if (Flags & SCEV::FlagNSW)
+ BO->setHasNoSignedWrap();
+ rememberInstruction(BO);
+
+ return BO;
+}
+
+/// FactorOutConstant - Test if S is divisible by Factor, using signed
+/// division. If so, update S with Factor divided out and return true.
+/// S need not be evenly divisible if a reasonable remainder can be
+/// computed.
+static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder,
+ const SCEV *Factor, ScalarEvolution &SE,
+ const DataLayout &DL) {
+ // Everything is divisible by one.
+ if (Factor->isOne())
+ return true;
+
+ // x/x == 1.
+ if (S == Factor) {
+ S = SE.getConstant(S->getType(), 1);
+ return true;
+ }
+
+ // For a Constant, check for a multiple of the given factor.
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
+ // 0/x == 0.
+ if (C->isZero())
+ return true;
+ // Check for divisibility.
+ if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) {
+ ConstantInt *CI =
+ ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt()));
+ // If the quotient is zero and the remainder is non-zero, reject
+ // the value at this scale. It will be considered for subsequent
+ // smaller scales.
+ if (!CI->isZero()) {
+ const SCEV *Div = SE.getConstant(CI);
+ S = Div;
+ Remainder = SE.getAddExpr(
+ Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt())));
+ return true;
+ }
+ }
+ }
+
+ // In a Mul, check if there is a constant operand which is a multiple
+ // of the given factor.
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
+ // Size is known, check if there is a constant operand which is a multiple
+ // of the given factor. If so, we can factor it.
+ const SCEVConstant *FC = cast<SCEVConstant>(Factor);
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
+ if (!C->getAPInt().srem(FC->getAPInt())) {
+ SmallVector<const SCEV *, 4> NewMulOps(M->op_begin(), M->op_end());
+ NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt()));
+ S = SE.getMulExpr(NewMulOps);
+ return true;
+ }
+ }
+
+ // In an AddRec, check if both start and step are divisible.
+ if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
+ const SCEV *Step = A->getStepRecurrence(SE);
+ const SCEV *StepRem = SE.getConstant(Step->getType(), 0);
+ if (!FactorOutConstant(Step, StepRem, Factor, SE, DL))
+ return false;
+ if (!StepRem->isZero())
+ return false;
+ const SCEV *Start = A->getStart();
+ if (!FactorOutConstant(Start, Remainder, Factor, SE, DL))
+ return false;
+ S = SE.getAddRecExpr(Start, Step, A->getLoop(),
+ A->getNoWrapFlags(SCEV::FlagNW));
+ return true;
+ }
+
+ return false;
+}
+
+/// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs
+/// is the number of SCEVAddRecExprs present, which are kept at the end of
+/// the list.
+///
+static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops,
+ Type *Ty,
+ ScalarEvolution &SE) {
+ unsigned NumAddRecs = 0;
+ for (unsigned i = Ops.size(); i > 0 && isa<SCEVAddRecExpr>(Ops[i-1]); --i)
+ ++NumAddRecs;
+ // Group Ops into non-addrecs and addrecs.
+ SmallVector<const SCEV *, 8> NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs);
+ SmallVector<const SCEV *, 8> AddRecs(Ops.end() - NumAddRecs, Ops.end());
+ // Let ScalarEvolution sort and simplify the non-addrecs list.
+ const SCEV *Sum = NoAddRecs.empty() ?
+ SE.getConstant(Ty, 0) :
+ SE.getAddExpr(NoAddRecs);
+ // If it returned an add, use the operands. Otherwise it simplified
+ // the sum into a single value, so just use that.
+ Ops.clear();
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum))
+ Ops.append(Add->op_begin(), Add->op_end());
+ else if (!Sum->isZero())
+ Ops.push_back(Sum);
+ // Then append the addrecs.
+ Ops.append(AddRecs.begin(), AddRecs.end());
+}
+
+/// SplitAddRecs - Flatten a list of add operands, moving addrec start values
+/// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}.
+/// This helps expose more opportunities for folding parts of the expressions
+/// into GEP indices.
+///
+static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops,
+ Type *Ty,
+ ScalarEvolution &SE) {
+ // Find the addrecs.
+ SmallVector<const SCEV *, 8> AddRecs;
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) {
+ const SCEV *Start = A->getStart();
+ if (Start->isZero()) break;
+ const SCEV *Zero = SE.getConstant(Ty, 0);
+ AddRecs.push_back(SE.getAddRecExpr(Zero,
+ A->getStepRecurrence(SE),
+ A->getLoop(),
+ A->getNoWrapFlags(SCEV::FlagNW)));
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) {
+ Ops[i] = Zero;
+ Ops.append(Add->op_begin(), Add->op_end());
+ e += Add->getNumOperands();
+ } else {
+ Ops[i] = Start;
+ }
+ }
+ if (!AddRecs.empty()) {
+ // Add the addrecs onto the end of the list.
+ Ops.append(AddRecs.begin(), AddRecs.end());
+ // Resort the operand list, moving any constants to the front.
+ SimplifyAddOperands(Ops, Ty, SE);
+ }
+}
+
+/// expandAddToGEP - Expand an addition expression with a pointer type into
+/// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps
+/// BasicAliasAnalysis and other passes analyze the result. See the rules
+/// for getelementptr vs. inttoptr in
+/// http://llvm.org/docs/LangRef.html#pointeraliasing
+/// for details.
+///
+/// Design note: The correctness of using getelementptr here depends on
+/// ScalarEvolution not recognizing inttoptr and ptrtoint operators, as
+/// they may introduce pointer arithmetic which may not be safely converted
+/// into getelementptr.
+///
+/// Design note: It might seem desirable for this function to be more
+/// loop-aware. If some of the indices are loop-invariant while others
+/// aren't, it might seem desirable to emit multiple GEPs, keeping the
+/// loop-invariant portions of the overall computation outside the loop.
+/// However, there are a few reasons this is not done here. Hoisting simple
+/// arithmetic is a low-level optimization that often isn't very
+/// important until late in the optimization process. In fact, passes
+/// like InstructionCombining will combine GEPs, even if it means
+/// pushing loop-invariant computation down into loops, so even if the
+/// GEPs were split here, the work would quickly be undone. The
+/// LoopStrengthReduction pass, which is usually run quite late (and
+/// after the last InstructionCombining pass), takes care of hoisting
+/// loop-invariant portions of expressions, after considering what
+/// can be folded using target addressing modes.
+///
+Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
+ const SCEV *const *op_end,
+ PointerType *PTy,
+ Type *Ty,
+ Value *V) {
+ Type *OriginalElTy = PTy->getElementType();
+ Type *ElTy = OriginalElTy;
+ SmallVector<Value *, 4> GepIndices;
+ SmallVector<const SCEV *, 8> Ops(op_begin, op_end);
+ bool AnyNonZeroIndices = false;
+
+ // Split AddRecs up into parts as either of the parts may be usable
+ // without the other.
+ SplitAddRecs(Ops, Ty, SE);
+
+ Type *IntPtrTy = DL.getIntPtrType(PTy);
+
+ // Descend down the pointer's type and attempt to convert the other
+ // operands into GEP indices, at each level. The first index in a GEP
+ // indexes into the array implied by the pointer operand; the rest of
+ // the indices index into the element or field type selected by the
+ // preceding index.
+ for (;;) {
+ // If the scale size is not 0, attempt to factor out a scale for
+ // array indexing.
+ SmallVector<const SCEV *, 8> ScaledOps;
+ if (ElTy->isSized()) {
+ const SCEV *ElSize = SE.getSizeOfExpr(IntPtrTy, ElTy);
+ if (!ElSize->isZero()) {
+ SmallVector<const SCEV *, 8> NewOps;
+ for (const SCEV *Op : Ops) {
+ const SCEV *Remainder = SE.getConstant(Ty, 0);
+ if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) {
+ // Op now has ElSize factored out.
+ ScaledOps.push_back(Op);
+ if (!Remainder->isZero())
+ NewOps.push_back(Remainder);
+ AnyNonZeroIndices = true;
+ } else {
+ // The operand was not divisible, so add it to the list of operands
+ // we'll scan next iteration.
+ NewOps.push_back(Op);
+ }
+ }
+ // If we made any changes, update Ops.
+ if (!ScaledOps.empty()) {
+ Ops = NewOps;
+ SimplifyAddOperands(Ops, Ty, SE);
+ }
+ }
+ }
+
+ // Record the scaled array index for this level of the type. If
+ // we didn't find any operands that could be factored, tentatively
+ // assume that element zero was selected (since the zero offset
+ // would obviously be folded away).
+ Value *Scaled = ScaledOps.empty() ?
+ Constant::getNullValue(Ty) :
+ expandCodeFor(SE.getAddExpr(ScaledOps), Ty);
+ GepIndices.push_back(Scaled);
+
+ // Collect struct field index operands.
+ while (StructType *STy = dyn_cast<StructType>(ElTy)) {
+ bool FoundFieldNo = false;
+ // An empty struct has no fields.
+ if (STy->getNumElements() == 0) break;
+ // Field offsets are known. See if a constant offset falls within any of
+ // the struct fields.
+ if (Ops.empty())
+ break;
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0]))
+ if (SE.getTypeSizeInBits(C->getType()) <= 64) {
+ const StructLayout &SL = *DL.getStructLayout(STy);
+ uint64_t FullOffset = C->getValue()->getZExtValue();
+ if (FullOffset < SL.getSizeInBytes()) {
+ unsigned ElIdx = SL.getElementContainingOffset(FullOffset);
+ GepIndices.push_back(
+ ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx));
+ ElTy = STy->getTypeAtIndex(ElIdx);
+ Ops[0] =
+ SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx));
+ AnyNonZeroIndices = true;
+ FoundFieldNo = true;
+ }
+ }
+ // If no struct field offsets were found, tentatively assume that
+ // field zero was selected (since the zero offset would obviously
+ // be folded away).
+ if (!FoundFieldNo) {
+ ElTy = STy->getTypeAtIndex(0u);
+ GepIndices.push_back(
+ Constant::getNullValue(Type::getInt32Ty(Ty->getContext())));
+ }
+ }
+
+ if (ArrayType *ATy = dyn_cast<ArrayType>(ElTy))
+ ElTy = ATy->getElementType();
+ else
+ break;
+ }
+
+ // If none of the operands were convertible to proper GEP indices, cast
+ // the base to i8* and do an ugly getelementptr with that. It's still
+ // better than ptrtoint+arithmetic+inttoptr at least.
+ if (!AnyNonZeroIndices) {
+ // Cast the base to i8*.
+ V = InsertNoopCastOfTo(V,
+ Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace()));
+
+ assert(!isa<Instruction>(V) ||
+ SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
+
+ // Expand the operands for a plain byte offset.
+ Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty);
+
+ // Fold a GEP with constant operands.
+ if (Constant *CLHS = dyn_cast<Constant>(V))
+ if (Constant *CRHS = dyn_cast<Constant>(Idx))
+ return ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ty->getContext()),
+ CLHS, CRHS);
+
+ // Do a quick scan to see if we have this GEP nearby. If so, reuse it.
+ unsigned ScanLimit = 6;
+ BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin();
+ // Scanning starts from the last instruction before the insertion point.
+ BasicBlock::iterator IP = Builder.GetInsertPoint();
+ if (IP != BlockBegin) {
+ --IP;
+ for (; ScanLimit; --IP, --ScanLimit) {
+ // Don't count dbg.value against the ScanLimit, to avoid perturbing the
+ // generated code.
+ if (isa<DbgInfoIntrinsic>(IP))
+ ScanLimit++;
+ if (IP->getOpcode() == Instruction::GetElementPtr &&
+ IP->getOperand(0) == V && IP->getOperand(1) == Idx)
+ return &*IP;
+ if (IP == BlockBegin) break;
+ }
+ }
+
+ // Save the original insertion point so we can restore it when we're done.
+ SCEVInsertPointGuard Guard(Builder, this);
+
+ // Move the insertion point out of as many loops as we can.
+ while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
+ if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break;
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) break;
+
+ // Ok, move up a level.
+ Builder.SetInsertPoint(Preheader->getTerminator());
+ }
+
+ // Emit a GEP.
+ Value *GEP = Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep");
+ rememberInstruction(GEP);
+
+ return GEP;
+ }
+
+ {
+ SCEVInsertPointGuard Guard(Builder, this);
+
+ // Move the insertion point out of as many loops as we can.
+ while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
+ if (!L->isLoopInvariant(V)) break;
+
+ bool AnyIndexNotLoopInvariant = any_of(
+ GepIndices, [L](Value *Op) { return !L->isLoopInvariant(Op); });
+
+ if (AnyIndexNotLoopInvariant)
+ break;
+
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) break;
+
+ // Ok, move up a level.
+ Builder.SetInsertPoint(Preheader->getTerminator());
+ }
+
+ // Insert a pretty getelementptr. Note that this GEP is not marked inbounds,
+ // because ScalarEvolution may have changed the address arithmetic to
+ // compute a value which is beyond the end of the allocated object.
+ Value *Casted = V;
+ if (V->getType() != PTy)
+ Casted = InsertNoopCastOfTo(Casted, PTy);
+ Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep");
+ Ops.push_back(SE.getUnknown(GEP));
+ rememberInstruction(GEP);
+ }
+
+ return expand(SE.getAddExpr(Ops));
+}
+
+Value *SCEVExpander::expandAddToGEP(const SCEV *Op, PointerType *PTy, Type *Ty,
+ Value *V) {
+ const SCEV *const Ops[1] = {Op};
+ return expandAddToGEP(Ops, Ops + 1, PTy, Ty, V);
+}
+
+/// PickMostRelevantLoop - Given two loops pick the one that's most relevant for
+/// SCEV expansion. If they are nested, this is the most nested. If they are
+/// neighboring, pick the later.
+static const Loop *PickMostRelevantLoop(const Loop *A, const Loop *B,
+ DominatorTree &DT) {
+ if (!A) return B;
+ if (!B) return A;
+ if (A->contains(B)) return B;
+ if (B->contains(A)) return A;
+ if (DT.dominates(A->getHeader(), B->getHeader())) return B;
+ if (DT.dominates(B->getHeader(), A->getHeader())) return A;
+ return A; // Arbitrarily break the tie.
+}
+
+/// getRelevantLoop - Get the most relevant loop associated with the given
+/// expression, according to PickMostRelevantLoop.
+const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) {
+ // Test whether we've already computed the most relevant loop for this SCEV.
+ auto Pair = RelevantLoops.insert(std::make_pair(S, nullptr));
+ if (!Pair.second)
+ return Pair.first->second;
+
+ if (isa<SCEVConstant>(S))
+ // A constant has no relevant loops.
+ return nullptr;
+ if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ if (const Instruction *I = dyn_cast<Instruction>(U->getValue()))
+ return Pair.first->second = SE.LI.getLoopFor(I->getParent());
+ // A non-instruction has no relevant loops.
+ return nullptr;
+ }
+ if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) {
+ const Loop *L = nullptr;
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
+ L = AR->getLoop();
+ for (const SCEV *Op : N->operands())
+ L = PickMostRelevantLoop(L, getRelevantLoop(Op), SE.DT);
+ return RelevantLoops[N] = L;
+ }
+ if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) {
+ const Loop *Result = getRelevantLoop(C->getOperand());
+ return RelevantLoops[C] = Result;
+ }
+ if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) {
+ const Loop *Result = PickMostRelevantLoop(
+ getRelevantLoop(D->getLHS()), getRelevantLoop(D->getRHS()), SE.DT);
+ return RelevantLoops[D] = Result;
+ }
+ llvm_unreachable("Unexpected SCEV type!");
+}
+
+namespace {
+
+/// LoopCompare - Compare loops by PickMostRelevantLoop.
+class LoopCompare {
+ DominatorTree &DT;
+public:
+ explicit LoopCompare(DominatorTree &dt) : DT(dt) {}
+
+ bool operator()(std::pair<const Loop *, const SCEV *> LHS,
+ std::pair<const Loop *, const SCEV *> RHS) const {
+ // Keep pointer operands sorted at the end.
+ if (LHS.second->getType()->isPointerTy() !=
+ RHS.second->getType()->isPointerTy())
+ return LHS.second->getType()->isPointerTy();
+
+ // Compare loops with PickMostRelevantLoop.
+ if (LHS.first != RHS.first)
+ return PickMostRelevantLoop(LHS.first, RHS.first, DT) != LHS.first;
+
+ // If one operand is a non-constant negative and the other is not,
+ // put the non-constant negative on the right so that a sub can
+ // be used instead of a negate and add.
+ if (LHS.second->isNonConstantNegative()) {
+ if (!RHS.second->isNonConstantNegative())
+ return false;
+ } else if (RHS.second->isNonConstantNegative())
+ return true;
+
+ // Otherwise they are equivalent according to this comparison.
+ return false;
+ }
+};
+
+}
+
+Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+
+ // Collect all the add operands in a loop, along with their associated loops.
+ // Iterate in reverse so that constants are emitted last, all else equal, and
+ // so that pointer operands are inserted first, which the code below relies on
+ // to form more involved GEPs.
+ SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops;
+ for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(S->op_end()),
+ E(S->op_begin()); I != E; ++I)
+ OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I));
+
+ // Sort by loop. Use a stable sort so that constants follow non-constants and
+ // pointer operands precede non-pointer operands.
+ llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT));
+
+ // Emit instructions to add all the operands. Hoist as much as possible
+ // out of loops, and form meaningful getelementptrs where possible.
+ Value *Sum = nullptr;
+ for (auto I = OpsAndLoops.begin(), E = OpsAndLoops.end(); I != E;) {
+ const Loop *CurLoop = I->first;
+ const SCEV *Op = I->second;
+ if (!Sum) {
+ // This is the first operand. Just expand it.
+ Sum = expand(Op);
+ ++I;
+ } else if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) {
+ // The running sum expression is a pointer. Try to form a getelementptr
+ // at this level with that as the base.
+ SmallVector<const SCEV *, 4> NewOps;
+ for (; I != E && I->first == CurLoop; ++I) {
+ // If the operand is SCEVUnknown and not instructions, peek through
+ // it, to enable more of it to be folded into the GEP.
+ const SCEV *X = I->second;
+ if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(X))
+ if (!isa<Instruction>(U->getValue()))
+ X = SE.getSCEV(U->getValue());
+ NewOps.push_back(X);
+ }
+ Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum);
+ } else if (PointerType *PTy = dyn_cast<PointerType>(Op->getType())) {
+ // The running sum is an integer, and there's a pointer at this level.
+ // Try to form a getelementptr. If the running sum is instructions,
+ // use a SCEVUnknown to avoid re-analyzing them.
+ SmallVector<const SCEV *, 4> NewOps;
+ NewOps.push_back(isa<Instruction>(Sum) ? SE.getUnknown(Sum) :
+ SE.getSCEV(Sum));
+ for (++I; I != E && I->first == CurLoop; ++I)
+ NewOps.push_back(I->second);
+ Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op));
+ } else if (Op->isNonConstantNegative()) {
+ // Instead of doing a negate and add, just do a subtract.
+ Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty);
+ Sum = InsertNoopCastOfTo(Sum, Ty);
+ Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap,
+ /*IsSafeToHoist*/ true);
+ ++I;
+ } else {
+ // A simple add.
+ Value *W = expandCodeFor(Op, Ty);
+ Sum = InsertNoopCastOfTo(Sum, Ty);
+ // Canonicalize a constant to the RHS.
+ if (isa<Constant>(Sum)) std::swap(Sum, W);
+ Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(),
+ /*IsSafeToHoist*/ true);
+ ++I;
+ }
+ }
+
+ return Sum;
+}
+
+Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+
+ // Collect all the mul operands in a loop, along with their associated loops.
+ // Iterate in reverse so that constants are emitted last, all else equal.
+ SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops;
+ for (std::reverse_iterator<SCEVMulExpr::op_iterator> I(S->op_end()),
+ E(S->op_begin()); I != E; ++I)
+ OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I));
+
+ // Sort by loop. Use a stable sort so that constants follow non-constants.
+ llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT));
+
+ // Emit instructions to mul all the operands. Hoist as much as possible
+ // out of loops.
+ Value *Prod = nullptr;
+ auto I = OpsAndLoops.begin();
+
+ // Expand the calculation of X pow N in the following manner:
+ // Let N = P1 + P2 + ... + PK, where all P are powers of 2. Then:
+ // X pow N = (X pow P1) * (X pow P2) * ... * (X pow PK).
+ const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops, &Ty]() {
+ auto E = I;
+ // Calculate how many times the same operand from the same loop is included
+ // into this power.
+ uint64_t Exponent = 0;
+ const uint64_t MaxExponent = UINT64_MAX >> 1;
+ // No one sane will ever try to calculate such huge exponents, but if we
+ // need this, we stop on UINT64_MAX / 2 because we need to exit the loop
+ // below when the power of 2 exceeds our Exponent, and we want it to be
+ // 1u << 31 at most to not deal with unsigned overflow.
+ while (E != OpsAndLoops.end() && *I == *E && Exponent != MaxExponent) {
+ ++Exponent;
+ ++E;
+ }
+ assert(Exponent > 0 && "Trying to calculate a zeroth exponent of operand?");
+
+ // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them
+ // that are needed into the result.
+ Value *P = expandCodeFor(I->second, Ty);
+ Value *Result = nullptr;
+ if (Exponent & 1)
+ Result = P;
+ for (uint64_t BinExp = 2; BinExp <= Exponent; BinExp <<= 1) {
+ P = InsertBinop(Instruction::Mul, P, P, SCEV::FlagAnyWrap,
+ /*IsSafeToHoist*/ true);
+ if (Exponent & BinExp)
+ Result = Result ? InsertBinop(Instruction::Mul, Result, P,
+ SCEV::FlagAnyWrap,
+ /*IsSafeToHoist*/ true)
+ : P;
+ }
+
+ I = E;
+ assert(Result && "Nothing was expanded?");
+ return Result;
+ };
+
+ while (I != OpsAndLoops.end()) {
+ if (!Prod) {
+ // This is the first operand. Just expand it.
+ Prod = ExpandOpBinPowN();
+ } else if (I->second->isAllOnesValue()) {
+ // Instead of doing a multiply by negative one, just do a negate.
+ Prod = InsertNoopCastOfTo(Prod, Ty);
+ Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod,
+ SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
+ ++I;
+ } else {
+ // A simple mul.
+ Value *W = ExpandOpBinPowN();
+ Prod = InsertNoopCastOfTo(Prod, Ty);
+ // Canonicalize a constant to the RHS.
+ if (isa<Constant>(Prod)) std::swap(Prod, W);
+ const APInt *RHS;
+ if (match(W, m_Power2(RHS))) {
+ // Canonicalize Prod*(1<<C) to Prod<<C.
+ assert(!Ty->isVectorTy() && "vector types are not SCEVable");
+ auto NWFlags = S->getNoWrapFlags();
+ // clear nsw flag if shl will produce poison value.
+ if (RHS->logBase2() == RHS->getBitWidth() - 1)
+ NWFlags = ScalarEvolution::clearFlags(NWFlags, SCEV::FlagNSW);
+ Prod = InsertBinop(Instruction::Shl, Prod,
+ ConstantInt::get(Ty, RHS->logBase2()), NWFlags,
+ /*IsSafeToHoist*/ true);
+ } else {
+ Prod = InsertBinop(Instruction::Mul, Prod, W, S->getNoWrapFlags(),
+ /*IsSafeToHoist*/ true);
+ }
+ }
+ }
+
+ return Prod;
+}
+
+Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+
+ Value *LHS = expandCodeFor(S->getLHS(), Ty);
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
+ const APInt &RHS = SC->getAPInt();
+ if (RHS.isPowerOf2())
+ return InsertBinop(Instruction::LShr, LHS,
+ ConstantInt::get(Ty, RHS.logBase2()),
+ SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
+ }
+
+ Value *RHS = expandCodeFor(S->getRHS(), Ty);
+ return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
+ /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
+}
+
+/// Move parts of Base into Rest to leave Base with the minimal
+/// expression that provides a pointer operand suitable for a
+/// GEP expansion.
+static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest,
+ ScalarEvolution &SE) {
+ while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) {
+ Base = A->getStart();
+ Rest = SE.getAddExpr(Rest,
+ SE.getAddRecExpr(SE.getConstant(A->getType(), 0),
+ A->getStepRecurrence(SE),
+ A->getLoop(),
+ A->getNoWrapFlags(SCEV::FlagNW)));
+ }
+ if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) {
+ Base = A->getOperand(A->getNumOperands()-1);
+ SmallVector<const SCEV *, 8> NewAddOps(A->op_begin(), A->op_end());
+ NewAddOps.back() = Rest;
+ Rest = SE.getAddExpr(NewAddOps);
+ ExposePointerBase(Base, Rest, SE);
+ }
+}
+
+/// Determine if this is a well-behaved chain of instructions leading back to
+/// the PHI. If so, it may be reused by expanded expressions.
+bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV,
+ const Loop *L) {
+ if (IncV->getNumOperands() == 0 || isa<PHINode>(IncV) ||
+ (isa<CastInst>(IncV) && !isa<BitCastInst>(IncV)))
+ return false;
+ // If any of the operands don't dominate the insert position, bail.
+ // Addrec operands are always loop-invariant, so this can only happen
+ // if there are instructions which haven't been hoisted.
+ if (L == IVIncInsertLoop) {
+ for (User::op_iterator OI = IncV->op_begin()+1,
+ OE = IncV->op_end(); OI != OE; ++OI)
+ if (Instruction *OInst = dyn_cast<Instruction>(OI))
+ if (!SE.DT.dominates(OInst, IVIncInsertPos))
+ return false;
+ }
+ // Advance to the next instruction.
+ IncV = dyn_cast<Instruction>(IncV->getOperand(0));
+ if (!IncV)
+ return false;
+
+ if (IncV->mayHaveSideEffects())
+ return false;
+
+ if (IncV == PN)
+ return true;
+
+ return isNormalAddRecExprPHI(PN, IncV, L);
+}
+
+/// getIVIncOperand returns an induction variable increment's induction
+/// variable operand.
+///
+/// If allowScale is set, any type of GEP is allowed as long as the nonIV
+/// operands dominate InsertPos.
+///
+/// If allowScale is not set, ensure that a GEP increment conforms to one of the
+/// simple patterns generated by getAddRecExprPHILiterally and
+/// expandAddtoGEP. If the pattern isn't recognized, return NULL.
+Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV,
+ Instruction *InsertPos,
+ bool allowScale) {
+ if (IncV == InsertPos)
+ return nullptr;
+
+ switch (IncV->getOpcode()) {
+ default:
+ return nullptr;
+ // Check for a simple Add/Sub or GEP of a loop invariant step.
+ case Instruction::Add:
+ case Instruction::Sub: {
+ Instruction *OInst = dyn_cast<Instruction>(IncV->getOperand(1));
+ if (!OInst || SE.DT.dominates(OInst, InsertPos))
+ return dyn_cast<Instruction>(IncV->getOperand(0));
+ return nullptr;
+ }
+ case Instruction::BitCast:
+ return dyn_cast<Instruction>(IncV->getOperand(0));
+ case Instruction::GetElementPtr:
+ for (auto I = IncV->op_begin() + 1, E = IncV->op_end(); I != E; ++I) {
+ if (isa<Constant>(*I))
+ continue;
+ if (Instruction *OInst = dyn_cast<Instruction>(*I)) {
+ if (!SE.DT.dominates(OInst, InsertPos))
+ return nullptr;
+ }
+ if (allowScale) {
+ // allow any kind of GEP as long as it can be hoisted.
+ continue;
+ }
+ // This must be a pointer addition of constants (pretty), which is already
+ // handled, or some number of address-size elements (ugly). Ugly geps
+ // have 2 operands. i1* is used by the expander to represent an
+ // address-size element.
+ if (IncV->getNumOperands() != 2)
+ return nullptr;
+ unsigned AS = cast<PointerType>(IncV->getType())->getAddressSpace();
+ if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS)
+ && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS))
+ return nullptr;
+ break;
+ }
+ return dyn_cast<Instruction>(IncV->getOperand(0));
+ }
+}
+
+/// If the insert point of the current builder or any of the builders on the
+/// stack of saved builders has 'I' as its insert point, update it to point to
+/// the instruction after 'I'. This is intended to be used when the instruction
+/// 'I' is being moved. If this fixup is not done and 'I' is moved to a
+/// different block, the inconsistent insert point (with a mismatched
+/// Instruction and Block) can lead to an instruction being inserted in a block
+/// other than its parent.
+void SCEVExpander::fixupInsertPoints(Instruction *I) {
+ BasicBlock::iterator It(*I);
+ BasicBlock::iterator NewInsertPt = std::next(It);
+ if (Builder.GetInsertPoint() == It)
+ Builder.SetInsertPoint(&*NewInsertPt);
+ for (auto *InsertPtGuard : InsertPointGuards)
+ if (InsertPtGuard->GetInsertPoint() == It)
+ InsertPtGuard->SetInsertPoint(NewInsertPt);
+}
+
+/// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make
+/// it available to other uses in this loop. Recursively hoist any operands,
+/// until we reach a value that dominates InsertPos.
+bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) {
+ if (SE.DT.dominates(IncV, InsertPos))
+ return true;
+
+ // InsertPos must itself dominate IncV so that IncV's new position satisfies
+ // its existing users.
+ if (isa<PHINode>(InsertPos) ||
+ !SE.DT.dominates(InsertPos->getParent(), IncV->getParent()))
+ return false;
+
+ if (!SE.LI.movementPreservesLCSSAForm(IncV, InsertPos))
+ return false;
+
+ // Check that the chain of IV operands leading back to Phi can be hoisted.
+ SmallVector<Instruction*, 4> IVIncs;
+ for(;;) {
+ Instruction *Oper = getIVIncOperand(IncV, InsertPos, /*allowScale*/true);
+ if (!Oper)
+ return false;
+ // IncV is safe to hoist.
+ IVIncs.push_back(IncV);
+ IncV = Oper;
+ if (SE.DT.dominates(IncV, InsertPos))
+ break;
+ }
+ for (auto I = IVIncs.rbegin(), E = IVIncs.rend(); I != E; ++I) {
+ fixupInsertPoints(*I);
+ (*I)->moveBefore(InsertPos);
+ }
+ return true;
+}
+
+/// Determine if this cyclic phi is in a form that would have been generated by
+/// LSR. We don't care if the phi was actually expanded in this pass, as long
+/// as it is in a low-cost form, for example, no implied multiplication. This
+/// should match any patterns generated by getAddRecExprPHILiterally and
+/// expandAddtoGEP.
+bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV,
+ const Loop *L) {
+ for(Instruction *IVOper = IncV;
+ (IVOper = getIVIncOperand(IVOper, L->getLoopPreheader()->getTerminator(),
+ /*allowScale=*/false));) {
+ if (IVOper == PN)
+ return true;
+ }
+ return false;
+}
+
+/// expandIVInc - Expand an IV increment at Builder's current InsertPos.
+/// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may
+/// need to materialize IV increments elsewhere to handle difficult situations.
+Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L,
+ Type *ExpandTy, Type *IntTy,
+ bool useSubtract) {
+ Value *IncV;
+ // If the PHI is a pointer, use a GEP, otherwise use an add or sub.
+ if (ExpandTy->isPointerTy()) {
+ PointerType *GEPPtrTy = cast<PointerType>(ExpandTy);
+ // If the step isn't constant, don't use an implicitly scaled GEP, because
+ // that would require a multiply inside the loop.
+ if (!isa<ConstantInt>(StepV))
+ GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()),
+ GEPPtrTy->getAddressSpace());
+ IncV = expandAddToGEP(SE.getSCEV(StepV), GEPPtrTy, IntTy, PN);
+ if (IncV->getType() != PN->getType()) {
+ IncV = Builder.CreateBitCast(IncV, PN->getType());
+ rememberInstruction(IncV);
+ }
+ } else {
+ IncV = useSubtract ?
+ Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") :
+ Builder.CreateAdd(PN, StepV, Twine(IVName) + ".iv.next");
+ rememberInstruction(IncV);
+ }
+ return IncV;
+}
+
+/// Hoist the addrec instruction chain rooted in the loop phi above the
+/// position. This routine assumes that this is possible (has been checked).
+void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist,
+ Instruction *Pos, PHINode *LoopPhi) {
+ do {
+ if (DT->dominates(InstToHoist, Pos))
+ break;
+ // Make sure the increment is where we want it. But don't move it
+ // down past a potential existing post-inc user.
+ fixupInsertPoints(InstToHoist);
+ InstToHoist->moveBefore(Pos);
+ Pos = InstToHoist;
+ InstToHoist = cast<Instruction>(InstToHoist->getOperand(0));
+ } while (InstToHoist != LoopPhi);
+}
+
+/// Check whether we can cheaply express the requested SCEV in terms of
+/// the available PHI SCEV by truncation and/or inversion of the step.
+static bool canBeCheaplyTransformed(ScalarEvolution &SE,
+ const SCEVAddRecExpr *Phi,
+ const SCEVAddRecExpr *Requested,
+ bool &InvertStep) {
+ Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType());
+ Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType());
+
+ if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth())
+ return false;
+
+ // Try truncate it if necessary.
+ Phi = dyn_cast<SCEVAddRecExpr>(SE.getTruncateOrNoop(Phi, RequestedTy));
+ if (!Phi)
+ return false;
+
+ // Check whether truncation will help.
+ if (Phi == Requested) {
+ InvertStep = false;
+ return true;
+ }
+
+ // Check whether inverting will help: {R,+,-1} == R - {0,+,1}.
+ if (SE.getAddExpr(Requested->getStart(),
+ SE.getNegativeSCEV(Requested)) == Phi) {
+ InvertStep = true;
+ return true;
+ }
+
+ return false;
+}
+
+static bool IsIncrementNSW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) {
+ if (!isa<IntegerType>(AR->getType()))
+ return false;
+
+ unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth();
+ Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2);
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpAfterExtend = SE.getAddExpr(SE.getSignExtendExpr(Step, WideTy),
+ SE.getSignExtendExpr(AR, WideTy));
+ const SCEV *ExtendAfterOp =
+ SE.getSignExtendExpr(SE.getAddExpr(AR, Step), WideTy);
+ return ExtendAfterOp == OpAfterExtend;
+}
+
+static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) {
+ if (!isa<IntegerType>(AR->getType()))
+ return false;
+
+ unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth();
+ Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2);
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpAfterExtend = SE.getAddExpr(SE.getZeroExtendExpr(Step, WideTy),
+ SE.getZeroExtendExpr(AR, WideTy));
+ const SCEV *ExtendAfterOp =
+ SE.getZeroExtendExpr(SE.getAddExpr(AR, Step), WideTy);
+ return ExtendAfterOp == OpAfterExtend;
+}
+
+/// getAddRecExprPHILiterally - Helper for expandAddRecExprLiterally. Expand
+/// the base addrec, which is the addrec without any non-loop-dominating
+/// values, and return the PHI.
+PHINode *
+SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
+ const Loop *L,
+ Type *ExpandTy,
+ Type *IntTy,
+ Type *&TruncTy,
+ bool &InvertStep) {
+ assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position");
+
+ // Reuse a previously-inserted PHI, if present.
+ BasicBlock *LatchBlock = L->getLoopLatch();
+ if (LatchBlock) {
+ PHINode *AddRecPhiMatch = nullptr;
+ Instruction *IncV = nullptr;
+ TruncTy = nullptr;
+ InvertStep = false;
+
+ // Only try partially matching scevs that need truncation and/or
+ // step-inversion if we know this loop is outside the current loop.
+ bool TryNonMatchingSCEV =
+ IVIncInsertLoop &&
+ SE.DT.properlyDominates(LatchBlock, IVIncInsertLoop->getHeader());
+
+ for (PHINode &PN : L->getHeader()->phis()) {
+ if (!SE.isSCEVable(PN.getType()))
+ continue;
+
+ const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+ if (!PhiSCEV)
+ continue;
+
+ bool IsMatchingSCEV = PhiSCEV == Normalized;
+ // We only handle truncation and inversion of phi recurrences for the
+ // expanded expression if the expanded expression's loop dominates the
+ // loop we insert to. Check now, so we can bail out early.
+ if (!IsMatchingSCEV && !TryNonMatchingSCEV)
+ continue;
+
+ // TODO: this possibly can be reworked to avoid this cast at all.
+ Instruction *TempIncV =
+ dyn_cast<Instruction>(PN.getIncomingValueForBlock(LatchBlock));
+ if (!TempIncV)
+ continue;
+
+ // Check whether we can reuse this PHI node.
+ if (LSRMode) {
+ if (!isExpandedAddRecExprPHI(&PN, TempIncV, L))
+ continue;
+ if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos))
+ continue;
+ } else {
+ if (!isNormalAddRecExprPHI(&PN, TempIncV, L))
+ continue;
+ }
+
+ // Stop if we have found an exact match SCEV.
+ if (IsMatchingSCEV) {
+ IncV = TempIncV;
+ TruncTy = nullptr;
+ InvertStep = false;
+ AddRecPhiMatch = &PN;
+ break;
+ }
+
+ // Try whether the phi can be translated into the requested form
+ // (truncated and/or offset by a constant).
+ if ((!TruncTy || InvertStep) &&
+ canBeCheaplyTransformed(SE, PhiSCEV, Normalized, InvertStep)) {
+ // Record the phi node. But don't stop we might find an exact match
+ // later.
+ AddRecPhiMatch = &PN;
+ IncV = TempIncV;
+ TruncTy = SE.getEffectiveSCEVType(Normalized->getType());
+ }
+ }
+
+ if (AddRecPhiMatch) {
+ // Potentially, move the increment. We have made sure in
+ // isExpandedAddRecExprPHI or hoistIVInc that this is possible.
+ if (L == IVIncInsertLoop)
+ hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch);
+
+ // Ok, the add recurrence looks usable.
+ // Remember this PHI, even in post-inc mode.
+ InsertedValues.insert(AddRecPhiMatch);
+ // Remember the increment.
+ rememberInstruction(IncV);
+ return AddRecPhiMatch;
+ }
+ }
+
+ // Save the original insertion point so we can restore it when we're done.
+ SCEVInsertPointGuard Guard(Builder, this);
+
+ // Another AddRec may need to be recursively expanded below. For example, if
+ // this AddRec is quadratic, the StepV may itself be an AddRec in this
+ // loop. Remove this loop from the PostIncLoops set before expanding such
+ // AddRecs. Otherwise, we cannot find a valid position for the step
+ // (i.e. StepV can never dominate its loop header). Ideally, we could do
+ // SavedIncLoops.swap(PostIncLoops), but we generally have a single element,
+ // so it's not worth implementing SmallPtrSet::swap.
+ PostIncLoopSet SavedPostIncLoops = PostIncLoops;
+ PostIncLoops.clear();
+
+ // Expand code for the start value into the loop preheader.
+ assert(L->getLoopPreheader() &&
+ "Can't expand add recurrences without a loop preheader!");
+ Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy,
+ L->getLoopPreheader()->getTerminator());
+
+ // StartV must have been be inserted into L's preheader to dominate the new
+ // phi.
+ assert(!isa<Instruction>(StartV) ||
+ SE.DT.properlyDominates(cast<Instruction>(StartV)->getParent(),
+ L->getHeader()));
+
+ // Expand code for the step value. Do this before creating the PHI so that PHI
+ // reuse code doesn't see an incomplete PHI.
+ const SCEV *Step = Normalized->getStepRecurrence(SE);
+ // If the stride is negative, insert a sub instead of an add for the increment
+ // (unless it's a constant, because subtracts of constants are canonicalized
+ // to adds).
+ bool useSubtract = !ExpandTy->isPointerTy() && Step->isNonConstantNegative();
+ if (useSubtract)
+ Step = SE.getNegativeSCEV(Step);
+ // Expand the step somewhere that dominates the loop header.
+ Value *StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front());
+
+ // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if
+ // we actually do emit an addition. It does not apply if we emit a
+ // subtraction.
+ bool IncrementIsNUW = !useSubtract && IsIncrementNUW(SE, Normalized);
+ bool IncrementIsNSW = !useSubtract && IsIncrementNSW(SE, Normalized);
+
+ // Create the PHI.
+ BasicBlock *Header = L->getHeader();
+ Builder.SetInsertPoint(Header, Header->begin());
+ pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header);
+ PHINode *PN = Builder.CreatePHI(ExpandTy, std::distance(HPB, HPE),
+ Twine(IVName) + ".iv");
+ rememberInstruction(PN);
+
+ // Create the step instructions and populate the PHI.
+ for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) {
+ BasicBlock *Pred = *HPI;
+
+ // Add a start value.
+ if (!L->contains(Pred)) {
+ PN->addIncoming(StartV, Pred);
+ continue;
+ }
+
+ // Create a step value and add it to the PHI.
+ // If IVIncInsertLoop is non-null and equal to the addrec's loop, insert the
+ // instructions at IVIncInsertPos.
+ Instruction *InsertPos = L == IVIncInsertLoop ?
+ IVIncInsertPos : Pred->getTerminator();
+ Builder.SetInsertPoint(InsertPos);
+ Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
+
+ if (isa<OverflowingBinaryOperator>(IncV)) {
+ if (IncrementIsNUW)
+ cast<BinaryOperator>(IncV)->setHasNoUnsignedWrap();
+ if (IncrementIsNSW)
+ cast<BinaryOperator>(IncV)->setHasNoSignedWrap();
+ }
+ PN->addIncoming(IncV, Pred);
+ }
+
+ // After expanding subexpressions, restore the PostIncLoops set so the caller
+ // can ensure that IVIncrement dominates the current uses.
+ PostIncLoops = SavedPostIncLoops;
+
+ // Remember this PHI, even in post-inc mode.
+ InsertedValues.insert(PN);
+
+ return PN;
+}
+
+Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
+ Type *STy = S->getType();
+ Type *IntTy = SE.getEffectiveSCEVType(STy);
+ const Loop *L = S->getLoop();
+
+ // Determine a normalized form of this expression, which is the expression
+ // before any post-inc adjustment is made.
+ const SCEVAddRecExpr *Normalized = S;
+ if (PostIncLoops.count(L)) {
+ PostIncLoopSet Loops;
+ Loops.insert(L);
+ Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE));
+ }
+
+ // Strip off any non-loop-dominating component from the addrec start.
+ const SCEV *Start = Normalized->getStart();
+ const SCEV *PostLoopOffset = nullptr;
+ if (!SE.properlyDominates(Start, L->getHeader())) {
+ PostLoopOffset = Start;
+ Start = SE.getConstant(Normalized->getType(), 0);
+ Normalized = cast<SCEVAddRecExpr>(
+ SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE),
+ Normalized->getLoop(),
+ Normalized->getNoWrapFlags(SCEV::FlagNW)));
+ }
+
+ // Strip off any non-loop-dominating component from the addrec step.
+ const SCEV *Step = Normalized->getStepRecurrence(SE);
+ const SCEV *PostLoopScale = nullptr;
+ if (!SE.dominates(Step, L->getHeader())) {
+ PostLoopScale = Step;
+ Step = SE.getConstant(Normalized->getType(), 1);
+ if (!Start->isZero()) {
+ // The normalization below assumes that Start is constant zero, so if
+ // it isn't re-associate Start to PostLoopOffset.
+ assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?");
+ PostLoopOffset = Start;
+ Start = SE.getConstant(Normalized->getType(), 0);
+ }
+ Normalized =
+ cast<SCEVAddRecExpr>(SE.getAddRecExpr(
+ Start, Step, Normalized->getLoop(),
+ Normalized->getNoWrapFlags(SCEV::FlagNW)));
+ }
+
+ // Expand the core addrec. If we need post-loop scaling, force it to
+ // expand to an integer type to avoid the need for additional casting.
+ Type *ExpandTy = PostLoopScale ? IntTy : STy;
+ // We can't use a pointer type for the addrec if the pointer type is
+ // non-integral.
+ Type *AddRecPHIExpandTy =
+ DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy;
+
+ // In some cases, we decide to reuse an existing phi node but need to truncate
+ // it and/or invert the step.
+ Type *TruncTy = nullptr;
+ bool InvertStep = false;
+ PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy,
+ IntTy, TruncTy, InvertStep);
+
+ // Accommodate post-inc mode, if necessary.
+ Value *Result;
+ if (!PostIncLoops.count(L))
+ Result = PN;
+ else {
+ // In PostInc mode, use the post-incremented value.
+ BasicBlock *LatchBlock = L->getLoopLatch();
+ assert(LatchBlock && "PostInc mode requires a unique loop latch!");
+ Result = PN->getIncomingValueForBlock(LatchBlock);
+
+ // For an expansion to use the postinc form, the client must call
+ // expandCodeFor with an InsertPoint that is either outside the PostIncLoop
+ // or dominated by IVIncInsertPos.
+ if (isa<Instruction>(Result) &&
+ !SE.DT.dominates(cast<Instruction>(Result),
+ &*Builder.GetInsertPoint())) {
+ // The induction variable's postinc expansion does not dominate this use.
+ // IVUsers tries to prevent this case, so it is rare. However, it can
+ // happen when an IVUser outside the loop is not dominated by the latch
+ // block. Adjusting IVIncInsertPos before expansion begins cannot handle
+ // all cases. Consider a phi outside whose operand is replaced during
+ // expansion with the value of the postinc user. Without fundamentally
+ // changing the way postinc users are tracked, the only remedy is
+ // inserting an extra IV increment. StepV might fold into PostLoopOffset,
+ // but hopefully expandCodeFor handles that.
+ bool useSubtract =
+ !ExpandTy->isPointerTy() && Step->isNonConstantNegative();
+ if (useSubtract)
+ Step = SE.getNegativeSCEV(Step);
+ Value *StepV;
+ {
+ // Expand the step somewhere that dominates the loop header.
+ SCEVInsertPointGuard Guard(Builder, this);
+ StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front());
+ }
+ Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
+ }
+ }
+
+ // We have decided to reuse an induction variable of a dominating loop. Apply
+ // truncation and/or inversion of the step.
+ if (TruncTy) {
+ Type *ResTy = Result->getType();
+ // Normalize the result type.
+ if (ResTy != SE.getEffectiveSCEVType(ResTy))
+ Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy));
+ // Truncate the result.
+ if (TruncTy != Result->getType()) {
+ Result = Builder.CreateTrunc(Result, TruncTy);
+ rememberInstruction(Result);
+ }
+ // Invert the result.
+ if (InvertStep) {
+ Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy),
+ Result);
+ rememberInstruction(Result);
+ }
+ }
+
+ // Re-apply any non-loop-dominating scale.
+ if (PostLoopScale) {
+ assert(S->isAffine() && "Can't linearly scale non-affine recurrences.");
+ Result = InsertNoopCastOfTo(Result, IntTy);
+ Result = Builder.CreateMul(Result,
+ expandCodeFor(PostLoopScale, IntTy));
+ rememberInstruction(Result);
+ }
+
+ // Re-apply any non-loop-dominating offset.
+ if (PostLoopOffset) {
+ if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) {
+ if (Result->getType()->isIntegerTy()) {
+ Value *Base = expandCodeFor(PostLoopOffset, ExpandTy);
+ Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base);
+ } else {
+ Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result);
+ }
+ } else {
+ Result = InsertNoopCastOfTo(Result, IntTy);
+ Result = Builder.CreateAdd(Result,
+ expandCodeFor(PostLoopOffset, IntTy));
+ rememberInstruction(Result);
+ }
+ }
+
+ return Result;
+}
+
+Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
+ // In canonical mode we compute the addrec as an expression of a canonical IV
+ // using evaluateAtIteration and expand the resulting SCEV expression. This
+ // way we avoid introducing new IVs to carry on the comutation of the addrec
+ // throughout the loop.
+ //
+ // For nested addrecs evaluateAtIteration might need a canonical IV of a
+ // type wider than the addrec itself. Emitting a canonical IV of the
+ // proper type might produce non-legal types, for example expanding an i64
+ // {0,+,2,+,1} addrec would need an i65 canonical IV. To avoid this just fall
+ // back to non-canonical mode for nested addrecs.
+ if (!CanonicalMode || (S->getNumOperands() > 2))
+ return expandAddRecExprLiterally(S);
+
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+ const Loop *L = S->getLoop();
+
+ // First check for an existing canonical IV in a suitable type.
+ PHINode *CanonicalIV = nullptr;
+ if (PHINode *PN = L->getCanonicalInductionVariable())
+ if (SE.getTypeSizeInBits(PN->getType()) >= SE.getTypeSizeInBits(Ty))
+ CanonicalIV = PN;
+
+ // Rewrite an AddRec in terms of the canonical induction variable, if
+ // its type is more narrow.
+ if (CanonicalIV &&
+ SE.getTypeSizeInBits(CanonicalIV->getType()) >
+ SE.getTypeSizeInBits(Ty)) {
+ SmallVector<const SCEV *, 4> NewOps(S->getNumOperands());
+ for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i)
+ NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType());
+ Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(),
+ S->getNoWrapFlags(SCEV::FlagNW)));
+ BasicBlock::iterator NewInsertPt =
+ findInsertPointAfter(cast<Instruction>(V), Builder.GetInsertBlock());
+ V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
+ &*NewInsertPt);
+ return V;
+ }
+
+ // {X,+,F} --> X + {0,+,F}
+ if (!S->getStart()->isZero()) {
+ SmallVector<const SCEV *, 4> NewOps(S->op_begin(), S->op_end());
+ NewOps[0] = SE.getConstant(Ty, 0);
+ const SCEV *Rest = SE.getAddRecExpr(NewOps, L,
+ S->getNoWrapFlags(SCEV::FlagNW));
+
+ // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the
+ // comments on expandAddToGEP for details.
+ const SCEV *Base = S->getStart();
+ // Dig into the expression to find the pointer base for a GEP.
+ const SCEV *ExposedRest = Rest;
+ ExposePointerBase(Base, ExposedRest, SE);
+ // If we found a pointer, expand the AddRec with a GEP.
+ if (PointerType *PTy = dyn_cast<PointerType>(Base->getType())) {
+ // Make sure the Base isn't something exotic, such as a multiplied
+ // or divided pointer value. In those cases, the result type isn't
+ // actually a pointer type.
+ if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) {
+ Value *StartV = expand(Base);
+ assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!");
+ return expandAddToGEP(ExposedRest, PTy, Ty, StartV);
+ }
+ }
+
+ // Just do a normal add. Pre-expand the operands to suppress folding.
+ //
+ // The LHS and RHS values are factored out of the expand call to make the
+ // output independent of the argument evaluation order.
+ const SCEV *AddExprLHS = SE.getUnknown(expand(S->getStart()));
+ const SCEV *AddExprRHS = SE.getUnknown(expand(Rest));
+ return expand(SE.getAddExpr(AddExprLHS, AddExprRHS));
+ }
+
+ // If we don't yet have a canonical IV, create one.
+ if (!CanonicalIV) {
+ // Create and insert the PHI node for the induction variable in the
+ // specified loop.
+ BasicBlock *Header = L->getHeader();
+ pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header);
+ CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar",
+ &Header->front());
+ rememberInstruction(CanonicalIV);
+
+ SmallSet<BasicBlock *, 4> PredSeen;
+ Constant *One = ConstantInt::get(Ty, 1);
+ for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) {
+ BasicBlock *HP = *HPI;
+ if (!PredSeen.insert(HP).second) {
+ // There must be an incoming value for each predecessor, even the
+ // duplicates!
+ CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP);
+ continue;
+ }
+
+ if (L->contains(HP)) {
+ // Insert a unit add instruction right before the terminator
+ // corresponding to the back-edge.
+ Instruction *Add = BinaryOperator::CreateAdd(CanonicalIV, One,
+ "indvar.next",
+ HP->getTerminator());
+ Add->setDebugLoc(HP->getTerminator()->getDebugLoc());
+ rememberInstruction(Add);
+ CanonicalIV->addIncoming(Add, HP);
+ } else {
+ CanonicalIV->addIncoming(Constant::getNullValue(Ty), HP);
+ }
+ }
+ }
+
+ // {0,+,1} --> Insert a canonical induction variable into the loop!
+ if (S->isAffine() && S->getOperand(1)->isOne()) {
+ assert(Ty == SE.getEffectiveSCEVType(CanonicalIV->getType()) &&
+ "IVs with types different from the canonical IV should "
+ "already have been handled!");
+ return CanonicalIV;
+ }
+
+ // {0,+,F} --> {0,+,1} * F
+
+ // If this is a simple linear addrec, emit it now as a special case.
+ if (S->isAffine()) // {0,+,F} --> i*F
+ return
+ expand(SE.getTruncateOrNoop(
+ SE.getMulExpr(SE.getUnknown(CanonicalIV),
+ SE.getNoopOrAnyExtend(S->getOperand(1),
+ CanonicalIV->getType())),
+ Ty));
+
+ // If this is a chain of recurrences, turn it into a closed form, using the
+ // folders, then expandCodeFor the closed form. This allows the folders to
+ // simplify the expression without having to build a bunch of special code
+ // into this folder.
+ const SCEV *IH = SE.getUnknown(CanonicalIV); // Get I as a "symbolic" SCEV.
+
+ // Promote S up to the canonical IV type, if the cast is foldable.
+ const SCEV *NewS = S;
+ const SCEV *Ext = SE.getNoopOrAnyExtend(S, CanonicalIV->getType());
+ if (isa<SCEVAddRecExpr>(Ext))
+ NewS = Ext;
+
+ const SCEV *V = cast<SCEVAddRecExpr>(NewS)->evaluateAtIteration(IH, SE);
+ //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n";
+
+ // Truncate the result down to the original type, if needed.
+ const SCEV *T = SE.getTruncateOrNoop(V, Ty);
+ return expand(T);
+}
+
+Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+ Value *V = expandCodeFor(S->getOperand(),
+ SE.getEffectiveSCEVType(S->getOperand()->getType()));
+ Value *I = Builder.CreateTrunc(V, Ty);
+ rememberInstruction(I);
+ return I;
+}
+
+Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+ Value *V = expandCodeFor(S->getOperand(),
+ SE.getEffectiveSCEVType(S->getOperand()->getType()));
+ Value *I = Builder.CreateZExt(V, Ty);
+ rememberInstruction(I);
+ return I;
+}
+
+Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
+ Type *Ty = SE.getEffectiveSCEVType(S->getType());
+ Value *V = expandCodeFor(S->getOperand(),
+ SE.getEffectiveSCEVType(S->getOperand()->getType()));
+ Value *I = Builder.CreateSExt(V, Ty);
+ rememberInstruction(I);
+ return I;
+}
+
+Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
+ Value *LHS = expand(S->getOperand(S->getNumOperands()-1));
+ Type *Ty = LHS->getType();
+ for (int i = S->getNumOperands()-2; i >= 0; --i) {
+ // In the case of mixed integer and pointer types, do the
+ // rest of the comparisons as integer.
+ Type *OpTy = S->getOperand(i)->getType();
+ if (OpTy->isIntegerTy() != Ty->isIntegerTy()) {
+ Ty = SE.getEffectiveSCEVType(Ty);
+ LHS = InsertNoopCastOfTo(LHS, Ty);
+ }
+ Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+ Value *ICmp = Builder.CreateICmpSGT(LHS, RHS);
+ rememberInstruction(ICmp);
+ Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax");
+ rememberInstruction(Sel);
+ LHS = Sel;
+ }
+ // In the case of mixed integer and pointer types, cast the
+ // final result back to the pointer type.
+ if (LHS->getType() != S->getType())
+ LHS = InsertNoopCastOfTo(LHS, S->getType());
+ return LHS;
+}
+
+Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
+ Value *LHS = expand(S->getOperand(S->getNumOperands()-1));
+ Type *Ty = LHS->getType();
+ for (int i = S->getNumOperands()-2; i >= 0; --i) {
+ // In the case of mixed integer and pointer types, do the
+ // rest of the comparisons as integer.
+ Type *OpTy = S->getOperand(i)->getType();
+ if (OpTy->isIntegerTy() != Ty->isIntegerTy()) {
+ Ty = SE.getEffectiveSCEVType(Ty);
+ LHS = InsertNoopCastOfTo(LHS, Ty);
+ }
+ Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+ Value *ICmp = Builder.CreateICmpUGT(LHS, RHS);
+ rememberInstruction(ICmp);
+ Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax");
+ rememberInstruction(Sel);
+ LHS = Sel;
+ }
+ // In the case of mixed integer and pointer types, cast the
+ // final result back to the pointer type.
+ if (LHS->getType() != S->getType())
+ LHS = InsertNoopCastOfTo(LHS, S->getType());
+ return LHS;
+}
+
+Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
+ Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
+ Type *Ty = LHS->getType();
+ for (int i = S->getNumOperands() - 2; i >= 0; --i) {
+ // In the case of mixed integer and pointer types, do the
+ // rest of the comparisons as integer.
+ Type *OpTy = S->getOperand(i)->getType();
+ if (OpTy->isIntegerTy() != Ty->isIntegerTy()) {
+ Ty = SE.getEffectiveSCEVType(Ty);
+ LHS = InsertNoopCastOfTo(LHS, Ty);
+ }
+ Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+ Value *ICmp = Builder.CreateICmpSLT(LHS, RHS);
+ rememberInstruction(ICmp);
+ Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin");
+ rememberInstruction(Sel);
+ LHS = Sel;
+ }
+ // In the case of mixed integer and pointer types, cast the
+ // final result back to the pointer type.
+ if (LHS->getType() != S->getType())
+ LHS = InsertNoopCastOfTo(LHS, S->getType());
+ return LHS;
+}
+
+Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
+ Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
+ Type *Ty = LHS->getType();
+ for (int i = S->getNumOperands() - 2; i >= 0; --i) {
+ // In the case of mixed integer and pointer types, do the
+ // rest of the comparisons as integer.
+ Type *OpTy = S->getOperand(i)->getType();
+ if (OpTy->isIntegerTy() != Ty->isIntegerTy()) {
+ Ty = SE.getEffectiveSCEVType(Ty);
+ LHS = InsertNoopCastOfTo(LHS, Ty);
+ }
+ Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+ Value *ICmp = Builder.CreateICmpULT(LHS, RHS);
+ rememberInstruction(ICmp);
+ Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin");
+ rememberInstruction(Sel);
+ LHS = Sel;
+ }
+ // In the case of mixed integer and pointer types, cast the
+ // final result back to the pointer type.
+ if (LHS->getType() != S->getType())
+ LHS = InsertNoopCastOfTo(LHS, S->getType());
+ return LHS;
+}
+
+Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
+ Instruction *IP) {
+ setInsertPoint(IP);
+ return expandCodeFor(SH, Ty);
+}
+
+Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
+ // Expand the code for this SCEV.
+ Value *V = expand(SH);
+ if (Ty) {
+ assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) &&
+ "non-trivial casts should be done with the SCEVs directly!");
+ V = InsertNoopCastOfTo(V, Ty);
+ }
+ return V;
+}
+
+ScalarEvolution::ValueOffsetPair
+SCEVExpander::FindValueInExprValueMap(const SCEV *S,
+ const Instruction *InsertPt) {
+ SetVector<ScalarEvolution::ValueOffsetPair> *Set = SE.getSCEVValues(S);
+ // If the expansion is not in CanonicalMode, and the SCEV contains any
+ // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally.
+ if (CanonicalMode || !SE.containsAddRecurrence(S)) {
+ // If S is scConstant, it may be worse to reuse an existing Value.
+ if (S->getSCEVType() != scConstant && Set) {
+ // Choose a Value from the set which dominates the insertPt.
+ // insertPt should be inside the Value's parent loop so as not to break
+ // the LCSSA form.
+ for (auto const &VOPair : *Set) {
+ Value *V = VOPair.first;
+ ConstantInt *Offset = VOPair.second;
+ Instruction *EntInst = nullptr;
+ if (V && isa<Instruction>(V) && (EntInst = cast<Instruction>(V)) &&
+ S->getType() == V->getType() &&
+ EntInst->getFunction() == InsertPt->getFunction() &&
+ SE.DT.dominates(EntInst, InsertPt) &&
+ (SE.LI.getLoopFor(EntInst->getParent()) == nullptr ||
+ SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)))
+ return {V, Offset};
+ }
+ }
+ }
+ return {nullptr, nullptr};
+}
+
+// The expansion of SCEV will either reuse a previous Value in ExprValueMap,
+// or expand the SCEV literally. Specifically, if the expansion is in LSRMode,
+// and the SCEV contains any sub scAddRecExpr type SCEV, it will be expanded
+// literally, to prevent LSR's transformed SCEV from being reverted. Otherwise,
+// the expansion will try to reuse Value from ExprValueMap, and only when it
+// fails, expand the SCEV literally.
+Value *SCEVExpander::expand(const SCEV *S) {
+ // Compute an insertion point for this SCEV object. Hoist the instructions
+ // as far out in the loop nest as possible.
+ Instruction *InsertPt = &*Builder.GetInsertPoint();
+
+ // We can move insertion point only if there is no div or rem operations
+ // otherwise we are risky to move it over the check for zero denominator.
+ auto SafeToHoist = [](const SCEV *S) {
+ return !SCEVExprContains(S, [](const SCEV *S) {
+ if (const auto *D = dyn_cast<SCEVUDivExpr>(S)) {
+ if (const auto *SC = dyn_cast<SCEVConstant>(D->getRHS()))
+ // Division by non-zero constants can be hoisted.
+ return SC->getValue()->isZero();
+ // All other divisions should not be moved as they may be
+ // divisions by zero and should be kept within the
+ // conditions of the surrounding loops that guard their
+ // execution (see PR35406).
+ return true;
+ }
+ return false;
+ });
+ };
+ if (SafeToHoist(S)) {
+ for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());;
+ L = L->getParentLoop()) {
+ if (SE.isLoopInvariant(S, L)) {
+ if (!L) break;
+ if (BasicBlock *Preheader = L->getLoopPreheader())
+ InsertPt = Preheader->getTerminator();
+ else
+ // LSR sets the insertion point for AddRec start/step values to the
+ // block start to simplify value reuse, even though it's an invalid
+ // position. SCEVExpander must correct for this in all cases.
+ InsertPt = &*L->getHeader()->getFirstInsertionPt();
+ } else {
+ // If the SCEV is computable at this level, insert it into the header
+ // after the PHIs (and after any other instructions that we've inserted
+ // there) so that it is guaranteed to dominate any user inside the loop.
+ if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L))
+ InsertPt = &*L->getHeader()->getFirstInsertionPt();
+ while (InsertPt->getIterator() != Builder.GetInsertPoint() &&
+ (isInsertedInstruction(InsertPt) ||
+ isa<DbgInfoIntrinsic>(InsertPt)))
+ InsertPt = &*std::next(InsertPt->getIterator());
+ break;
+ }
+ }
+ }
+
+ // IndVarSimplify sometimes sets the insertion point at the block start, even
+ // when there are PHIs at that point. We must correct for this.
+ if (isa<PHINode>(*InsertPt))
+ InsertPt = &*InsertPt->getParent()->getFirstInsertionPt();
+
+ // Check to see if we already expanded this here.
+ auto I = InsertedExpressions.find(std::make_pair(S, InsertPt));
+ if (I != InsertedExpressions.end())
+ return I->second;
+
+ SCEVInsertPointGuard Guard(Builder, this);
+ Builder.SetInsertPoint(InsertPt);
+
+ // Expand the expression into instructions.
+ ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, InsertPt);
+ Value *V = VO.first;
+
+ if (!V)
+ V = visit(S);
+ else if (VO.second) {
+ if (PointerType *Vty = dyn_cast<PointerType>(V->getType())) {
+ Type *Ety = Vty->getPointerElementType();
+ int64_t Offset = VO.second->getSExtValue();
+ int64_t ESize = SE.getTypeSizeInBits(Ety);
+ if ((Offset * 8) % ESize == 0) {
+ ConstantInt *Idx =
+ ConstantInt::getSigned(VO.second->getType(), -(Offset * 8) / ESize);
+ V = Builder.CreateGEP(Ety, V, Idx, "scevgep");
+ } else {
+ ConstantInt *Idx =
+ ConstantInt::getSigned(VO.second->getType(), -Offset);
+ unsigned AS = Vty->getAddressSpace();
+ V = Builder.CreateBitCast(V, Type::getInt8PtrTy(SE.getContext(), AS));
+ V = Builder.CreateGEP(Type::getInt8Ty(SE.getContext()), V, Idx,
+ "uglygep");
+ V = Builder.CreateBitCast(V, Vty);
+ }
+ } else {
+ V = Builder.CreateSub(V, VO.second);
+ }
+ }
+ // Remember the expanded value for this SCEV at this location.
+ //
+ // This is independent of PostIncLoops. The mapped value simply materializes
+ // the expression at this insertion point. If the mapped value happened to be
+ // a postinc expansion, it could be reused by a non-postinc user, but only if
+ // its insertion point was already at the head of the loop.
+ InsertedExpressions[std::make_pair(S, InsertPt)] = V;
+ return V;
+}
+
+void SCEVExpander::rememberInstruction(Value *I) {
+ if (!PostIncLoops.empty())
+ InsertedPostIncValues.insert(I);
+ else
+ InsertedValues.insert(I);
+}
+
+/// getOrInsertCanonicalInductionVariable - This method returns the
+/// canonical induction variable of the specified type for the specified
+/// loop (inserting one if there is none). A canonical induction variable
+/// starts at zero and steps by one on each iteration.
+PHINode *
+SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L,
+ Type *Ty) {
+ assert(Ty->isIntegerTy() && "Can only insert integer induction variables!");
+
+ // Build a SCEV for {0,+,1}<L>.
+ // Conservatively use FlagAnyWrap for now.
+ const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0),
+ SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap);
+
+ // Emit code for it.
+ SCEVInsertPointGuard Guard(Builder, this);
+ PHINode *V =
+ cast<PHINode>(expandCodeFor(H, nullptr, &L->getHeader()->front()));
+
+ return V;
+}
+
+/// replaceCongruentIVs - Check for congruent phis in this loop header and
+/// replace them with their most canonical representative. Return the number of
+/// phis eliminated.
+///
+/// This does not depend on any SCEVExpander state but should be used in
+/// the same context that SCEVExpander is used.
+unsigned
+SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
+ SmallVectorImpl<WeakTrackingVH> &DeadInsts,
+ const TargetTransformInfo *TTI) {
+ // Find integer phis in order of increasing width.
+ SmallVector<PHINode*, 8> Phis;
+ for (PHINode &PN : L->getHeader()->phis())
+ Phis.push_back(&PN);
+
+ if (TTI)
+ llvm::sort(Phis, [](Value *LHS, Value *RHS) {
+ // Put pointers at the back and make sure pointer < pointer = false.
+ if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy())
+ return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy();
+ return RHS->getType()->getPrimitiveSizeInBits() <
+ LHS->getType()->getPrimitiveSizeInBits();
+ });
+
+ unsigned NumElim = 0;
+ DenseMap<const SCEV *, PHINode *> ExprToIVMap;
+ // Process phis from wide to narrow. Map wide phis to their truncation
+ // so narrow phis can reuse them.
+ for (PHINode *Phi : Phis) {
+ auto SimplifyPHINode = [&](PHINode *PN) -> Value * {
+ if (Value *V = SimplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC}))
+ return V;
+ if (!SE.isSCEVable(PN->getType()))
+ return nullptr;
+ auto *Const = dyn_cast<SCEVConstant>(SE.getSCEV(PN));
+ if (!Const)
+ return nullptr;
+ return Const->getValue();
+ };
+
+ // Fold constant phis. They may be congruent to other constant phis and
+ // would confuse the logic below that expects proper IVs.
+ if (Value *V = SimplifyPHINode(Phi)) {
+ if (V->getType() != Phi->getType())
+ continue;
+ Phi->replaceAllUsesWith(V);
+ DeadInsts.emplace_back(Phi);
+ ++NumElim;
+ DEBUG_WITH_TYPE(DebugType, dbgs()
+ << "INDVARS: Eliminated constant iv: " << *Phi << '\n');
+ continue;
+ }
+
+ if (!SE.isSCEVable(Phi->getType()))
+ continue;
+
+ PHINode *&OrigPhiRef = ExprToIVMap[SE.getSCEV(Phi)];
+ if (!OrigPhiRef) {
+ OrigPhiRef = Phi;
+ if (Phi->getType()->isIntegerTy() && TTI &&
+ TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) {
+ // This phi can be freely truncated to the narrowest phi type. Map the
+ // truncated expression to it so it will be reused for narrow types.
+ const SCEV *TruncExpr =
+ SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType());
+ ExprToIVMap[TruncExpr] = Phi;
+ }
+ continue;
+ }
+
+ // Replacing a pointer phi with an integer phi or vice-versa doesn't make
+ // sense.
+ if (OrigPhiRef->getType()->isPointerTy() != Phi->getType()->isPointerTy())
+ continue;
+
+ if (BasicBlock *LatchBlock = L->getLoopLatch()) {
+ Instruction *OrigInc = dyn_cast<Instruction>(
+ OrigPhiRef->getIncomingValueForBlock(LatchBlock));
+ Instruction *IsomorphicInc =
+ dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock));
+
+ if (OrigInc && IsomorphicInc) {
+ // If this phi has the same width but is more canonical, replace the
+ // original with it. As part of the "more canonical" determination,
+ // respect a prior decision to use an IV chain.
+ if (OrigPhiRef->getType() == Phi->getType() &&
+ !(ChainedPhis.count(Phi) ||
+ isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) &&
+ (ChainedPhis.count(Phi) ||
+ isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) {
+ std::swap(OrigPhiRef, Phi);
+ std::swap(OrigInc, IsomorphicInc);
+ }
+ // Replacing the congruent phi is sufficient because acyclic
+ // redundancy elimination, CSE/GVN, should handle the
+ // rest. However, once SCEV proves that a phi is congruent,
+ // it's often the head of an IV user cycle that is isomorphic
+ // with the original phi. It's worth eagerly cleaning up the
+ // common case of a single IV increment so that DeleteDeadPHIs
+ // can remove cycles that had postinc uses.
+ const SCEV *TruncExpr =
+ SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType());
+ if (OrigInc != IsomorphicInc &&
+ TruncExpr == SE.getSCEV(IsomorphicInc) &&
+ SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) &&
+ hoistIVInc(OrigInc, IsomorphicInc)) {
+ DEBUG_WITH_TYPE(DebugType,
+ dbgs() << "INDVARS: Eliminated congruent iv.inc: "
+ << *IsomorphicInc << '\n');
+ Value *NewInc = OrigInc;
+ if (OrigInc->getType() != IsomorphicInc->getType()) {
+ Instruction *IP = nullptr;
+ if (PHINode *PN = dyn_cast<PHINode>(OrigInc))
+ IP = &*PN->getParent()->getFirstInsertionPt();
+ else
+ IP = OrigInc->getNextNode();
+
+ IRBuilder<> Builder(IP);
+ Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc());
+ NewInc = Builder.CreateTruncOrBitCast(
+ OrigInc, IsomorphicInc->getType(), IVName);
+ }
+ IsomorphicInc->replaceAllUsesWith(NewInc);
+ DeadInsts.emplace_back(IsomorphicInc);
+ }
+ }
+ }
+ DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: "
+ << *Phi << '\n');
+ ++NumElim;
+ Value *NewIV = OrigPhiRef;
+ if (OrigPhiRef->getType() != Phi->getType()) {
+ IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt());
+ Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
+ NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName);
+ }
+ Phi->replaceAllUsesWith(NewIV);
+ DeadInsts.emplace_back(Phi);
+ }
+ return NumElim;
+}
+
+Value *SCEVExpander::getExactExistingExpansion(const SCEV *S,
+ const Instruction *At, Loop *L) {
+ Optional<ScalarEvolution::ValueOffsetPair> VO =
+ getRelatedExistingExpansion(S, At, L);
+ if (VO && VO.getValue().second == nullptr)
+ return VO.getValue().first;
+ return nullptr;
+}
+
+Optional<ScalarEvolution::ValueOffsetPair>
+SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At,
+ Loop *L) {
+ using namespace llvm::PatternMatch;
+
+ SmallVector<BasicBlock *, 4> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ // Look for suitable value in simple conditions at the loop exits.
+ for (BasicBlock *BB : ExitingBlocks) {
+ ICmpInst::Predicate Pred;
+ Instruction *LHS, *RHS;
+
+ if (!match(BB->getTerminator(),
+ m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)),
+ m_BasicBlock(), m_BasicBlock())))
+ continue;
+
+ if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At))
+ return ScalarEvolution::ValueOffsetPair(LHS, nullptr);
+
+ if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At))
+ return ScalarEvolution::ValueOffsetPair(RHS, nullptr);
+ }
+
+ // Use expand's logic which is used for reusing a previous Value in
+ // ExprValueMap.
+ ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, At);
+ if (VO.first)
+ return VO;
+
+ // There is potential to make this significantly smarter, but this simple
+ // heuristic already gets some interesting cases.
+
+ // Can not find suitable value.
+ return None;
+}
+
+bool SCEVExpander::isHighCostExpansionHelper(
+ const SCEV *S, Loop *L, const Instruction *At,
+ SmallPtrSetImpl<const SCEV *> &Processed) {
+
+ // If we can find an existing value for this scev available at the point "At"
+ // then consider the expression cheap.
+ if (At && getRelatedExistingExpansion(S, At, L))
+ return false;
+
+ // Zero/One operand expressions
+ switch (S->getSCEVType()) {
+ case scUnknown:
+ case scConstant:
+ return false;
+ case scTruncate:
+ return isHighCostExpansionHelper(cast<SCEVTruncateExpr>(S)->getOperand(),
+ L, At, Processed);
+ case scZeroExtend:
+ return isHighCostExpansionHelper(cast<SCEVZeroExtendExpr>(S)->getOperand(),
+ L, At, Processed);
+ case scSignExtend:
+ return isHighCostExpansionHelper(cast<SCEVSignExtendExpr>(S)->getOperand(),
+ L, At, Processed);
+ }
+
+ if (!Processed.insert(S).second)
+ return false;
+
+ if (auto *UDivExpr = dyn_cast<SCEVUDivExpr>(S)) {
+ // If the divisor is a power of two and the SCEV type fits in a native
+ // integer (and the LHS not expensive), consider the division cheap
+ // irrespective of whether it occurs in the user code since it can be
+ // lowered into a right shift.
+ if (auto *SC = dyn_cast<SCEVConstant>(UDivExpr->getRHS()))
+ if (SC->getAPInt().isPowerOf2()) {
+ if (isHighCostExpansionHelper(UDivExpr->getLHS(), L, At, Processed))
+ return true;
+ const DataLayout &DL =
+ L->getHeader()->getParent()->getParent()->getDataLayout();
+ unsigned Width = cast<IntegerType>(UDivExpr->getType())->getBitWidth();
+ return DL.isIllegalInteger(Width);
+ }
+
+ // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or
+ // HowManyLessThans produced to compute a precise expression, rather than a
+ // UDiv from the user's code. If we can't find a UDiv in the code with some
+ // simple searching, assume the former consider UDivExpr expensive to
+ // compute.
+ BasicBlock *ExitingBB = L->getExitingBlock();
+ if (!ExitingBB)
+ return true;
+
+ // At the beginning of this function we already tried to find existing value
+ // for plain 'S'. Now try to lookup 'S + 1' since it is common pattern
+ // involving division. This is just a simple search heuristic.
+ if (!At)
+ At = &ExitingBB->back();
+ if (!getRelatedExistingExpansion(
+ SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L))
+ return true;
+ }
+
+ // HowManyLessThans uses a Max expression whenever the loop is not guarded by
+ // the exit condition.
+ if (isa<SCEVMinMaxExpr>(S))
+ return true;
+
+ // Recurse past nary expressions, which commonly occur in the
+ // BackedgeTakenCount. They may already exist in program code, and if not,
+ // they are not too expensive rematerialize.
+ if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) {
+ for (auto *Op : NAry->operands())
+ if (isHighCostExpansionHelper(Op, L, At, Processed))
+ return true;
+ }
+
+ // If we haven't recognized an expensive SCEV pattern, assume it's an
+ // expression produced by program code.
+ return false;
+}
+
+Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
+ Instruction *IP) {
+ assert(IP);
+ switch (Pred->getKind()) {
+ case SCEVPredicate::P_Union:
+ return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
+ case SCEVPredicate::P_Equal:
+ return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+ case SCEVPredicate::P_Wrap: {
+ auto *AddRecPred = cast<SCEVWrapPredicate>(Pred);
+ return expandWrapPredicate(AddRecPred, IP);
+ }
+ }
+ llvm_unreachable("Unknown SCEV predicate type");
+}
+
+Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
+ Instruction *IP) {
+ Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
+ Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+
+ Builder.SetInsertPoint(IP);
+ auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
+ return I;
+}
+
+Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
+ Instruction *Loc, bool Signed) {
+ assert(AR->isAffine() && "Cannot generate RT check for "
+ "non-affine expression");
+
+ SCEVUnionPredicate Pred;
+ const SCEV *ExitCount =
+ SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred);
+
+ assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *Start = AR->getStart();
+
+ Type *ARTy = AR->getType();
+ unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
+ unsigned DstBits = SE.getTypeSizeInBits(ARTy);
+
+ // The expression {Start,+,Step} has nusw/nssw if
+ // Step < 0, Start - |Step| * Backedge <= Start
+ // Step >= 0, Start + |Step| * Backedge > Start
+ // and |Step| * Backedge doesn't unsigned overflow.
+
+ IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
+ Builder.SetInsertPoint(Loc);
+ Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
+
+ IntegerType *Ty =
+ IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));
+ Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty;
+
+ Value *StepValue = expandCodeFor(Step, Ty, Loc);
+ Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
+ Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc);
+
+ ConstantInt *Zero =
+ ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
+
+ Builder.SetInsertPoint(Loc);
+ // Compute |Step|
+ Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero);
+ Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue);
+
+ // Get the backedge taken count and truncate or extended to the AR type.
+ Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty);
+ auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
+ Intrinsic::umul_with_overflow, Ty);
+
+ // Compute |Step| * Backedge
+ CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
+ Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
+ Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
+
+ // Compute:
+ // Start + |Step| * Backedge < Start
+ // Start - |Step| * Backedge > Start
+ Value *Add = nullptr, *Sub = nullptr;
+ if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARExpandTy)) {
+ const SCEV *MulS = SE.getSCEV(MulV);
+ const SCEV *NegMulS = SE.getNegativeSCEV(MulS);
+ Add = Builder.CreateBitCast(expandAddToGEP(MulS, ARPtrTy, Ty, StartValue),
+ ARPtrTy);
+ Sub = Builder.CreateBitCast(
+ expandAddToGEP(NegMulS, ARPtrTy, Ty, StartValue), ARPtrTy);
+ } else {
+ Add = Builder.CreateAdd(StartValue, MulV);
+ Sub = Builder.CreateSub(StartValue, MulV);
+ }
+
+ Value *EndCompareGT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue);
+
+ Value *EndCompareLT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue);
+
+ // Select the answer based on the sign of Step.
+ Value *EndCheck =
+ Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT);
+
+ // If the backedge taken count type is larger than the AR type,
+ // check that we don't drop any bits by truncating it. If we are
+ // dropping bits, then we have overflow (unless the step is zero).
+ if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) {
+ auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits);
+ auto *BackedgeCheck =
+ Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal,
+ ConstantInt::get(Loc->getContext(), MaxVal));
+ BackedgeCheck = Builder.CreateAnd(
+ BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero));
+
+ EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck);
+ }
+
+ EndCheck = Builder.CreateOr(EndCheck, OfMul);
+ return EndCheck;
+}
+
+Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred,
+ Instruction *IP) {
+ const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr());
+ Value *NSSWCheck = nullptr, *NUSWCheck = nullptr;
+
+ // Add a check for NUSW
+ if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW)
+ NUSWCheck = generateOverflowCheck(A, IP, false);
+
+ // Add a check for NSSW
+ if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW)
+ NSSWCheck = generateOverflowCheck(A, IP, true);
+
+ if (NUSWCheck && NSSWCheck)
+ return Builder.CreateOr(NUSWCheck, NSSWCheck);
+
+ if (NUSWCheck)
+ return NUSWCheck;
+
+ if (NSSWCheck)
+ return NSSWCheck;
+
+ return ConstantInt::getFalse(IP->getContext());
+}
+
+Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
+ Instruction *IP) {
+ auto *BoolType = IntegerType::get(IP->getContext(), 1);
+ Value *Check = ConstantInt::getNullValue(BoolType);
+
+ // Loop over all checks in this set.
+ for (auto Pred : Union->getPredicates()) {
+ auto *NextCheck = expandCodeForPredicate(Pred, IP);
+ Builder.SetInsertPoint(IP);
+ Check = Builder.CreateOr(Check, NextCheck);
+ }
+
+ return Check;
+}
+
+namespace {
+// Search for a SCEV subexpression that is not safe to expand. Any expression
+// that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
+// UDiv expressions. We don't know if the UDiv is derived from an IR divide
+// instruction, but the important thing is that we prove the denominator is
+// nonzero before expansion.
+//
+// IVUsers already checks that IV-derived expressions are safe. So this check is
+// only needed when the expression includes some subexpression that is not IV
+// derived.
+//
+// Currently, we only allow division by a nonzero constant here. If this is
+// inadequate, we could easily allow division by SCEVUnknown by using
+// ValueTracking to check isKnownNonZero().
+//
+// We cannot generally expand recurrences unless the step dominates the loop
+// header. The expander handles the special case of affine recurrences by
+// scaling the recurrence outside the loop, but this technique isn't generally
+// applicable. Expanding a nested recurrence outside a loop requires computing
+// binomial coefficients. This could be done, but the recurrence has to be in a
+// perfectly reduced form, which can't be guaranteed.
+struct SCEVFindUnsafe {
+ ScalarEvolution &SE;
+ bool IsUnsafe;
+
+ SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {}
+
+ bool follow(const SCEV *S) {
+ if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) {
+ const SCEVConstant *SC = dyn_cast<SCEVConstant>(D->getRHS());
+ if (!SC || SC->getValue()->isZero()) {
+ IsUnsafe = true;
+ return false;
+ }
+ }
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) {
+ IsUnsafe = true;
+ return false;
+ }
+ }
+ return true;
+ }
+ bool isDone() const { return IsUnsafe; }
+};
+}
+
+namespace llvm {
+bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) {
+ SCEVFindUnsafe Search(SE);
+ visitAll(S, Search);
+ return !Search.IsUnsafe;
+}
+
+bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint,
+ ScalarEvolution &SE) {
+ if (!isSafeToExpand(S, SE))
+ return false;
+ // We have to prove that the expanded site of S dominates InsertionPoint.
+ // This is easy when not in the same block, but hard when S is an instruction
+ // to be expanded somewhere inside the same block as our insertion point.
+ // What we really need here is something analogous to an OrderedBasicBlock,
+ // but for the moment, we paper over the problem by handling two common and
+ // cheap to check cases.
+ if (SE.properlyDominates(S, InsertionPoint->getParent()))
+ return true;
+ if (SE.dominates(S, InsertionPoint->getParent())) {
+ if (InsertionPoint->getParent()->getTerminator() == InsertionPoint)
+ return true;
+ if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
+ for (const Value *V : InsertionPoint->operand_values())
+ if (V == U->getValue())
+ return true;
+ }
+ return false;
+}
+}
diff --git a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp
new file mode 100644
index 000000000000..209ae66ca53e
--- /dev/null
+++ b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp
@@ -0,0 +1,117 @@
+//===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for working with "normalized" expressions.
+// See the comments at the top of ScalarEvolutionNormalization.h for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+using namespace llvm;
+
+/// TransformKind - Different types of transformations that
+/// TransformForPostIncUse can do.
+enum TransformKind {
+ /// Normalize - Normalize according to the given loops.
+ Normalize,
+ /// Denormalize - Perform the inverse transform on the expression with the
+ /// given loop set.
+ Denormalize
+};
+
+namespace {
+struct NormalizeDenormalizeRewriter
+ : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
+ const TransformKind Kind;
+
+ // NB! Pred is a function_ref. Storing it here is okay only because
+ // we're careful about the lifetime of NormalizeDenormalizeRewriter.
+ const NormalizePredTy Pred;
+
+ NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
+ ScalarEvolution &SE)
+ : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
+ Pred(Pred) {}
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
+};
+} // namespace
+
+const SCEV *
+NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
+ SmallVector<const SCEV *, 8> Operands;
+
+ transform(AR->operands(), std::back_inserter(Operands),
+ [&](const SCEV *Op) { return visit(Op); });
+
+ if (!Pred(AR))
+ return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
+
+ // Normalization and denormalization are fancy names for decrementing and
+ // incrementing a SCEV expression with respect to a set of loops. Since
+ // Pred(AR) has returned true, we know we need to normalize or denormalize AR
+ // with respect to its loop.
+
+ if (Kind == Denormalize) {
+ // Denormalization / "partial increment" is essentially the same as \c
+ // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the
+ // symmetry with Normalization clear.
+ for (int i = 0, e = Operands.size() - 1; i < e; i++)
+ Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
+ } else {
+ assert(Kind == Normalize && "Only two possibilities!");
+
+ // Normalization / "partial decrement" is a bit more subtle. Since
+ // incrementing a SCEV expression (in general) changes the step of the SCEV
+ // expression as well, we cannot use the step of the current expression.
+ // Instead, we have to use the step of the very expression we're trying to
+ // compute!
+ //
+ // We solve the issue by recursively building up the result, starting from
+ // the "least significant" operand in the add recurrence:
+ //
+ // Base case:
+ // Single operand add recurrence. It's its own normalization.
+ //
+ // N-operand case:
+ // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
+ //
+ // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
+ // normalization by induction. We subtract the normalized step
+ // recurrence from S_{N-1} to get the normalization of S.
+
+ for (int i = Operands.size() - 2; i >= 0; i--)
+ Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
+ }
+
+ return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
+}
+
+const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
+ const PostIncLoopSet &Loops,
+ ScalarEvolution &SE) {
+ auto Pred = [&](const SCEVAddRecExpr *AR) {
+ return Loops.count(AR->getLoop());
+ };
+ return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
+}
+
+const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
+ ScalarEvolution &SE) {
+ return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
+}
+
+const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
+ const PostIncLoopSet &Loops,
+ ScalarEvolution &SE) {
+ auto Pred = [&](const SCEVAddRecExpr *AR) {
+ return Loops.count(AR->getLoop());
+ };
+ return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
+}
diff --git a/llvm/lib/Analysis/ScopedNoAliasAA.cpp b/llvm/lib/Analysis/ScopedNoAliasAA.cpp
new file mode 100644
index 000000000000..094e4a3d5dc8
--- /dev/null
+++ b/llvm/lib/Analysis/ScopedNoAliasAA.cpp
@@ -0,0 +1,210 @@
+//===- ScopedNoAliasAA.cpp - Scoped No-Alias Alias Analysis ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ScopedNoAlias alias-analysis pass, which implements
+// metadata-based scoped no-alias support.
+//
+// Alias-analysis scopes are defined by an id (which can be a string or some
+// other metadata node), a domain node, and an optional descriptive string.
+// A domain is defined by an id (which can be a string or some other metadata
+// node), and an optional descriptive string.
+//
+// !dom0 = metadata !{ metadata !"domain of foo()" }
+// !scope1 = metadata !{ metadata !scope1, metadata !dom0, metadata !"scope 1" }
+// !scope2 = metadata !{ metadata !scope2, metadata !dom0, metadata !"scope 2" }
+//
+// Loads and stores can be tagged with an alias-analysis scope, and also, with
+// a noalias tag for a specific scope:
+//
+// ... = load %ptr1, !alias.scope !{ !scope1 }
+// ... = load %ptr2, !alias.scope !{ !scope1, !scope2 }, !noalias !{ !scope1 }
+//
+// When evaluating an aliasing query, if one of the instructions is associated
+// has a set of noalias scopes in some domain that is a superset of the alias
+// scopes in that domain of some other instruction, then the two memory
+// accesses are assumed not to alias.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ScopedNoAliasAA.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace llvm;
+
+// A handy option for disabling scoped no-alias functionality. The same effect
+// can also be achieved by stripping the associated metadata tags from IR, but
+// this option is sometimes more convenient.
+static cl::opt<bool> EnableScopedNoAlias("enable-scoped-noalias",
+ cl::init(true), cl::Hidden);
+
+namespace {
+
+/// This is a simple wrapper around an MDNode which provides a higher-level
+/// interface by hiding the details of how alias analysis information is encoded
+/// in its operands.
+class AliasScopeNode {
+ const MDNode *Node = nullptr;
+
+public:
+ AliasScopeNode() = default;
+ explicit AliasScopeNode(const MDNode *N) : Node(N) {}
+
+ /// Get the MDNode for this AliasScopeNode.
+ const MDNode *getNode() const { return Node; }
+
+ /// Get the MDNode for this AliasScopeNode's domain.
+ const MDNode *getDomain() const {
+ if (Node->getNumOperands() < 2)
+ return nullptr;
+ return dyn_cast_or_null<MDNode>(Node->getOperand(1));
+ }
+};
+
+} // end anonymous namespace
+
+AliasResult ScopedNoAliasAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ if (!EnableScopedNoAlias)
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ // Get the attached MDNodes.
+ const MDNode *AScopes = LocA.AATags.Scope, *BScopes = LocB.AATags.Scope;
+
+ const MDNode *ANoAlias = LocA.AATags.NoAlias, *BNoAlias = LocB.AATags.NoAlias;
+
+ if (!mayAliasInScopes(AScopes, BNoAlias))
+ return NoAlias;
+
+ if (!mayAliasInScopes(BScopes, ANoAlias))
+ return NoAlias;
+
+ // If they may alias, chain to the next AliasAnalysis.
+ return AAResultBase::alias(LocA, LocB, AAQI);
+}
+
+ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (!EnableScopedNoAlias)
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+
+ if (!mayAliasInScopes(Loc.AATags.Scope,
+ Call->getMetadata(LLVMContext::MD_noalias)))
+ return ModRefInfo::NoModRef;
+
+ if (!mayAliasInScopes(Call->getMetadata(LLVMContext::MD_alias_scope),
+ Loc.AATags.NoAlias))
+ return ModRefInfo::NoModRef;
+
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+}
+
+ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call1,
+ const CallBase *Call2,
+ AAQueryInfo &AAQI) {
+ if (!EnableScopedNoAlias)
+ return AAResultBase::getModRefInfo(Call1, Call2, AAQI);
+
+ if (!mayAliasInScopes(Call1->getMetadata(LLVMContext::MD_alias_scope),
+ Call2->getMetadata(LLVMContext::MD_noalias)))
+ return ModRefInfo::NoModRef;
+
+ if (!mayAliasInScopes(Call2->getMetadata(LLVMContext::MD_alias_scope),
+ Call1->getMetadata(LLVMContext::MD_noalias)))
+ return ModRefInfo::NoModRef;
+
+ return AAResultBase::getModRefInfo(Call1, Call2, AAQI);
+}
+
+static void collectMDInDomain(const MDNode *List, const MDNode *Domain,
+ SmallPtrSetImpl<const MDNode *> &Nodes) {
+ for (const MDOperand &MDOp : List->operands())
+ if (const MDNode *MD = dyn_cast<MDNode>(MDOp))
+ if (AliasScopeNode(MD).getDomain() == Domain)
+ Nodes.insert(MD);
+}
+
+bool ScopedNoAliasAAResult::mayAliasInScopes(const MDNode *Scopes,
+ const MDNode *NoAlias) const {
+ if (!Scopes || !NoAlias)
+ return true;
+
+ // Collect the set of scope domains relevant to the noalias scopes.
+ SmallPtrSet<const MDNode *, 16> Domains;
+ for (const MDOperand &MDOp : NoAlias->operands())
+ if (const MDNode *NAMD = dyn_cast<MDNode>(MDOp))
+ if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain())
+ Domains.insert(Domain);
+
+ // We alias unless, for some domain, the set of noalias scopes in that domain
+ // is a superset of the set of alias scopes in that domain.
+ for (const MDNode *Domain : Domains) {
+ SmallPtrSet<const MDNode *, 16> ScopeNodes;
+ collectMDInDomain(Scopes, Domain, ScopeNodes);
+ if (ScopeNodes.empty())
+ continue;
+
+ SmallPtrSet<const MDNode *, 16> NANodes;
+ collectMDInDomain(NoAlias, Domain, NANodes);
+
+ // To not alias, all of the nodes in ScopeNodes must be in NANodes.
+ bool FoundAll = true;
+ for (const MDNode *SMD : ScopeNodes)
+ if (!NANodes.count(SMD)) {
+ FoundAll = false;
+ break;
+ }
+
+ if (FoundAll)
+ return false;
+ }
+
+ return true;
+}
+
+AnalysisKey ScopedNoAliasAA::Key;
+
+ScopedNoAliasAAResult ScopedNoAliasAA::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ return ScopedNoAliasAAResult();
+}
+
+char ScopedNoAliasAAWrapperPass::ID = 0;
+
+INITIALIZE_PASS(ScopedNoAliasAAWrapperPass, "scoped-noalias",
+ "Scoped NoAlias Alias Analysis", false, true)
+
+ImmutablePass *llvm::createScopedNoAliasAAWrapperPass() {
+ return new ScopedNoAliasAAWrapperPass();
+}
+
+ScopedNoAliasAAWrapperPass::ScopedNoAliasAAWrapperPass() : ImmutablePass(ID) {
+ initializeScopedNoAliasAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool ScopedNoAliasAAWrapperPass::doInitialization(Module &M) {
+ Result.reset(new ScopedNoAliasAAResult());
+ return false;
+}
+
+bool ScopedNoAliasAAWrapperPass::doFinalization(Module &M) {
+ Result.reset();
+ return false;
+}
+
+void ScopedNoAliasAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
new file mode 100644
index 000000000000..1b3638698950
--- /dev/null
+++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
@@ -0,0 +1,676 @@
+//===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/StackSafetyAnalysis.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "stack-safety"
+
+static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
+ cl::init(20), cl::Hidden);
+
+namespace {
+
+/// Rewrite an SCEV expression for a memory access address to an expression that
+/// represents offset from the given alloca.
+class AllocaOffsetRewriter : public SCEVRewriteVisitor<AllocaOffsetRewriter> {
+ const Value *AllocaPtr;
+
+public:
+ AllocaOffsetRewriter(ScalarEvolution &SE, const Value *AllocaPtr)
+ : SCEVRewriteVisitor(SE), AllocaPtr(AllocaPtr) {}
+
+ const SCEV *visit(const SCEV *Expr) {
+ // Only re-write the expression if the alloca is used in an addition
+ // expression (it can be used in other types of expressions if it's cast to
+ // an int and passed as an argument.)
+ if (!isa<SCEVAddRecExpr>(Expr) && !isa<SCEVAddExpr>(Expr) &&
+ !isa<SCEVUnknown>(Expr))
+ return Expr;
+ return SCEVRewriteVisitor<AllocaOffsetRewriter>::visit(Expr);
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ // FIXME: look through one or several levels of definitions?
+ // This can be inttoptr(AllocaPtr) and SCEV would not unwrap
+ // it for us.
+ if (Expr->getValue() == AllocaPtr)
+ return SE.getZero(Expr->getType());
+ return Expr;
+ }
+};
+
+/// Describes use of address in as a function call argument.
+struct PassAsArgInfo {
+ /// Function being called.
+ const GlobalValue *Callee = nullptr;
+ /// Index of argument which pass address.
+ size_t ParamNo = 0;
+ // Offset range of address from base address (alloca or calling function
+ // argument).
+ // Range should never set to empty-set, that is an invalid access range
+ // that can cause empty-set to be propagated with ConstantRange::add
+ ConstantRange Offset;
+ PassAsArgInfo(const GlobalValue *Callee, size_t ParamNo, ConstantRange Offset)
+ : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
+
+ StringRef getName() const { return Callee->getName(); }
+};
+
+raw_ostream &operator<<(raw_ostream &OS, const PassAsArgInfo &P) {
+ return OS << "@" << P.getName() << "(arg" << P.ParamNo << ", " << P.Offset
+ << ")";
+}
+
+/// Describe uses of address (alloca or parameter) inside of the function.
+struct UseInfo {
+ // Access range if the address (alloca or parameters).
+ // It is allowed to be empty-set when there are no known accesses.
+ ConstantRange Range;
+
+ // List of calls which pass address as an argument.
+ SmallVector<PassAsArgInfo, 4> Calls;
+
+ explicit UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
+
+ void updateRange(ConstantRange R) { Range = Range.unionWith(R); }
+};
+
+raw_ostream &operator<<(raw_ostream &OS, const UseInfo &U) {
+ OS << U.Range;
+ for (auto &Call : U.Calls)
+ OS << ", " << Call;
+ return OS;
+}
+
+struct AllocaInfo {
+ const AllocaInst *AI = nullptr;
+ uint64_t Size = 0;
+ UseInfo Use;
+
+ AllocaInfo(unsigned PointerSize, const AllocaInst *AI, uint64_t Size)
+ : AI(AI), Size(Size), Use(PointerSize) {}
+
+ StringRef getName() const { return AI->getName(); }
+};
+
+raw_ostream &operator<<(raw_ostream &OS, const AllocaInfo &A) {
+ return OS << A.getName() << "[" << A.Size << "]: " << A.Use;
+}
+
+struct ParamInfo {
+ const Argument *Arg = nullptr;
+ UseInfo Use;
+
+ explicit ParamInfo(unsigned PointerSize, const Argument *Arg)
+ : Arg(Arg), Use(PointerSize) {}
+
+ StringRef getName() const { return Arg ? Arg->getName() : "<N/A>"; }
+};
+
+raw_ostream &operator<<(raw_ostream &OS, const ParamInfo &P) {
+ return OS << P.getName() << "[]: " << P.Use;
+}
+
+/// Calculate the allocation size of a given alloca. Returns 0 if the
+/// size can not be statically determined.
+uint64_t getStaticAllocaAllocationSize(const AllocaInst *AI) {
+ const DataLayout &DL = AI->getModule()->getDataLayout();
+ uint64_t Size = DL.getTypeAllocSize(AI->getAllocatedType());
+ if (AI->isArrayAllocation()) {
+ auto C = dyn_cast<ConstantInt>(AI->getArraySize());
+ if (!C)
+ return 0;
+ Size *= C->getZExtValue();
+ }
+ return Size;
+}
+
+} // end anonymous namespace
+
+/// Describes uses of allocas and parameters inside of a single function.
+struct StackSafetyInfo::FunctionInfo {
+ // May be a Function or a GlobalAlias
+ const GlobalValue *GV = nullptr;
+ // Informations about allocas uses.
+ SmallVector<AllocaInfo, 4> Allocas;
+ // Informations about parameters uses.
+ SmallVector<ParamInfo, 4> Params;
+ // TODO: describe return value as depending on one or more of its arguments.
+
+ // StackSafetyDataFlowAnalysis counter stored here for faster access.
+ int UpdateCount = 0;
+
+ FunctionInfo(const StackSafetyInfo &SSI) : FunctionInfo(*SSI.Info) {}
+
+ explicit FunctionInfo(const Function *F) : GV(F){};
+ // Creates FunctionInfo that forwards all the parameters to the aliasee.
+ explicit FunctionInfo(const GlobalAlias *A);
+
+ FunctionInfo(FunctionInfo &&) = default;
+
+ bool IsDSOLocal() const { return GV->isDSOLocal(); };
+
+ bool IsInterposable() const { return GV->isInterposable(); };
+
+ StringRef getName() const { return GV->getName(); }
+
+ void print(raw_ostream &O) const {
+ // TODO: Consider different printout format after
+ // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
+ O << " @" << getName() << (IsDSOLocal() ? "" : " dso_preemptable")
+ << (IsInterposable() ? " interposable" : "") << "\n";
+ O << " args uses:\n";
+ for (auto &P : Params)
+ O << " " << P << "\n";
+ O << " allocas uses:\n";
+ for (auto &AS : Allocas)
+ O << " " << AS << "\n";
+ }
+
+private:
+ FunctionInfo(const FunctionInfo &) = default;
+};
+
+StackSafetyInfo::FunctionInfo::FunctionInfo(const GlobalAlias *A) : GV(A) {
+ unsigned PointerSize = A->getParent()->getDataLayout().getPointerSizeInBits();
+ const GlobalObject *Aliasee = A->getBaseObject();
+ const FunctionType *Type = cast<FunctionType>(Aliasee->getValueType());
+ // 'Forward' all parameters to this alias to the aliasee
+ for (unsigned ArgNo = 0; ArgNo < Type->getNumParams(); ArgNo++) {
+ Params.emplace_back(PointerSize, nullptr);
+ UseInfo &US = Params.back().Use;
+ US.Calls.emplace_back(Aliasee, ArgNo, ConstantRange(APInt(PointerSize, 0)));
+ }
+}
+
+namespace {
+
+class StackSafetyLocalAnalysis {
+ const Function &F;
+ const DataLayout &DL;
+ ScalarEvolution &SE;
+ unsigned PointerSize = 0;
+
+ const ConstantRange UnknownRange;
+
+ ConstantRange offsetFromAlloca(Value *Addr, const Value *AllocaPtr);
+ ConstantRange getAccessRange(Value *Addr, const Value *AllocaPtr,
+ uint64_t AccessSize);
+ ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
+ const Value *AllocaPtr);
+
+ bool analyzeAllUses(const Value *Ptr, UseInfo &AS);
+
+ ConstantRange getRange(uint64_t Lower, uint64_t Upper) const {
+ return ConstantRange(APInt(PointerSize, Lower), APInt(PointerSize, Upper));
+ }
+
+public:
+ StackSafetyLocalAnalysis(const Function &F, ScalarEvolution &SE)
+ : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
+ PointerSize(DL.getPointerSizeInBits()),
+ UnknownRange(PointerSize, true) {}
+
+ // Run the transformation on the associated function.
+ StackSafetyInfo run();
+};
+
+ConstantRange
+StackSafetyLocalAnalysis::offsetFromAlloca(Value *Addr,
+ const Value *AllocaPtr) {
+ if (!SE.isSCEVable(Addr->getType()))
+ return UnknownRange;
+
+ AllocaOffsetRewriter Rewriter(SE, AllocaPtr);
+ const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr));
+ ConstantRange Offset = SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize);
+ assert(!Offset.isEmptySet());
+ return Offset;
+}
+
+ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr,
+ const Value *AllocaPtr,
+ uint64_t AccessSize) {
+ if (!SE.isSCEVable(Addr->getType()))
+ return UnknownRange;
+
+ AllocaOffsetRewriter Rewriter(SE, AllocaPtr);
+ const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr));
+
+ ConstantRange AccessStartRange =
+ SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize);
+ ConstantRange SizeRange = getRange(0, AccessSize);
+ ConstantRange AccessRange = AccessStartRange.add(SizeRange);
+ assert(!AccessRange.isEmptySet());
+ return AccessRange;
+}
+
+ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
+ const MemIntrinsic *MI, const Use &U, const Value *AllocaPtr) {
+ if (auto MTI = dyn_cast<MemTransferInst>(MI)) {
+ if (MTI->getRawSource() != U && MTI->getRawDest() != U)
+ return getRange(0, 1);
+ } else {
+ if (MI->getRawDest() != U)
+ return getRange(0, 1);
+ }
+ const auto *Len = dyn_cast<ConstantInt>(MI->getLength());
+ // Non-constant size => unsafe. FIXME: try SCEV getRange.
+ if (!Len)
+ return UnknownRange;
+ ConstantRange AccessRange = getAccessRange(U, AllocaPtr, Len->getZExtValue());
+ return AccessRange;
+}
+
+/// The function analyzes all local uses of Ptr (alloca or argument) and
+/// calculates local access range and all function calls where it was used.
+bool StackSafetyLocalAnalysis::analyzeAllUses(const Value *Ptr, UseInfo &US) {
+ SmallPtrSet<const Value *, 16> Visited;
+ SmallVector<const Value *, 8> WorkList;
+ WorkList.push_back(Ptr);
+
+ // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
+ while (!WorkList.empty()) {
+ const Value *V = WorkList.pop_back_val();
+ for (const Use &UI : V->uses()) {
+ auto I = cast<const Instruction>(UI.getUser());
+ assert(V == UI.get());
+
+ switch (I->getOpcode()) {
+ case Instruction::Load: {
+ US.updateRange(
+ getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
+ break;
+ }
+
+ case Instruction::VAArg:
+ // "va-arg" from a pointer is safe.
+ break;
+ case Instruction::Store: {
+ if (V == I->getOperand(0)) {
+ // Stored the pointer - conservatively assume it may be unsafe.
+ US.updateRange(UnknownRange);
+ return false;
+ }
+ US.updateRange(getAccessRange(
+ UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
+ break;
+ }
+
+ case Instruction::Ret:
+ // Information leak.
+ // FIXME: Process parameters correctly. This is a leak only if we return
+ // alloca.
+ US.updateRange(UnknownRange);
+ return false;
+
+ case Instruction::Call:
+ case Instruction::Invoke: {
+ ImmutableCallSite CS(I);
+
+ if (I->isLifetimeStartOrEnd())
+ break;
+
+ if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
+ US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
+ break;
+ }
+
+ // FIXME: consult devirt?
+ // Do not follow aliases, otherwise we could inadvertently follow
+ // dso_preemptable aliases or aliases with interposable linkage.
+ const GlobalValue *Callee =
+ dyn_cast<GlobalValue>(CS.getCalledValue()->stripPointerCasts());
+ if (!Callee) {
+ US.updateRange(UnknownRange);
+ return false;
+ }
+
+ assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
+
+ ImmutableCallSite::arg_iterator B = CS.arg_begin(), E = CS.arg_end();
+ for (ImmutableCallSite::arg_iterator A = B; A != E; ++A) {
+ if (A->get() == V) {
+ ConstantRange Offset = offsetFromAlloca(UI, Ptr);
+ US.Calls.emplace_back(Callee, A - B, Offset);
+ }
+ }
+
+ break;
+ }
+
+ default:
+ if (Visited.insert(I).second)
+ WorkList.push_back(cast<const Instruction>(I));
+ }
+ }
+ }
+
+ return true;
+}
+
+StackSafetyInfo StackSafetyLocalAnalysis::run() {
+ StackSafetyInfo::FunctionInfo Info(&F);
+ assert(!F.isDeclaration() &&
+ "Can't run StackSafety on a function declaration");
+
+ LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
+
+ for (auto &I : instructions(F)) {
+ if (auto AI = dyn_cast<AllocaInst>(&I)) {
+ Info.Allocas.emplace_back(PointerSize, AI,
+ getStaticAllocaAllocationSize(AI));
+ AllocaInfo &AS = Info.Allocas.back();
+ analyzeAllUses(AI, AS.Use);
+ }
+ }
+
+ for (const Argument &A : make_range(F.arg_begin(), F.arg_end())) {
+ Info.Params.emplace_back(PointerSize, &A);
+ ParamInfo &PS = Info.Params.back();
+ analyzeAllUses(&A, PS.Use);
+ }
+
+ LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
+ LLVM_DEBUG(Info.print(dbgs()));
+ return StackSafetyInfo(std::move(Info));
+}
+
+class StackSafetyDataFlowAnalysis {
+ using FunctionMap =
+ std::map<const GlobalValue *, StackSafetyInfo::FunctionInfo>;
+
+ FunctionMap Functions;
+ // Callee-to-Caller multimap.
+ DenseMap<const GlobalValue *, SmallVector<const GlobalValue *, 4>> Callers;
+ SetVector<const GlobalValue *> WorkList;
+
+ unsigned PointerSize = 0;
+ const ConstantRange UnknownRange;
+
+ ConstantRange getArgumentAccessRange(const GlobalValue *Callee,
+ unsigned ParamNo) const;
+ bool updateOneUse(UseInfo &US, bool UpdateToFullSet);
+ void updateOneNode(const GlobalValue *Callee,
+ StackSafetyInfo::FunctionInfo &FS);
+ void updateOneNode(const GlobalValue *Callee) {
+ updateOneNode(Callee, Functions.find(Callee)->second);
+ }
+ void updateAllNodes() {
+ for (auto &F : Functions)
+ updateOneNode(F.first, F.second);
+ }
+ void runDataFlow();
+#ifndef NDEBUG
+ void verifyFixedPoint();
+#endif
+
+public:
+ StackSafetyDataFlowAnalysis(
+ Module &M, std::function<const StackSafetyInfo &(Function &)> FI);
+ StackSafetyGlobalInfo run();
+};
+
+StackSafetyDataFlowAnalysis::StackSafetyDataFlowAnalysis(
+ Module &M, std::function<const StackSafetyInfo &(Function &)> FI)
+ : PointerSize(M.getDataLayout().getPointerSizeInBits()),
+ UnknownRange(PointerSize, true) {
+ // Without ThinLTO, run the local analysis for every function in the TU and
+ // then run the DFA.
+ for (auto &F : M.functions())
+ if (!F.isDeclaration())
+ Functions.emplace(&F, FI(F));
+ for (auto &A : M.aliases())
+ if (isa<Function>(A.getBaseObject()))
+ Functions.emplace(&A, StackSafetyInfo::FunctionInfo(&A));
+}
+
+ConstantRange
+StackSafetyDataFlowAnalysis::getArgumentAccessRange(const GlobalValue *Callee,
+ unsigned ParamNo) const {
+ auto IT = Functions.find(Callee);
+ // Unknown callee (outside of LTO domain or an indirect call).
+ if (IT == Functions.end())
+ return UnknownRange;
+ const StackSafetyInfo::FunctionInfo &FS = IT->second;
+ // The definition of this symbol may not be the definition in this linkage
+ // unit.
+ if (!FS.IsDSOLocal() || FS.IsInterposable())
+ return UnknownRange;
+ if (ParamNo >= FS.Params.size()) // possibly vararg
+ return UnknownRange;
+ return FS.Params[ParamNo].Use.Range;
+}
+
+bool StackSafetyDataFlowAnalysis::updateOneUse(UseInfo &US,
+ bool UpdateToFullSet) {
+ bool Changed = false;
+ for (auto &CS : US.Calls) {
+ assert(!CS.Offset.isEmptySet() &&
+ "Param range can't be empty-set, invalid offset range");
+
+ ConstantRange CalleeRange = getArgumentAccessRange(CS.Callee, CS.ParamNo);
+ CalleeRange = CalleeRange.add(CS.Offset);
+ if (!US.Range.contains(CalleeRange)) {
+ Changed = true;
+ if (UpdateToFullSet)
+ US.Range = UnknownRange;
+ else
+ US.Range = US.Range.unionWith(CalleeRange);
+ }
+ }
+ return Changed;
+}
+
+void StackSafetyDataFlowAnalysis::updateOneNode(
+ const GlobalValue *Callee, StackSafetyInfo::FunctionInfo &FS) {
+ bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
+ bool Changed = false;
+ for (auto &AS : FS.Allocas)
+ Changed |= updateOneUse(AS.Use, UpdateToFullSet);
+ for (auto &PS : FS.Params)
+ Changed |= updateOneUse(PS.Use, UpdateToFullSet);
+
+ if (Changed) {
+ LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
+ << (UpdateToFullSet ? ", full-set" : "") << "] "
+ << FS.getName() << "\n");
+ // Callers of this function may need updating.
+ for (auto &CallerID : Callers[Callee])
+ WorkList.insert(CallerID);
+
+ ++FS.UpdateCount;
+ }
+}
+
+void StackSafetyDataFlowAnalysis::runDataFlow() {
+ Callers.clear();
+ WorkList.clear();
+
+ SmallVector<const GlobalValue *, 16> Callees;
+ for (auto &F : Functions) {
+ Callees.clear();
+ StackSafetyInfo::FunctionInfo &FS = F.second;
+ for (auto &AS : FS.Allocas)
+ for (auto &CS : AS.Use.Calls)
+ Callees.push_back(CS.Callee);
+ for (auto &PS : FS.Params)
+ for (auto &CS : PS.Use.Calls)
+ Callees.push_back(CS.Callee);
+
+ llvm::sort(Callees);
+ Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
+
+ for (auto &Callee : Callees)
+ Callers[Callee].push_back(F.first);
+ }
+
+ updateAllNodes();
+
+ while (!WorkList.empty()) {
+ const GlobalValue *Callee = WorkList.back();
+ WorkList.pop_back();
+ updateOneNode(Callee);
+ }
+}
+
+#ifndef NDEBUG
+void StackSafetyDataFlowAnalysis::verifyFixedPoint() {
+ WorkList.clear();
+ updateAllNodes();
+ assert(WorkList.empty());
+}
+#endif
+
+StackSafetyGlobalInfo StackSafetyDataFlowAnalysis::run() {
+ runDataFlow();
+ LLVM_DEBUG(verifyFixedPoint());
+
+ StackSafetyGlobalInfo SSI;
+ for (auto &F : Functions)
+ SSI.emplace(F.first, std::move(F.second));
+ return SSI;
+}
+
+void print(const StackSafetyGlobalInfo &SSI, raw_ostream &O, const Module &M) {
+ size_t Count = 0;
+ for (auto &F : M.functions())
+ if (!F.isDeclaration()) {
+ SSI.find(&F)->second.print(O);
+ O << "\n";
+ ++Count;
+ }
+ for (auto &A : M.aliases()) {
+ SSI.find(&A)->second.print(O);
+ O << "\n";
+ ++Count;
+ }
+ assert(Count == SSI.size() && "Unexpected functions in the result");
+}
+
+} // end anonymous namespace
+
+StackSafetyInfo::StackSafetyInfo() = default;
+StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
+StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
+
+StackSafetyInfo::StackSafetyInfo(FunctionInfo &&Info)
+ : Info(new FunctionInfo(std::move(Info))) {}
+
+StackSafetyInfo::~StackSafetyInfo() = default;
+
+void StackSafetyInfo::print(raw_ostream &O) const { Info->print(O); }
+
+AnalysisKey StackSafetyAnalysis::Key;
+
+StackSafetyInfo StackSafetyAnalysis::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ StackSafetyLocalAnalysis SSLA(F, AM.getResult<ScalarEvolutionAnalysis>(F));
+ return SSLA.run();
+}
+
+PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
+ AM.getResult<StackSafetyAnalysis>(F).print(OS);
+ return PreservedAnalyses::all();
+}
+
+char StackSafetyInfoWrapperPass::ID = 0;
+
+StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
+ initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ AU.setPreservesAll();
+}
+
+void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
+ SSI.print(O);
+}
+
+bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
+ StackSafetyLocalAnalysis SSLA(
+ F, getAnalysis<ScalarEvolutionWrapperPass>().getSE());
+ SSI = StackSafetyInfo(SSLA.run());
+ return false;
+}
+
+AnalysisKey StackSafetyGlobalAnalysis::Key;
+
+StackSafetyGlobalInfo
+StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+
+ StackSafetyDataFlowAnalysis SSDFA(
+ M, [&FAM](Function &F) -> const StackSafetyInfo & {
+ return FAM.getResult<StackSafetyAnalysis>(F);
+ });
+ return SSDFA.run();
+}
+
+PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
+ print(AM.getResult<StackSafetyGlobalAnalysis>(M), OS, M);
+ return PreservedAnalyses::all();
+}
+
+char StackSafetyGlobalInfoWrapperPass::ID = 0;
+
+StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
+ : ModulePass(ID) {
+ initializeStackSafetyGlobalInfoWrapperPassPass(
+ *PassRegistry::getPassRegistry());
+}
+
+void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
+ const Module *M) const {
+ ::print(SSI, O, *M);
+}
+
+void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
+ AnalysisUsage &AU) const {
+ AU.addRequired<StackSafetyInfoWrapperPass>();
+}
+
+bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
+ StackSafetyDataFlowAnalysis SSDFA(
+ M, [this](Function &F) -> const StackSafetyInfo & {
+ return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
+ });
+ SSI = SSDFA.run();
+ return false;
+}
+
+static const char LocalPassArg[] = "stack-safety-local";
+static const char LocalPassName[] = "Stack Safety Local Analysis";
+INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
+ false, true)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
+ false, true)
+
+static const char GlobalPassName[] = "Stack Safety Analysis";
+INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
+ GlobalPassName, false, false)
+INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
+INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
+ GlobalPassName, false, false)
diff --git a/llvm/lib/Analysis/StratifiedSets.h b/llvm/lib/Analysis/StratifiedSets.h
new file mode 100644
index 000000000000..60ea2451b0ef
--- /dev/null
+++ b/llvm/lib/Analysis/StratifiedSets.h
@@ -0,0 +1,596 @@
+//===- StratifiedSets.h - Abstract stratified sets implementation. --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_STRATIFIEDSETS_H
+#define LLVM_ADT_STRATIFIEDSETS_H
+
+#include "AliasAnalysisSummary.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include <bitset>
+#include <cassert>
+#include <cmath>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace llvm {
+namespace cflaa {
+/// An index into Stratified Sets.
+typedef unsigned StratifiedIndex;
+/// NOTE: ^ This can't be a short -- bootstrapping clang has a case where
+/// ~1M sets exist.
+
+// Container of information related to a value in a StratifiedSet.
+struct StratifiedInfo {
+ StratifiedIndex Index;
+ /// For field sensitivity, etc. we can tack fields on here.
+};
+
+/// A "link" between two StratifiedSets.
+struct StratifiedLink {
+ /// This is a value used to signify "does not exist" where the
+ /// StratifiedIndex type is used.
+ ///
+ /// This is used instead of Optional<StratifiedIndex> because
+ /// Optional<StratifiedIndex> would eat up a considerable amount of extra
+ /// memory, after struct padding/alignment is taken into account.
+ static const StratifiedIndex SetSentinel;
+
+ /// The index for the set "above" current
+ StratifiedIndex Above;
+
+ /// The link for the set "below" current
+ StratifiedIndex Below;
+
+ /// Attributes for these StratifiedSets.
+ AliasAttrs Attrs;
+
+ StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {}
+
+ bool hasBelow() const { return Below != SetSentinel; }
+ bool hasAbove() const { return Above != SetSentinel; }
+
+ void clearBelow() { Below = SetSentinel; }
+ void clearAbove() { Above = SetSentinel; }
+};
+
+/// These are stratified sets, as described in "Fast algorithms for
+/// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M
+/// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets
+/// of Value*s. If two Value*s are in the same set, or if both sets have
+/// overlapping attributes, then the Value*s are said to alias.
+///
+/// Sets may be related by position, meaning that one set may be considered as
+/// above or below another. In CFL Alias Analysis, this gives us an indication
+/// of how two variables are related; if the set of variable A is below a set
+/// containing variable B, then at some point, a variable that has interacted
+/// with B (or B itself) was either used in order to extract the variable A, or
+/// was used as storage of variable A.
+///
+/// Sets may also have attributes (as noted above). These attributes are
+/// generally used for noting whether a variable in the set has interacted with
+/// a variable whose origins we don't quite know (i.e. globals/arguments), or if
+/// the variable may have had operations performed on it (modified in a function
+/// call). All attributes that exist in a set A must exist in all sets marked as
+/// below set A.
+template <typename T> class StratifiedSets {
+public:
+ StratifiedSets() = default;
+ StratifiedSets(StratifiedSets &&) = default;
+ StratifiedSets &operator=(StratifiedSets &&) = default;
+
+ StratifiedSets(DenseMap<T, StratifiedInfo> Map,
+ std::vector<StratifiedLink> Links)
+ : Values(std::move(Map)), Links(std::move(Links)) {}
+
+ Optional<StratifiedInfo> find(const T &Elem) const {
+ auto Iter = Values.find(Elem);
+ if (Iter == Values.end())
+ return None;
+ return Iter->second;
+ }
+
+ const StratifiedLink &getLink(StratifiedIndex Index) const {
+ assert(inbounds(Index));
+ return Links[Index];
+ }
+
+private:
+ DenseMap<T, StratifiedInfo> Values;
+ std::vector<StratifiedLink> Links;
+
+ bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); }
+};
+
+/// Generic Builder class that produces StratifiedSets instances.
+///
+/// The goal of this builder is to efficiently produce correct StratifiedSets
+/// instances. To this end, we use a few tricks:
+/// > Set chains (A method for linking sets together)
+/// > Set remaps (A method for marking a set as an alias [irony?] of another)
+///
+/// ==== Set chains ====
+/// This builder has a notion of some value A being above, below, or with some
+/// other value B:
+/// > The `A above B` relationship implies that there is a reference edge
+/// going from A to B. Namely, it notes that A can store anything in B's set.
+/// > The `A below B` relationship is the opposite of `A above B`. It implies
+/// that there's a dereference edge going from A to B.
+/// > The `A with B` relationship states that there's an assignment edge going
+/// from A to B, and that A and B should be treated as equals.
+///
+/// As an example, take the following code snippet:
+///
+/// %a = alloca i32, align 4
+/// %ap = alloca i32*, align 8
+/// %app = alloca i32**, align 8
+/// store %a, %ap
+/// store %ap, %app
+/// %aw = getelementptr %ap, i32 0
+///
+/// Given this, the following relations exist:
+/// - %a below %ap & %ap above %a
+/// - %ap below %app & %app above %ap
+/// - %aw with %ap & %ap with %aw
+///
+/// These relations produce the following sets:
+/// [{%a}, {%ap, %aw}, {%app}]
+///
+/// ...Which state that the only MayAlias relationship in the above program is
+/// between %ap and %aw.
+///
+/// Because LLVM allows arbitrary casts, code like the following needs to be
+/// supported:
+/// %ip = alloca i64, align 8
+/// %ipp = alloca i64*, align 8
+/// %i = bitcast i64** ipp to i64
+/// store i64* %ip, i64** %ipp
+/// store i64 %i, i64* %ip
+///
+/// Which, because %ipp ends up *both* above and below %ip, is fun.
+///
+/// This is solved by merging %i and %ipp into a single set (...which is the
+/// only way to solve this, since their bit patterns are equivalent). Any sets
+/// that ended up in between %i and %ipp at the time of merging (in this case,
+/// the set containing %ip) also get conservatively merged into the set of %i
+/// and %ipp. In short, the resulting StratifiedSet from the above code would be
+/// {%ip, %ipp, %i}.
+///
+/// ==== Set remaps ====
+/// More of an implementation detail than anything -- when merging sets, we need
+/// to update the numbers of all of the elements mapped to those sets. Rather
+/// than doing this at each merge, we note in the BuilderLink structure that a
+/// remap has occurred, and use this information so we can defer renumbering set
+/// elements until build time.
+template <typename T> class StratifiedSetsBuilder {
+ /// Represents a Stratified Set, with information about the Stratified
+ /// Set above it, the set below it, and whether the current set has been
+ /// remapped to another.
+ struct BuilderLink {
+ const StratifiedIndex Number;
+
+ BuilderLink(StratifiedIndex N) : Number(N) {
+ Remap = StratifiedLink::SetSentinel;
+ }
+
+ bool hasAbove() const {
+ assert(!isRemapped());
+ return Link.hasAbove();
+ }
+
+ bool hasBelow() const {
+ assert(!isRemapped());
+ return Link.hasBelow();
+ }
+
+ void setBelow(StratifiedIndex I) {
+ assert(!isRemapped());
+ Link.Below = I;
+ }
+
+ void setAbove(StratifiedIndex I) {
+ assert(!isRemapped());
+ Link.Above = I;
+ }
+
+ void clearBelow() {
+ assert(!isRemapped());
+ Link.clearBelow();
+ }
+
+ void clearAbove() {
+ assert(!isRemapped());
+ Link.clearAbove();
+ }
+
+ StratifiedIndex getBelow() const {
+ assert(!isRemapped());
+ assert(hasBelow());
+ return Link.Below;
+ }
+
+ StratifiedIndex getAbove() const {
+ assert(!isRemapped());
+ assert(hasAbove());
+ return Link.Above;
+ }
+
+ AliasAttrs getAttrs() {
+ assert(!isRemapped());
+ return Link.Attrs;
+ }
+
+ void setAttrs(AliasAttrs Other) {
+ assert(!isRemapped());
+ Link.Attrs |= Other;
+ }
+
+ bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; }
+
+ /// For initial remapping to another set
+ void remapTo(StratifiedIndex Other) {
+ assert(!isRemapped());
+ Remap = Other;
+ }
+
+ StratifiedIndex getRemapIndex() const {
+ assert(isRemapped());
+ return Remap;
+ }
+
+ /// Should only be called when we're already remapped.
+ void updateRemap(StratifiedIndex Other) {
+ assert(isRemapped());
+ Remap = Other;
+ }
+
+ /// Prefer the above functions to calling things directly on what's returned
+ /// from this -- they guard against unexpected calls when the current
+ /// BuilderLink is remapped.
+ const StratifiedLink &getLink() const { return Link; }
+
+ private:
+ StratifiedLink Link;
+ StratifiedIndex Remap;
+ };
+
+ /// This function performs all of the set unioning/value renumbering
+ /// that we've been putting off, and generates a vector<StratifiedLink> that
+ /// may be placed in a StratifiedSets instance.
+ void finalizeSets(std::vector<StratifiedLink> &StratLinks) {
+ DenseMap<StratifiedIndex, StratifiedIndex> Remaps;
+ for (auto &Link : Links) {
+ if (Link.isRemapped())
+ continue;
+
+ StratifiedIndex Number = StratLinks.size();
+ Remaps.insert(std::make_pair(Link.Number, Number));
+ StratLinks.push_back(Link.getLink());
+ }
+
+ for (auto &Link : StratLinks) {
+ if (Link.hasAbove()) {
+ auto &Above = linksAt(Link.Above);
+ auto Iter = Remaps.find(Above.Number);
+ assert(Iter != Remaps.end());
+ Link.Above = Iter->second;
+ }
+
+ if (Link.hasBelow()) {
+ auto &Below = linksAt(Link.Below);
+ auto Iter = Remaps.find(Below.Number);
+ assert(Iter != Remaps.end());
+ Link.Below = Iter->second;
+ }
+ }
+
+ for (auto &Pair : Values) {
+ auto &Info = Pair.second;
+ auto &Link = linksAt(Info.Index);
+ auto Iter = Remaps.find(Link.Number);
+ assert(Iter != Remaps.end());
+ Info.Index = Iter->second;
+ }
+ }
+
+ /// There's a guarantee in StratifiedLink where all bits set in a
+ /// Link.externals will be set in all Link.externals "below" it.
+ static void propagateAttrs(std::vector<StratifiedLink> &Links) {
+ const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) {
+ const auto *Link = &Links[Idx];
+ while (Link->hasAbove()) {
+ Idx = Link->Above;
+ Link = &Links[Idx];
+ }
+ return Idx;
+ };
+
+ SmallSet<StratifiedIndex, 16> Visited;
+ for (unsigned I = 0, E = Links.size(); I < E; ++I) {
+ auto CurrentIndex = getHighestParentAbove(I);
+ if (!Visited.insert(CurrentIndex).second)
+ continue;
+
+ while (Links[CurrentIndex].hasBelow()) {
+ auto &CurrentBits = Links[CurrentIndex].Attrs;
+ auto NextIndex = Links[CurrentIndex].Below;
+ auto &NextBits = Links[NextIndex].Attrs;
+ NextBits |= CurrentBits;
+ CurrentIndex = NextIndex;
+ }
+ }
+ }
+
+public:
+ /// Builds a StratifiedSet from the information we've been given since either
+ /// construction or the prior build() call.
+ StratifiedSets<T> build() {
+ std::vector<StratifiedLink> StratLinks;
+ finalizeSets(StratLinks);
+ propagateAttrs(StratLinks);
+ Links.clear();
+ return StratifiedSets<T>(std::move(Values), std::move(StratLinks));
+ }
+
+ bool has(const T &Elem) const { return get(Elem).hasValue(); }
+
+ bool add(const T &Main) {
+ if (get(Main).hasValue())
+ return false;
+
+ auto NewIndex = getNewUnlinkedIndex();
+ return addAtMerging(Main, NewIndex);
+ }
+
+ /// Restructures the stratified sets as necessary to make "ToAdd" in a
+ /// set above "Main". There are some cases where this is not possible (see
+ /// above), so we merge them such that ToAdd and Main are in the same set.
+ bool addAbove(const T &Main, const T &ToAdd) {
+ assert(has(Main));
+ auto Index = *indexOf(Main);
+ if (!linksAt(Index).hasAbove())
+ addLinkAbove(Index);
+
+ auto Above = linksAt(Index).getAbove();
+ return addAtMerging(ToAdd, Above);
+ }
+
+ /// Restructures the stratified sets as necessary to make "ToAdd" in a
+ /// set below "Main". There are some cases where this is not possible (see
+ /// above), so we merge them such that ToAdd and Main are in the same set.
+ bool addBelow(const T &Main, const T &ToAdd) {
+ assert(has(Main));
+ auto Index = *indexOf(Main);
+ if (!linksAt(Index).hasBelow())
+ addLinkBelow(Index);
+
+ auto Below = linksAt(Index).getBelow();
+ return addAtMerging(ToAdd, Below);
+ }
+
+ bool addWith(const T &Main, const T &ToAdd) {
+ assert(has(Main));
+ auto MainIndex = *indexOf(Main);
+ return addAtMerging(ToAdd, MainIndex);
+ }
+
+ void noteAttributes(const T &Main, AliasAttrs NewAttrs) {
+ assert(has(Main));
+ auto *Info = *get(Main);
+ auto &Link = linksAt(Info->Index);
+ Link.setAttrs(NewAttrs);
+ }
+
+private:
+ DenseMap<T, StratifiedInfo> Values;
+ std::vector<BuilderLink> Links;
+
+ /// Adds the given element at the given index, merging sets if necessary.
+ bool addAtMerging(const T &ToAdd, StratifiedIndex Index) {
+ StratifiedInfo Info = {Index};
+ auto Pair = Values.insert(std::make_pair(ToAdd, Info));
+ if (Pair.second)
+ return true;
+
+ auto &Iter = Pair.first;
+ auto &IterSet = linksAt(Iter->second.Index);
+ auto &ReqSet = linksAt(Index);
+
+ // Failed to add where we wanted to. Merge the sets.
+ if (&IterSet != &ReqSet)
+ merge(IterSet.Number, ReqSet.Number);
+
+ return false;
+ }
+
+ /// Gets the BuilderLink at the given index, taking set remapping into
+ /// account.
+ BuilderLink &linksAt(StratifiedIndex Index) {
+ auto *Start = &Links[Index];
+ if (!Start->isRemapped())
+ return *Start;
+
+ auto *Current = Start;
+ while (Current->isRemapped())
+ Current = &Links[Current->getRemapIndex()];
+
+ auto NewRemap = Current->Number;
+
+ // Run through everything that has yet to be updated, and update them to
+ // remap to NewRemap
+ Current = Start;
+ while (Current->isRemapped()) {
+ auto *Next = &Links[Current->getRemapIndex()];
+ Current->updateRemap(NewRemap);
+ Current = Next;
+ }
+
+ return *Current;
+ }
+
+ /// Merges two sets into one another. Assumes that these sets are not
+ /// already one in the same.
+ void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) {
+ assert(inbounds(Idx1) && inbounds(Idx2));
+ assert(&linksAt(Idx1) != &linksAt(Idx2) &&
+ "Merging a set into itself is not allowed");
+
+ // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge
+ // both the
+ // given sets, and all sets between them, into one.
+ if (tryMergeUpwards(Idx1, Idx2))
+ return;
+
+ if (tryMergeUpwards(Idx2, Idx1))
+ return;
+
+ // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`.
+ // We therefore need to merge the two chains together.
+ mergeDirect(Idx1, Idx2);
+ }
+
+ /// Merges two sets assuming that the set at `Idx1` is unreachable from
+ /// traversing above or below the set at `Idx2`.
+ void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) {
+ assert(inbounds(Idx1) && inbounds(Idx2));
+
+ auto *LinksInto = &linksAt(Idx1);
+ auto *LinksFrom = &linksAt(Idx2);
+ // Merging everything above LinksInto then proceeding to merge everything
+ // below LinksInto becomes problematic, so we go as far "up" as possible!
+ while (LinksInto->hasAbove() && LinksFrom->hasAbove()) {
+ LinksInto = &linksAt(LinksInto->getAbove());
+ LinksFrom = &linksAt(LinksFrom->getAbove());
+ }
+
+ if (LinksFrom->hasAbove()) {
+ LinksInto->setAbove(LinksFrom->getAbove());
+ auto &NewAbove = linksAt(LinksInto->getAbove());
+ NewAbove.setBelow(LinksInto->Number);
+ }
+
+ // Merging strategy:
+ // > If neither has links below, stop.
+ // > If only `LinksInto` has links below, stop.
+ // > If only `LinksFrom` has links below, reset `LinksInto.Below` to
+ // match `LinksFrom.Below`
+ // > If both have links above, deal with those next.
+ while (LinksInto->hasBelow() && LinksFrom->hasBelow()) {
+ auto FromAttrs = LinksFrom->getAttrs();
+ LinksInto->setAttrs(FromAttrs);
+
+ // Remap needs to happen after getBelow(), but before
+ // assignment of LinksFrom
+ auto *NewLinksFrom = &linksAt(LinksFrom->getBelow());
+ LinksFrom->remapTo(LinksInto->Number);
+ LinksFrom = NewLinksFrom;
+ LinksInto = &linksAt(LinksInto->getBelow());
+ }
+
+ if (LinksFrom->hasBelow()) {
+ LinksInto->setBelow(LinksFrom->getBelow());
+ auto &NewBelow = linksAt(LinksInto->getBelow());
+ NewBelow.setAbove(LinksInto->Number);
+ }
+
+ LinksInto->setAttrs(LinksFrom->getAttrs());
+ LinksFrom->remapTo(LinksInto->Number);
+ }
+
+ /// Checks to see if lowerIndex is at a level lower than upperIndex. If so, it
+ /// will merge lowerIndex with upperIndex (and all of the sets between) and
+ /// return true. Otherwise, it will return false.
+ bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) {
+ assert(inbounds(LowerIndex) && inbounds(UpperIndex));
+ auto *Lower = &linksAt(LowerIndex);
+ auto *Upper = &linksAt(UpperIndex);
+ if (Lower == Upper)
+ return true;
+
+ SmallVector<BuilderLink *, 8> Found;
+ auto *Current = Lower;
+ auto Attrs = Current->getAttrs();
+ while (Current->hasAbove() && Current != Upper) {
+ Found.push_back(Current);
+ Attrs |= Current->getAttrs();
+ Current = &linksAt(Current->getAbove());
+ }
+
+ if (Current != Upper)
+ return false;
+
+ Upper->setAttrs(Attrs);
+
+ if (Lower->hasBelow()) {
+ auto NewBelowIndex = Lower->getBelow();
+ Upper->setBelow(NewBelowIndex);
+ auto &NewBelow = linksAt(NewBelowIndex);
+ NewBelow.setAbove(UpperIndex);
+ } else {
+ Upper->clearBelow();
+ }
+
+ for (const auto &Ptr : Found)
+ Ptr->remapTo(Upper->Number);
+
+ return true;
+ }
+
+ Optional<const StratifiedInfo *> get(const T &Val) const {
+ auto Result = Values.find(Val);
+ if (Result == Values.end())
+ return None;
+ return &Result->second;
+ }
+
+ Optional<StratifiedInfo *> get(const T &Val) {
+ auto Result = Values.find(Val);
+ if (Result == Values.end())
+ return None;
+ return &Result->second;
+ }
+
+ Optional<StratifiedIndex> indexOf(const T &Val) {
+ auto MaybeVal = get(Val);
+ if (!MaybeVal.hasValue())
+ return None;
+ auto *Info = *MaybeVal;
+ auto &Link = linksAt(Info->Index);
+ return Link.Number;
+ }
+
+ StratifiedIndex addLinkBelow(StratifiedIndex Set) {
+ auto At = addLinks();
+ Links[Set].setBelow(At);
+ Links[At].setAbove(Set);
+ return At;
+ }
+
+ StratifiedIndex addLinkAbove(StratifiedIndex Set) {
+ auto At = addLinks();
+ Links[At].setBelow(Set);
+ Links[Set].setAbove(At);
+ return At;
+ }
+
+ StratifiedIndex getNewUnlinkedIndex() { return addLinks(); }
+
+ StratifiedIndex addLinks() {
+ auto Link = Links.size();
+ Links.push_back(BuilderLink(Link));
+ return Link;
+ }
+
+ bool inbounds(StratifiedIndex N) const { return N < Links.size(); }
+};
+}
+}
+#endif // LLVM_ADT_STRATIFIEDSETS_H
diff --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp
new file mode 100644
index 000000000000..8447dc87069d
--- /dev/null
+++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp
@@ -0,0 +1,380 @@
+//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an algorithm that returns for a divergent branch
+// the set of basic blocks whose phi nodes become divergent due to divergent
+// control. These are the blocks that are reachable by two disjoint paths from
+// the branch or loop exits that have a reaching path that is disjoint from a
+// path to the loop latch.
+//
+// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
+// control-induced divergence in phi nodes.
+//
+// -- Summary --
+// The SyncDependenceAnalysis lazily computes sync dependences [3].
+// The analysis evaluates the disjoint path criterion [2] by a reduction
+// to SSA construction. The SSA construction algorithm is implemented as
+// a simple data-flow analysis [1].
+//
+// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
+// [2] "Efficiently Computing Static Single Assignment Form
+// and the Control Dependence Graph", TOPLAS '91,
+// Cytron, Ferrante, Rosen, Wegman and Zadeck
+// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
+// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
+//
+// -- Sync dependence --
+// Sync dependence [4] characterizes the control flow aspect of the
+// propagation of branch divergence. For example,
+//
+// %cond = icmp slt i32 %tid, 10
+// br i1 %cond, label %then, label %else
+// then:
+// br label %merge
+// else:
+// br label %merge
+// merge:
+// %a = phi i32 [ 0, %then ], [ 1, %else ]
+//
+// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
+// because %tid is not on its use-def chains, %a is sync dependent on %tid
+// because the branch "br i1 %cond" depends on %tid and affects which value %a
+// is assigned to.
+//
+// -- Reduction to SSA construction --
+// There are two disjoint paths from A to X, if a certain variant of SSA
+// construction places a phi node in X under the following set-up scheme [2].
+//
+// This variant of SSA construction ignores incoming undef values.
+// That is paths from the entry without a definition do not result in
+// phi nodes.
+//
+// entry
+// / \
+// A \
+// / \ Y
+// B C /
+// \ / \ /
+// D E
+// \ /
+// F
+// Assume that A contains a divergent branch. We are interested
+// in the set of all blocks where each block is reachable from A
+// via two disjoint paths. This would be the set {D, F} in this
+// case.
+// To generally reduce this query to SSA construction we introduce
+// a virtual variable x and assign to x different values in each
+// successor block of A.
+// entry
+// / \
+// A \
+// / \ Y
+// x = 0 x = 1 /
+// \ / \ /
+// D E
+// \ /
+// F
+// Our flavor of SSA construction for x will construct the following
+// entry
+// / \
+// A \
+// / \ Y
+// x0 = 0 x1 = 1 /
+// \ / \ /
+// x2=phi E
+// \ /
+// x3=phi
+// The blocks D and F contain phi nodes and are thus each reachable
+// by two disjoins paths from A.
+//
+// -- Remarks --
+// In case of loop exits we need to check the disjoint path criterion for loops
+// [2]. To this end, we check whether the definition of x differs between the
+// loop exit and the loop header (_after_ SSA construction).
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/SyncDependenceAnalysis.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+
+#include <stack>
+#include <unordered_set>
+
+#define DEBUG_TYPE "sync-dependence"
+
+namespace llvm {
+
+ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;
+
+SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
+ const PostDominatorTree &PDT,
+ const LoopInfo &LI)
+ : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}
+
+SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
+
+using FunctionRPOT = ReversePostOrderTraversal<const Function *>;
+
+// divergence propagator for reducible CFGs
+struct DivergencePropagator {
+ const FunctionRPOT &FuncRPOT;
+ const DominatorTree &DT;
+ const PostDominatorTree &PDT;
+ const LoopInfo &LI;
+
+ // identified join points
+ std::unique_ptr<ConstBlockSet> JoinBlocks;
+
+ // reached loop exits (by a path disjoint to a path to the loop header)
+ SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;
+
+ // if DefMap[B] == C then C is the dominating definition at block B
+ // if DefMap[B] ~ undef then we haven't seen B yet
+ // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
+ // an immediate successor of X (initial value).
+ using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
+ DefiningBlockMap DefMap;
+
+ // all blocks with pending visits
+ std::unordered_set<const BasicBlock *> PendingUpdates;
+
+ DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
+ const PostDominatorTree &PDT, const LoopInfo &LI)
+ : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
+ JoinBlocks(new ConstBlockSet) {}
+
+ // set the definition at @block and mark @block as pending for a visit
+ void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
+ bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
+ if (WasAdded)
+ PendingUpdates.insert(&Block);
+ }
+
+ void printDefs(raw_ostream &Out) {
+ Out << "Propagator::DefMap {\n";
+ for (const auto *Block : FuncRPOT) {
+ auto It = DefMap.find(Block);
+ Out << Block->getName() << " : ";
+ if (It == DefMap.end()) {
+ Out << "\n";
+ } else {
+ const auto *DefBlock = It->second;
+ Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
+ }
+ }
+ Out << "}\n";
+ }
+
+ // process @succBlock with reaching definition @defBlock
+ // the original divergent branch was in @parentLoop (if any)
+ void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
+ const BasicBlock &DefBlock) {
+
+ // @succBlock is a loop exit
+ if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
+ DefMap.emplace(&SuccBlock, &DefBlock);
+ ReachedLoopExits.insert(&SuccBlock);
+ return;
+ }
+
+ // first reaching def?
+ auto ItLastDef = DefMap.find(&SuccBlock);
+ if (ItLastDef == DefMap.end()) {
+ addPending(SuccBlock, DefBlock);
+ return;
+ }
+
+ // a join of at least two definitions
+ if (ItLastDef->second != &DefBlock) {
+ // do we know this join already?
+ if (!JoinBlocks->insert(&SuccBlock).second)
+ return;
+
+ // update the definition
+ addPending(SuccBlock, SuccBlock);
+ }
+ }
+
+ // find all blocks reachable by two disjoint paths from @rootTerm.
+ // This method works for both divergent terminators and loops with
+ // divergent exits.
+ // @rootBlock is either the block containing the branch or the header of the
+ // divergent loop.
+ // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
+ // headed by @rootBlock.
+ // @parentLoop is the parent loop of the Loop or the loop that contains the
+ // Terminator.
+ template <typename SuccessorIterable>
+ std::unique_ptr<ConstBlockSet>
+ computeJoinPoints(const BasicBlock &RootBlock,
+ SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
+ assert(JoinBlocks);
+
+ LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop ? ParentLoop->getName() : "<null>") << "\n" );
+
+ // bootstrap with branch targets
+ for (const auto *SuccBlock : NodeSuccessors) {
+ DefMap.emplace(SuccBlock, SuccBlock);
+
+ if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
+ // immediate loop exit from node.
+ ReachedLoopExits.insert(SuccBlock);
+ } else {
+ // regular successor
+ PendingUpdates.insert(SuccBlock);
+ }
+ }
+
+ LLVM_DEBUG(
+ dbgs() << "SDA: rpo order:\n";
+ for (const auto * RpoBlock : FuncRPOT) {
+ dbgs() << "- " << RpoBlock->getName() << "\n";
+ }
+ );
+
+ auto ItBeginRPO = FuncRPOT.begin();
+
+ // skip until term (TODO RPOT won't let us start at @term directly)
+ for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {}
+
+ auto ItEndRPO = FuncRPOT.end();
+ assert(ItBeginRPO != ItEndRPO);
+
+ // propagate definitions at the immediate successors of the node in RPO
+ auto ItBlockRPO = ItBeginRPO;
+ while ((++ItBlockRPO != ItEndRPO) &&
+ !PendingUpdates.empty()) {
+ const auto *Block = *ItBlockRPO;
+ LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
+
+ // skip Block if not pending update
+ auto ItPending = PendingUpdates.find(Block);
+ if (ItPending == PendingUpdates.end())
+ continue;
+ PendingUpdates.erase(ItPending);
+
+ // propagate definition at Block to its successors
+ auto ItDef = DefMap.find(Block);
+ const auto *DefBlock = ItDef->second;
+ assert(DefBlock);
+
+ auto *BlockLoop = LI.getLoopFor(Block);
+ if (ParentLoop &&
+ (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
+ // if the successor is the header of a nested loop pretend its a
+ // single node with the loop's exits as successors
+ SmallVector<BasicBlock *, 4> BlockLoopExits;
+ BlockLoop->getExitBlocks(BlockLoopExits);
+ for (const auto *BlockLoopExit : BlockLoopExits) {
+ visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
+ }
+
+ } else {
+ // the successors are either on the same loop level or loop exits
+ for (const auto *SuccBlock : successors(Block)) {
+ visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
+ }
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
+
+ // We need to know the definition at the parent loop header to decide
+ // whether the definition at the header is different from the definition at
+ // the loop exits, which would indicate a divergent loop exits.
+ //
+ // A // loop header
+ // |
+ // B // nested loop header
+ // |
+ // C -> X (exit from B loop) -..-> (A latch)
+ // |
+ // D -> back to B (B latch)
+ // |
+ // proper exit from both loops
+ //
+ // analyze reached loop exits
+ if (!ReachedLoopExits.empty()) {
+ const BasicBlock *ParentLoopHeader =
+ ParentLoop ? ParentLoop->getHeader() : nullptr;
+
+ assert(ParentLoop);
+ auto ItHeaderDef = DefMap.find(ParentLoopHeader);
+ const auto *HeaderDefBlock = (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;
+
+ LLVM_DEBUG(printDefs(dbgs()));
+ assert(HeaderDefBlock && "no definition at header of carrying loop");
+
+ for (const auto *ExitBlock : ReachedLoopExits) {
+ auto ItExitDef = DefMap.find(ExitBlock);
+ assert((ItExitDef != DefMap.end()) &&
+ "no reaching def at reachable loop exit");
+ if (ItExitDef->second != HeaderDefBlock) {
+ JoinBlocks->insert(ExitBlock);
+ }
+ }
+ }
+
+ return std::move(JoinBlocks);
+ }
+};
+
+const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
+ using LoopExitVec = SmallVector<BasicBlock *, 4>;
+ LoopExitVec LoopExits;
+ Loop.getExitBlocks(LoopExits);
+ if (LoopExits.size() < 1) {
+ return EmptyBlockSet;
+ }
+
+ // already available in cache?
+ auto ItCached = CachedLoopExitJoins.find(&Loop);
+ if (ItCached != CachedLoopExitJoins.end()) {
+ return *ItCached->second;
+ }
+
+ // compute all join points
+ DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
+ auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
+ *Loop.getHeader(), LoopExits, Loop.getParentLoop());
+
+ auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
+ assert(ItInserted.second);
+ return *ItInserted.first->second;
+}
+
+const ConstBlockSet &
+SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
+ // trivial case
+ if (Term.getNumSuccessors() < 1) {
+ return EmptyBlockSet;
+ }
+
+ // already available in cache?
+ auto ItCached = CachedBranchJoins.find(&Term);
+ if (ItCached != CachedBranchJoins.end())
+ return *ItCached->second;
+
+ // compute all join points
+ DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
+ const auto &TermBlock = *Term.getParent();
+ auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
+ TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
+
+ auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
+ assert(ItInserted.second);
+ return *ItInserted.first->second;
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp
new file mode 100644
index 000000000000..22766e5f07f5
--- /dev/null
+++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp
@@ -0,0 +1,104 @@
+//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities for propagating synthetic counts.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/SyntheticCountsUtils.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/ModuleSummaryIndex.h"
+
+using namespace llvm;
+
+// Given an SCC, propagate entry counts along the edge of the SCC nodes.
+template <typename CallGraphType>
+void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
+ const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
+
+ DenseSet<NodeRef> SCCNodes;
+ SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
+
+ for (auto &Node : SCC)
+ SCCNodes.insert(Node);
+
+ // Partition the edges coming out of the SCC into those whose destination is
+ // in the SCC and the rest.
+ for (const auto &Node : SCCNodes) {
+ for (auto &E : children_edges<CallGraphType>(Node)) {
+ if (SCCNodes.count(CGT::edge_dest(E)))
+ SCCEdges.emplace_back(Node, E);
+ else
+ NonSCCEdges.emplace_back(Node, E);
+ }
+ }
+
+ // For nodes in the same SCC, update the counts in two steps:
+ // 1. Compute the additional count for each node by propagating the counts
+ // along all incoming edges to the node that originate from within the same
+ // SCC and summing them up.
+ // 2. Add the additional counts to the nodes in the SCC.
+ // This ensures that the order of
+ // traversal of nodes within the SCC doesn't affect the final result.
+
+ DenseMap<NodeRef, Scaled64> AdditionalCounts;
+ for (auto &E : SCCEdges) {
+ auto OptProfCount = GetProfCount(E.first, E.second);
+ if (!OptProfCount)
+ continue;
+ auto Callee = CGT::edge_dest(E.second);
+ AdditionalCounts[Callee] += OptProfCount.getValue();
+ }
+
+ // Update the counts for the nodes in the SCC.
+ for (auto &Entry : AdditionalCounts)
+ AddCount(Entry.first, Entry.second);
+
+ // Now update the counts for nodes outside the SCC.
+ for (auto &E : NonSCCEdges) {
+ auto OptProfCount = GetProfCount(E.first, E.second);
+ if (!OptProfCount)
+ continue;
+ auto Callee = CGT::edge_dest(E.second);
+ AddCount(Callee, OptProfCount.getValue());
+ }
+}
+
+/// Propgate synthetic entry counts on a callgraph \p CG.
+///
+/// This performs a reverse post-order traversal of the callgraph SCC. For each
+/// SCC, it first propagates the entry counts to the nodes within the SCC
+/// through call edges and updates them in one shot. Then the entry counts are
+/// propagated to nodes outside the SCC. This requires \p GraphTraits
+/// to have a specialization for \p CallGraphType.
+
+template <typename CallGraphType>
+void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
+ GetProfCountTy GetProfCount,
+ AddCountTy AddCount) {
+ std::vector<SccTy> SCCs;
+
+ // Collect all the SCCs.
+ for (auto I = scc_begin(CG); !I.isAtEnd(); ++I)
+ SCCs.push_back(*I);
+
+ // The callgraph-scc needs to be visited in top-down order for propagation.
+ // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
+ // and call propagateFromSCC.
+ for (auto &SCC : reverse(SCCs))
+ propagateFromSCC(SCC, GetProfCount, AddCount);
+}
+
+template class llvm::SyntheticCountsUtils<const CallGraph *>;
+template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>;
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
new file mode 100644
index 000000000000..230969698054
--- /dev/null
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -0,0 +1,1638 @@
+//===-- TargetLibraryInfo.cpp - Runtime library information ----------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TargetLibraryInfo class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/ADT/Triple.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/Support/CommandLine.h"
+using namespace llvm;
+
+static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary(
+ "vector-library", cl::Hidden, cl::desc("Vector functions library"),
+ cl::init(TargetLibraryInfoImpl::NoLibrary),
+ cl::values(clEnumValN(TargetLibraryInfoImpl::NoLibrary, "none",
+ "No vector functions library"),
+ clEnumValN(TargetLibraryInfoImpl::Accelerate, "Accelerate",
+ "Accelerate framework"),
+ clEnumValN(TargetLibraryInfoImpl::MASSV, "MASSV",
+ "IBM MASS vector library"),
+ clEnumValN(TargetLibraryInfoImpl::SVML, "SVML",
+ "Intel SVML library")));
+
+StringLiteral const TargetLibraryInfoImpl::StandardNames[LibFunc::NumLibFuncs] =
+ {
+#define TLI_DEFINE_STRING
+#include "llvm/Analysis/TargetLibraryInfo.def"
+};
+
+static bool hasSinCosPiStret(const Triple &T) {
+ // Only Darwin variants have _stret versions of combined trig functions.
+ if (!T.isOSDarwin())
+ return false;
+
+ // The ABI is rather complicated on x86, so don't do anything special there.
+ if (T.getArch() == Triple::x86)
+ return false;
+
+ if (T.isMacOSX() && T.isMacOSXVersionLT(10, 9))
+ return false;
+
+ if (T.isiOS() && T.isOSVersionLT(7, 0))
+ return false;
+
+ return true;
+}
+
+static bool hasBcmp(const Triple &TT) {
+ // Posix removed support from bcmp() in 2001, but the glibc and several
+ // implementations of the libc still have it.
+ if (TT.isOSLinux())
+ return TT.isGNUEnvironment() || TT.isMusl();
+ // Both NetBSD and OpenBSD are planning to remove the function. Windows does
+ // not have it.
+ return TT.isOSFreeBSD() || TT.isOSSolaris();
+}
+
+/// Initialize the set of available library functions based on the specified
+/// target triple. This should be carefully written so that a missing target
+/// triple gets a sane set of defaults.
+static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T,
+ ArrayRef<StringLiteral> StandardNames) {
+ // Verify that the StandardNames array is in alphabetical order.
+ assert(std::is_sorted(StandardNames.begin(), StandardNames.end(),
+ [](StringRef LHS, StringRef RHS) {
+ return LHS < RHS;
+ }) &&
+ "TargetLibraryInfoImpl function names must be sorted");
+
+ // Set IO unlocked variants as unavailable
+ // Set them as available per system below
+ TLI.setUnavailable(LibFunc_getchar_unlocked);
+ TLI.setUnavailable(LibFunc_putc_unlocked);
+ TLI.setUnavailable(LibFunc_putchar_unlocked);
+ TLI.setUnavailable(LibFunc_fputc_unlocked);
+ TLI.setUnavailable(LibFunc_fgetc_unlocked);
+ TLI.setUnavailable(LibFunc_fread_unlocked);
+ TLI.setUnavailable(LibFunc_fwrite_unlocked);
+ TLI.setUnavailable(LibFunc_fputs_unlocked);
+ TLI.setUnavailable(LibFunc_fgets_unlocked);
+
+ bool ShouldExtI32Param = false, ShouldExtI32Return = false,
+ ShouldSignExtI32Param = false;
+ // PowerPC64, Sparc64, SystemZ need signext/zeroext on i32 parameters and
+ // returns corresponding to C-level ints and unsigned ints.
+ if (T.isPPC64() || T.getArch() == Triple::sparcv9 ||
+ T.getArch() == Triple::systemz) {
+ ShouldExtI32Param = true;
+ ShouldExtI32Return = true;
+ }
+ // Mips, on the other hand, needs signext on i32 parameters corresponding
+ // to both signed and unsigned ints.
+ if (T.isMIPS()) {
+ ShouldSignExtI32Param = true;
+ }
+ TLI.setShouldExtI32Param(ShouldExtI32Param);
+ TLI.setShouldExtI32Return(ShouldExtI32Return);
+ TLI.setShouldSignExtI32Param(ShouldSignExtI32Param);
+
+ if (T.getArch() == Triple::r600 ||
+ T.getArch() == Triple::amdgcn)
+ TLI.disableAllFunctions();
+
+ // There are no library implementations of memcpy and memset for AMD gpus and
+ // these can be difficult to lower in the backend.
+ if (T.getArch() == Triple::r600 ||
+ T.getArch() == Triple::amdgcn) {
+ TLI.setUnavailable(LibFunc_memcpy);
+ TLI.setUnavailable(LibFunc_memset);
+ TLI.setUnavailable(LibFunc_memset_pattern16);
+ return;
+ }
+
+ // memset_pattern16 is only available on iOS 3.0 and Mac OS X 10.5 and later.
+ // All versions of watchOS support it.
+ if (T.isMacOSX()) {
+ // available IO unlocked variants on Mac OS X
+ TLI.setAvailable(LibFunc_getc_unlocked);
+ TLI.setAvailable(LibFunc_getchar_unlocked);
+ TLI.setAvailable(LibFunc_putc_unlocked);
+ TLI.setAvailable(LibFunc_putchar_unlocked);
+
+ if (T.isMacOSXVersionLT(10, 5))
+ TLI.setUnavailable(LibFunc_memset_pattern16);
+ } else if (T.isiOS()) {
+ if (T.isOSVersionLT(3, 0))
+ TLI.setUnavailable(LibFunc_memset_pattern16);
+ } else if (!T.isWatchOS()) {
+ TLI.setUnavailable(LibFunc_memset_pattern16);
+ }
+
+ if (!hasSinCosPiStret(T)) {
+ TLI.setUnavailable(LibFunc_sinpi);
+ TLI.setUnavailable(LibFunc_sinpif);
+ TLI.setUnavailable(LibFunc_cospi);
+ TLI.setUnavailable(LibFunc_cospif);
+ TLI.setUnavailable(LibFunc_sincospi_stret);
+ TLI.setUnavailable(LibFunc_sincospif_stret);
+ }
+
+ if (!hasBcmp(T))
+ TLI.setUnavailable(LibFunc_bcmp);
+
+ if (T.isMacOSX() && T.getArch() == Triple::x86 &&
+ !T.isMacOSXVersionLT(10, 7)) {
+ // x86-32 OSX has a scheme where fwrite and fputs (and some other functions
+ // we don't care about) have two versions; on recent OSX, the one we want
+ // has a $UNIX2003 suffix. The two implementations are identical except
+ // for the return value in some edge cases. However, we don't want to
+ // generate code that depends on the old symbols.
+ TLI.setAvailableWithName(LibFunc_fwrite, "fwrite$UNIX2003");
+ TLI.setAvailableWithName(LibFunc_fputs, "fputs$UNIX2003");
+ }
+
+ // iprintf and friends are only available on XCore, TCE, and Emscripten.
+ if (T.getArch() != Triple::xcore && T.getArch() != Triple::tce &&
+ T.getOS() != Triple::Emscripten) {
+ TLI.setUnavailable(LibFunc_iprintf);
+ TLI.setUnavailable(LibFunc_siprintf);
+ TLI.setUnavailable(LibFunc_fiprintf);
+ }
+
+ // __small_printf and friends are only available on Emscripten.
+ if (T.getOS() != Triple::Emscripten) {
+ TLI.setUnavailable(LibFunc_small_printf);
+ TLI.setUnavailable(LibFunc_small_sprintf);
+ TLI.setUnavailable(LibFunc_small_fprintf);
+ }
+
+ if (T.isOSWindows() && !T.isOSCygMing()) {
+ // XXX: The earliest documentation available at the moment is for VS2015/VC19:
+ // https://docs.microsoft.com/en-us/cpp/c-runtime-library/floating-point-support?view=vs-2015
+ // XXX: In order to use an MSVCRT older than VC19,
+ // the specific library version must be explicit in the target triple,
+ // e.g., x86_64-pc-windows-msvc18.
+ bool hasPartialC99 = true;
+ if (T.isKnownWindowsMSVCEnvironment()) {
+ unsigned Major, Minor, Micro;
+ T.getEnvironmentVersion(Major, Minor, Micro);
+ hasPartialC99 = (Major == 0 || Major >= 19);
+ }
+
+ // Latest targets support C89 math functions, in part.
+ bool isARM = (T.getArch() == Triple::aarch64 ||
+ T.getArch() == Triple::arm);
+ bool hasPartialFloat = (isARM ||
+ T.getArch() == Triple::x86_64);
+
+ // Win32 does not support float C89 math functions, in general.
+ if (!hasPartialFloat) {
+ TLI.setUnavailable(LibFunc_acosf);
+ TLI.setUnavailable(LibFunc_asinf);
+ TLI.setUnavailable(LibFunc_atan2f);
+ TLI.setUnavailable(LibFunc_atanf);
+ TLI.setUnavailable(LibFunc_ceilf);
+ TLI.setUnavailable(LibFunc_cosf);
+ TLI.setUnavailable(LibFunc_coshf);
+ TLI.setUnavailable(LibFunc_expf);
+ TLI.setUnavailable(LibFunc_floorf);
+ TLI.setUnavailable(LibFunc_fmodf);
+ TLI.setUnavailable(LibFunc_log10f);
+ TLI.setUnavailable(LibFunc_logf);
+ TLI.setUnavailable(LibFunc_modff);
+ TLI.setUnavailable(LibFunc_powf);
+ TLI.setUnavailable(LibFunc_sinf);
+ TLI.setUnavailable(LibFunc_sinhf);
+ TLI.setUnavailable(LibFunc_sqrtf);
+ TLI.setUnavailable(LibFunc_tanf);
+ TLI.setUnavailable(LibFunc_tanhf);
+ }
+ if (!isARM)
+ TLI.setUnavailable(LibFunc_fabsf);
+ TLI.setUnavailable(LibFunc_frexpf);
+ TLI.setUnavailable(LibFunc_ldexpf);
+
+ // Win32 does not support long double C89 math functions.
+ TLI.setUnavailable(LibFunc_acosl);
+ TLI.setUnavailable(LibFunc_asinl);
+ TLI.setUnavailable(LibFunc_atan2l);
+ TLI.setUnavailable(LibFunc_atanl);
+ TLI.setUnavailable(LibFunc_ceill);
+ TLI.setUnavailable(LibFunc_cosl);
+ TLI.setUnavailable(LibFunc_coshl);
+ TLI.setUnavailable(LibFunc_expl);
+ TLI.setUnavailable(LibFunc_fabsl);
+ TLI.setUnavailable(LibFunc_floorl);
+ TLI.setUnavailable(LibFunc_fmodl);
+ TLI.setUnavailable(LibFunc_frexpl);
+ TLI.setUnavailable(LibFunc_ldexpl);
+ TLI.setUnavailable(LibFunc_log10l);
+ TLI.setUnavailable(LibFunc_logl);
+ TLI.setUnavailable(LibFunc_modfl);
+ TLI.setUnavailable(LibFunc_powl);
+ TLI.setUnavailable(LibFunc_sinl);
+ TLI.setUnavailable(LibFunc_sinhl);
+ TLI.setUnavailable(LibFunc_sqrtl);
+ TLI.setUnavailable(LibFunc_tanl);
+ TLI.setUnavailable(LibFunc_tanhl);
+
+ // Win32 does not fully support C99 math functions.
+ if (!hasPartialC99) {
+ TLI.setUnavailable(LibFunc_acosh);
+ TLI.setUnavailable(LibFunc_acoshf);
+ TLI.setUnavailable(LibFunc_asinh);
+ TLI.setUnavailable(LibFunc_asinhf);
+ TLI.setUnavailable(LibFunc_atanh);
+ TLI.setUnavailable(LibFunc_atanhf);
+ TLI.setAvailableWithName(LibFunc_cabs, "_cabs");
+ TLI.setUnavailable(LibFunc_cabsf);
+ TLI.setUnavailable(LibFunc_cbrt);
+ TLI.setUnavailable(LibFunc_cbrtf);
+ TLI.setAvailableWithName(LibFunc_copysign, "_copysign");
+ TLI.setAvailableWithName(LibFunc_copysignf, "_copysignf");
+ TLI.setUnavailable(LibFunc_exp2);
+ TLI.setUnavailable(LibFunc_exp2f);
+ TLI.setUnavailable(LibFunc_expm1);
+ TLI.setUnavailable(LibFunc_expm1f);
+ TLI.setUnavailable(LibFunc_fmax);
+ TLI.setUnavailable(LibFunc_fmaxf);
+ TLI.setUnavailable(LibFunc_fmin);
+ TLI.setUnavailable(LibFunc_fminf);
+ TLI.setUnavailable(LibFunc_log1p);
+ TLI.setUnavailable(LibFunc_log1pf);
+ TLI.setUnavailable(LibFunc_log2);
+ TLI.setUnavailable(LibFunc_log2f);
+ TLI.setAvailableWithName(LibFunc_logb, "_logb");
+ if (hasPartialFloat)
+ TLI.setAvailableWithName(LibFunc_logbf, "_logbf");
+ else
+ TLI.setUnavailable(LibFunc_logbf);
+ TLI.setUnavailable(LibFunc_rint);
+ TLI.setUnavailable(LibFunc_rintf);
+ TLI.setUnavailable(LibFunc_round);
+ TLI.setUnavailable(LibFunc_roundf);
+ TLI.setUnavailable(LibFunc_trunc);
+ TLI.setUnavailable(LibFunc_truncf);
+ }
+
+ // Win32 does not support long double C99 math functions.
+ TLI.setUnavailable(LibFunc_acoshl);
+ TLI.setUnavailable(LibFunc_asinhl);
+ TLI.setUnavailable(LibFunc_atanhl);
+ TLI.setUnavailable(LibFunc_cabsl);
+ TLI.setUnavailable(LibFunc_cbrtl);
+ TLI.setUnavailable(LibFunc_copysignl);
+ TLI.setUnavailable(LibFunc_exp2l);
+ TLI.setUnavailable(LibFunc_expm1l);
+ TLI.setUnavailable(LibFunc_fmaxl);
+ TLI.setUnavailable(LibFunc_fminl);
+ TLI.setUnavailable(LibFunc_log1pl);
+ TLI.setUnavailable(LibFunc_log2l);
+ TLI.setUnavailable(LibFunc_logbl);
+ TLI.setUnavailable(LibFunc_nearbyintl);
+ TLI.setUnavailable(LibFunc_rintl);
+ TLI.setUnavailable(LibFunc_roundl);
+ TLI.setUnavailable(LibFunc_truncl);
+
+ // Win32 does not support these functions, but
+ // they are generally available on POSIX-compliant systems.
+ TLI.setUnavailable(LibFunc_access);
+ TLI.setUnavailable(LibFunc_bcmp);
+ TLI.setUnavailable(LibFunc_bcopy);
+ TLI.setUnavailable(LibFunc_bzero);
+ TLI.setUnavailable(LibFunc_chmod);
+ TLI.setUnavailable(LibFunc_chown);
+ TLI.setUnavailable(LibFunc_closedir);
+ TLI.setUnavailable(LibFunc_ctermid);
+ TLI.setUnavailable(LibFunc_fdopen);
+ TLI.setUnavailable(LibFunc_ffs);
+ TLI.setUnavailable(LibFunc_fileno);
+ TLI.setUnavailable(LibFunc_flockfile);
+ TLI.setUnavailable(LibFunc_fseeko);
+ TLI.setUnavailable(LibFunc_fstat);
+ TLI.setUnavailable(LibFunc_fstatvfs);
+ TLI.setUnavailable(LibFunc_ftello);
+ TLI.setUnavailable(LibFunc_ftrylockfile);
+ TLI.setUnavailable(LibFunc_funlockfile);
+ TLI.setUnavailable(LibFunc_getitimer);
+ TLI.setUnavailable(LibFunc_getlogin_r);
+ TLI.setUnavailable(LibFunc_getpwnam);
+ TLI.setUnavailable(LibFunc_gettimeofday);
+ TLI.setUnavailable(LibFunc_htonl);
+ TLI.setUnavailable(LibFunc_htons);
+ TLI.setUnavailable(LibFunc_lchown);
+ TLI.setUnavailable(LibFunc_lstat);
+ TLI.setUnavailable(LibFunc_memccpy);
+ TLI.setUnavailable(LibFunc_mkdir);
+ TLI.setUnavailable(LibFunc_ntohl);
+ TLI.setUnavailable(LibFunc_ntohs);
+ TLI.setUnavailable(LibFunc_open);
+ TLI.setUnavailable(LibFunc_opendir);
+ TLI.setUnavailable(LibFunc_pclose);
+ TLI.setUnavailable(LibFunc_popen);
+ TLI.setUnavailable(LibFunc_pread);
+ TLI.setUnavailable(LibFunc_pwrite);
+ TLI.setUnavailable(LibFunc_read);
+ TLI.setUnavailable(LibFunc_readlink);
+ TLI.setUnavailable(LibFunc_realpath);
+ TLI.setUnavailable(LibFunc_rmdir);
+ TLI.setUnavailable(LibFunc_setitimer);
+ TLI.setUnavailable(LibFunc_stat);
+ TLI.setUnavailable(LibFunc_statvfs);
+ TLI.setUnavailable(LibFunc_stpcpy);
+ TLI.setUnavailable(LibFunc_stpncpy);
+ TLI.setUnavailable(LibFunc_strcasecmp);
+ TLI.setUnavailable(LibFunc_strncasecmp);
+ TLI.setUnavailable(LibFunc_times);
+ TLI.setUnavailable(LibFunc_uname);
+ TLI.setUnavailable(LibFunc_unlink);
+ TLI.setUnavailable(LibFunc_unsetenv);
+ TLI.setUnavailable(LibFunc_utime);
+ TLI.setUnavailable(LibFunc_utimes);
+ TLI.setUnavailable(LibFunc_write);
+ }
+
+ switch (T.getOS()) {
+ case Triple::MacOSX:
+ // exp10 and exp10f are not available on OS X until 10.9 and iOS until 7.0
+ // and their names are __exp10 and __exp10f. exp10l is not available on
+ // OS X or iOS.
+ TLI.setUnavailable(LibFunc_exp10l);
+ if (T.isMacOSXVersionLT(10, 9)) {
+ TLI.setUnavailable(LibFunc_exp10);
+ TLI.setUnavailable(LibFunc_exp10f);
+ } else {
+ TLI.setAvailableWithName(LibFunc_exp10, "__exp10");
+ TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f");
+ }
+ break;
+ case Triple::IOS:
+ case Triple::TvOS:
+ case Triple::WatchOS:
+ TLI.setUnavailable(LibFunc_exp10l);
+ if (!T.isWatchOS() && (T.isOSVersionLT(7, 0) ||
+ (T.isOSVersionLT(9, 0) &&
+ (T.getArch() == Triple::x86 ||
+ T.getArch() == Triple::x86_64)))) {
+ TLI.setUnavailable(LibFunc_exp10);
+ TLI.setUnavailable(LibFunc_exp10f);
+ } else {
+ TLI.setAvailableWithName(LibFunc_exp10, "__exp10");
+ TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f");
+ }
+ break;
+ case Triple::Linux:
+ // exp10, exp10f, exp10l is available on Linux (GLIBC) but are extremely
+ // buggy prior to glibc version 2.18. Until this version is widely deployed
+ // or we have a reasonable detection strategy, we cannot use exp10 reliably
+ // on Linux.
+ //
+ // Fall through to disable all of them.
+ LLVM_FALLTHROUGH;
+ default:
+ TLI.setUnavailable(LibFunc_exp10);
+ TLI.setUnavailable(LibFunc_exp10f);
+ TLI.setUnavailable(LibFunc_exp10l);
+ }
+
+ // ffsl is available on at least Darwin, Mac OS X, iOS, FreeBSD, and
+ // Linux (GLIBC):
+ // http://developer.apple.com/library/mac/#documentation/Darwin/Reference/ManPages/man3/ffsl.3.html
+ // http://svn.freebsd.org/base/head/lib/libc/string/ffsl.c
+ // http://www.gnu.org/software/gnulib/manual/html_node/ffsl.html
+ switch (T.getOS()) {
+ case Triple::Darwin:
+ case Triple::MacOSX:
+ case Triple::IOS:
+ case Triple::TvOS:
+ case Triple::WatchOS:
+ case Triple::FreeBSD:
+ case Triple::Linux:
+ break;
+ default:
+ TLI.setUnavailable(LibFunc_ffsl);
+ }
+
+ // ffsll is available on at least FreeBSD and Linux (GLIBC):
+ // http://svn.freebsd.org/base/head/lib/libc/string/ffsll.c
+ // http://www.gnu.org/software/gnulib/manual/html_node/ffsll.html
+ switch (T.getOS()) {
+ case Triple::Darwin:
+ case Triple::MacOSX:
+ case Triple::IOS:
+ case Triple::TvOS:
+ case Triple::WatchOS:
+ case Triple::FreeBSD:
+ case Triple::Linux:
+ break;
+ default:
+ TLI.setUnavailable(LibFunc_ffsll);
+ }
+
+ // The following functions are available on at least FreeBSD:
+ // http://svn.freebsd.org/base/head/lib/libc/string/fls.c
+ // http://svn.freebsd.org/base/head/lib/libc/string/flsl.c
+ // http://svn.freebsd.org/base/head/lib/libc/string/flsll.c
+ if (!T.isOSFreeBSD()) {
+ TLI.setUnavailable(LibFunc_fls);
+ TLI.setUnavailable(LibFunc_flsl);
+ TLI.setUnavailable(LibFunc_flsll);
+ }
+
+ // The following functions are only available on GNU/Linux (using glibc).
+ // Linux variants without glibc (eg: bionic, musl) may have some subset.
+ if (!T.isOSLinux() || !T.isGNUEnvironment()) {
+ TLI.setUnavailable(LibFunc_dunder_strdup);
+ TLI.setUnavailable(LibFunc_dunder_strtok_r);
+ TLI.setUnavailable(LibFunc_dunder_isoc99_scanf);
+ TLI.setUnavailable(LibFunc_dunder_isoc99_sscanf);
+ TLI.setUnavailable(LibFunc_under_IO_getc);
+ TLI.setUnavailable(LibFunc_under_IO_putc);
+ // But, Android and musl have memalign.
+ if (!T.isAndroid() && !T.isMusl())
+ TLI.setUnavailable(LibFunc_memalign);
+ TLI.setUnavailable(LibFunc_fopen64);
+ TLI.setUnavailable(LibFunc_fseeko64);
+ TLI.setUnavailable(LibFunc_fstat64);
+ TLI.setUnavailable(LibFunc_fstatvfs64);
+ TLI.setUnavailable(LibFunc_ftello64);
+ TLI.setUnavailable(LibFunc_lstat64);
+ TLI.setUnavailable(LibFunc_open64);
+ TLI.setUnavailable(LibFunc_stat64);
+ TLI.setUnavailable(LibFunc_statvfs64);
+ TLI.setUnavailable(LibFunc_tmpfile64);
+
+ // Relaxed math functions are included in math-finite.h on Linux (GLIBC).
+ TLI.setUnavailable(LibFunc_acos_finite);
+ TLI.setUnavailable(LibFunc_acosf_finite);
+ TLI.setUnavailable(LibFunc_acosl_finite);
+ TLI.setUnavailable(LibFunc_acosh_finite);
+ TLI.setUnavailable(LibFunc_acoshf_finite);
+ TLI.setUnavailable(LibFunc_acoshl_finite);
+ TLI.setUnavailable(LibFunc_asin_finite);
+ TLI.setUnavailable(LibFunc_asinf_finite);
+ TLI.setUnavailable(LibFunc_asinl_finite);
+ TLI.setUnavailable(LibFunc_atan2_finite);
+ TLI.setUnavailable(LibFunc_atan2f_finite);
+ TLI.setUnavailable(LibFunc_atan2l_finite);
+ TLI.setUnavailable(LibFunc_atanh_finite);
+ TLI.setUnavailable(LibFunc_atanhf_finite);
+ TLI.setUnavailable(LibFunc_atanhl_finite);
+ TLI.setUnavailable(LibFunc_cosh_finite);
+ TLI.setUnavailable(LibFunc_coshf_finite);
+ TLI.setUnavailable(LibFunc_coshl_finite);
+ TLI.setUnavailable(LibFunc_exp10_finite);
+ TLI.setUnavailable(LibFunc_exp10f_finite);
+ TLI.setUnavailable(LibFunc_exp10l_finite);
+ TLI.setUnavailable(LibFunc_exp2_finite);
+ TLI.setUnavailable(LibFunc_exp2f_finite);
+ TLI.setUnavailable(LibFunc_exp2l_finite);
+ TLI.setUnavailable(LibFunc_exp_finite);
+ TLI.setUnavailable(LibFunc_expf_finite);
+ TLI.setUnavailable(LibFunc_expl_finite);
+ TLI.setUnavailable(LibFunc_log10_finite);
+ TLI.setUnavailable(LibFunc_log10f_finite);
+ TLI.setUnavailable(LibFunc_log10l_finite);
+ TLI.setUnavailable(LibFunc_log2_finite);
+ TLI.setUnavailable(LibFunc_log2f_finite);
+ TLI.setUnavailable(LibFunc_log2l_finite);
+ TLI.setUnavailable(LibFunc_log_finite);
+ TLI.setUnavailable(LibFunc_logf_finite);
+ TLI.setUnavailable(LibFunc_logl_finite);
+ TLI.setUnavailable(LibFunc_pow_finite);
+ TLI.setUnavailable(LibFunc_powf_finite);
+ TLI.setUnavailable(LibFunc_powl_finite);
+ TLI.setUnavailable(LibFunc_sinh_finite);
+ TLI.setUnavailable(LibFunc_sinhf_finite);
+ TLI.setUnavailable(LibFunc_sinhl_finite);
+ }
+
+ if ((T.isOSLinux() && T.isGNUEnvironment()) ||
+ (T.isAndroid() && !T.isAndroidVersionLT(28))) {
+ // available IO unlocked variants on GNU/Linux and Android P or later
+ TLI.setAvailable(LibFunc_getc_unlocked);
+ TLI.setAvailable(LibFunc_getchar_unlocked);
+ TLI.setAvailable(LibFunc_putc_unlocked);
+ TLI.setAvailable(LibFunc_putchar_unlocked);
+ TLI.setAvailable(LibFunc_fputc_unlocked);
+ TLI.setAvailable(LibFunc_fgetc_unlocked);
+ TLI.setAvailable(LibFunc_fread_unlocked);
+ TLI.setAvailable(LibFunc_fwrite_unlocked);
+ TLI.setAvailable(LibFunc_fputs_unlocked);
+ TLI.setAvailable(LibFunc_fgets_unlocked);
+ }
+
+ // As currently implemented in clang, NVPTX code has no standard library to
+ // speak of. Headers provide a standard-ish library implementation, but many
+ // of the signatures are wrong -- for example, many libm functions are not
+ // extern "C".
+ //
+ // libdevice, an IR library provided by nvidia, is linked in by the front-end,
+ // but only used functions are provided to llvm. Moreover, most of the
+ // functions in libdevice don't map precisely to standard library functions.
+ //
+ // FIXME: Having no standard library prevents e.g. many fastmath
+ // optimizations, so this situation should be fixed.
+ if (T.isNVPTX()) {
+ TLI.disableAllFunctions();
+ TLI.setAvailable(LibFunc_nvvm_reflect);
+ } else {
+ TLI.setUnavailable(LibFunc_nvvm_reflect);
+ }
+
+ TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary);
+}
+
+TargetLibraryInfoImpl::TargetLibraryInfoImpl() {
+ // Default to everything being available.
+ memset(AvailableArray, -1, sizeof(AvailableArray));
+
+ initialize(*this, Triple(), StandardNames);
+}
+
+TargetLibraryInfoImpl::TargetLibraryInfoImpl(const Triple &T) {
+ // Default to everything being available.
+ memset(AvailableArray, -1, sizeof(AvailableArray));
+
+ initialize(*this, T, StandardNames);
+}
+
+TargetLibraryInfoImpl::TargetLibraryInfoImpl(const TargetLibraryInfoImpl &TLI)
+ : CustomNames(TLI.CustomNames), ShouldExtI32Param(TLI.ShouldExtI32Param),
+ ShouldExtI32Return(TLI.ShouldExtI32Return),
+ ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) {
+ memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray));
+ VectorDescs = TLI.VectorDescs;
+ ScalarDescs = TLI.ScalarDescs;
+}
+
+TargetLibraryInfoImpl::TargetLibraryInfoImpl(TargetLibraryInfoImpl &&TLI)
+ : CustomNames(std::move(TLI.CustomNames)),
+ ShouldExtI32Param(TLI.ShouldExtI32Param),
+ ShouldExtI32Return(TLI.ShouldExtI32Return),
+ ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) {
+ std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray),
+ AvailableArray);
+ VectorDescs = TLI.VectorDescs;
+ ScalarDescs = TLI.ScalarDescs;
+}
+
+TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(const TargetLibraryInfoImpl &TLI) {
+ CustomNames = TLI.CustomNames;
+ ShouldExtI32Param = TLI.ShouldExtI32Param;
+ ShouldExtI32Return = TLI.ShouldExtI32Return;
+ ShouldSignExtI32Param = TLI.ShouldSignExtI32Param;
+ memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray));
+ return *this;
+}
+
+TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(TargetLibraryInfoImpl &&TLI) {
+ CustomNames = std::move(TLI.CustomNames);
+ ShouldExtI32Param = TLI.ShouldExtI32Param;
+ ShouldExtI32Return = TLI.ShouldExtI32Return;
+ ShouldSignExtI32Param = TLI.ShouldSignExtI32Param;
+ std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray),
+ AvailableArray);
+ return *this;
+}
+
+static StringRef sanitizeFunctionName(StringRef funcName) {
+ // Filter out empty names and names containing null bytes, those can't be in
+ // our table.
+ if (funcName.empty() || funcName.find('\0') != StringRef::npos)
+ return StringRef();
+
+ // Check for \01 prefix that is used to mangle __asm declarations and
+ // strip it if present.
+ return GlobalValue::dropLLVMManglingEscape(funcName);
+}
+
+bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, LibFunc &F) const {
+ funcName = sanitizeFunctionName(funcName);
+ if (funcName.empty())
+ return false;
+
+ const auto *Start = std::begin(StandardNames);
+ const auto *End = std::end(StandardNames);
+ const auto *I = std::lower_bound(Start, End, funcName);
+ if (I != End && *I == funcName) {
+ F = (LibFunc)(I - Start);
+ return true;
+ }
+ return false;
+}
+
+bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy,
+ LibFunc F,
+ const DataLayout *DL) const {
+ LLVMContext &Ctx = FTy.getContext();
+ Type *PCharTy = Type::getInt8PtrTy(Ctx);
+ Type *SizeTTy = DL ? DL->getIntPtrType(Ctx, /*AS=*/0) : nullptr;
+ auto IsSizeTTy = [SizeTTy](Type *Ty) {
+ return SizeTTy ? Ty == SizeTTy : Ty->isIntegerTy();
+ };
+ unsigned NumParams = FTy.getNumParams();
+
+ switch (F) {
+ case LibFunc_execl:
+ case LibFunc_execlp:
+ case LibFunc_execle:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32));
+ case LibFunc_execv:
+ case LibFunc_execvp:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32));
+ case LibFunc_execvP:
+ case LibFunc_execvpe:
+ case LibFunc_execve:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32));
+ case LibFunc_strlen:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy());
+
+ case LibFunc_strchr:
+ case LibFunc_strrchr:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getReturnType() &&
+ FTy.getParamType(1)->isIntegerTy());
+
+ case LibFunc_strtol:
+ case LibFunc_strtod:
+ case LibFunc_strtof:
+ case LibFunc_strtoul:
+ case LibFunc_strtoll:
+ case LibFunc_strtold:
+ case LibFunc_strtoull:
+ return ((NumParams == 2 || NumParams == 3) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_strcat_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_strcat:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getReturnType() &&
+ FTy.getParamType(1) == FTy.getReturnType());
+
+ case LibFunc_strncat_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_strncat:
+ return (NumParams == 3 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getReturnType() &&
+ FTy.getParamType(1) == FTy.getReturnType() &&
+ IsSizeTTy(FTy.getParamType(2)));
+
+ case LibFunc_strcpy_chk:
+ case LibFunc_stpcpy_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_strcpy:
+ case LibFunc_stpcpy:
+ return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(0) == FTy.getParamType(1) &&
+ FTy.getParamType(0) == PCharTy);
+
+ case LibFunc_strlcat_chk:
+ case LibFunc_strlcpy_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_strlcat:
+ case LibFunc_strlcpy:
+ return NumParams == 3 && IsSizeTTy(FTy.getReturnType()) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ IsSizeTTy(FTy.getParamType(2));
+
+ case LibFunc_strncpy_chk:
+ case LibFunc_stpncpy_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_strncpy:
+ case LibFunc_stpncpy:
+ return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(0) == FTy.getParamType(1) &&
+ FTy.getParamType(0) == PCharTy &&
+ IsSizeTTy(FTy.getParamType(2)));
+
+ case LibFunc_strxfrm:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+
+ case LibFunc_strcmp:
+ return (NumParams == 2 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getParamType(1));
+
+ case LibFunc_strncmp:
+ return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getParamType(1) &&
+ IsSizeTTy(FTy.getParamType(2)));
+
+ case LibFunc_strspn:
+ case LibFunc_strcspn:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(0) == FTy.getParamType(1) &&
+ FTy.getReturnType()->isIntegerTy());
+
+ case LibFunc_strcoll:
+ case LibFunc_strcasecmp:
+ case LibFunc_strncasecmp:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+
+ case LibFunc_strstr:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+
+ case LibFunc_strpbrk:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(0) == FTy.getParamType(1));
+
+ case LibFunc_strtok:
+ case LibFunc_strtok_r:
+ return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_scanf:
+ case LibFunc_setbuf:
+ case LibFunc_setvbuf:
+ return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_strdup:
+ case LibFunc_strndup:
+ return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy());
+ case LibFunc_sscanf:
+ case LibFunc_stat:
+ case LibFunc_statvfs:
+ case LibFunc_siprintf:
+ case LibFunc_small_sprintf:
+ case LibFunc_sprintf:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32));
+
+ case LibFunc_sprintf_chk:
+ return NumParams == 4 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isIntegerTy(32) &&
+ IsSizeTTy(FTy.getParamType(2)) &&
+ FTy.getParamType(3)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32);
+
+ case LibFunc_snprintf:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32));
+
+ case LibFunc_snprintf_chk:
+ return NumParams == 5 && FTy.getParamType(0)->isPointerTy() &&
+ IsSizeTTy(FTy.getParamType(1)) &&
+ FTy.getParamType(2)->isIntegerTy(32) &&
+ IsSizeTTy(FTy.getParamType(3)) &&
+ FTy.getParamType(4)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy(32);
+
+ case LibFunc_setitimer:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+ case LibFunc_system:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_malloc:
+ return (NumParams == 1 && FTy.getReturnType()->isPointerTy());
+ case LibFunc_memcmp:
+ return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+
+ case LibFunc_memchr:
+ case LibFunc_memrchr:
+ return (NumParams == 3 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(1)->isIntegerTy(32) &&
+ IsSizeTTy(FTy.getParamType(2)));
+ case LibFunc_modf:
+ case LibFunc_modff:
+ case LibFunc_modfl:
+ return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy());
+
+ case LibFunc_memcpy_chk:
+ case LibFunc_memmove_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_memcpy:
+ case LibFunc_mempcpy:
+ case LibFunc_memmove:
+ return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ IsSizeTTy(FTy.getParamType(2)));
+
+ case LibFunc_memset_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_memset:
+ return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isIntegerTy() &&
+ IsSizeTTy(FTy.getParamType(2)));
+
+ case LibFunc_memccpy_chk:
+ --NumParams;
+ if (!IsSizeTTy(FTy.getParamType(NumParams)))
+ return false;
+ LLVM_FALLTHROUGH;
+ case LibFunc_memccpy:
+ return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_memalign:
+ return (FTy.getReturnType()->isPointerTy());
+ case LibFunc_realloc:
+ case LibFunc_reallocf:
+ return (NumParams == 2 && FTy.getReturnType() == PCharTy &&
+ FTy.getParamType(0) == FTy.getReturnType() &&
+ IsSizeTTy(FTy.getParamType(1)));
+ case LibFunc_read:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_rewind:
+ case LibFunc_rmdir:
+ case LibFunc_remove:
+ case LibFunc_realpath:
+ return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_rename:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_readlink:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_write:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_bcopy:
+ case LibFunc_bcmp:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_bzero:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_calloc:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy());
+
+ case LibFunc_atof:
+ case LibFunc_atoi:
+ case LibFunc_atol:
+ case LibFunc_atoll:
+ case LibFunc_ferror:
+ case LibFunc_getenv:
+ case LibFunc_getpwnam:
+ case LibFunc_iprintf:
+ case LibFunc_small_printf:
+ case LibFunc_pclose:
+ case LibFunc_perror:
+ case LibFunc_printf:
+ case LibFunc_puts:
+ case LibFunc_uname:
+ case LibFunc_under_IO_getc:
+ case LibFunc_unlink:
+ case LibFunc_unsetenv:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy());
+
+ case LibFunc_access:
+ case LibFunc_chmod:
+ case LibFunc_chown:
+ case LibFunc_clearerr:
+ case LibFunc_closedir:
+ case LibFunc_ctermid:
+ case LibFunc_fclose:
+ case LibFunc_feof:
+ case LibFunc_fflush:
+ case LibFunc_fgetc:
+ case LibFunc_fgetc_unlocked:
+ case LibFunc_fileno:
+ case LibFunc_flockfile:
+ case LibFunc_free:
+ case LibFunc_fseek:
+ case LibFunc_fseeko64:
+ case LibFunc_fseeko:
+ case LibFunc_fsetpos:
+ case LibFunc_ftell:
+ case LibFunc_ftello64:
+ case LibFunc_ftello:
+ case LibFunc_ftrylockfile:
+ case LibFunc_funlockfile:
+ case LibFunc_getc:
+ case LibFunc_getc_unlocked:
+ case LibFunc_getlogin_r:
+ case LibFunc_mkdir:
+ case LibFunc_mktime:
+ case LibFunc_times:
+ return (NumParams != 0 && FTy.getParamType(0)->isPointerTy());
+
+ case LibFunc_fopen:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fork:
+ return (NumParams == 0 && FTy.getReturnType()->isIntegerTy(32));
+ case LibFunc_fdopen:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fputc:
+ case LibFunc_fputc_unlocked:
+ case LibFunc_fstat:
+ case LibFunc_frexp:
+ case LibFunc_frexpf:
+ case LibFunc_frexpl:
+ case LibFunc_fstatvfs:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fgets:
+ case LibFunc_fgets_unlocked:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+ case LibFunc_fread:
+ case LibFunc_fread_unlocked:
+ return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(3)->isPointerTy());
+ case LibFunc_fwrite:
+ case LibFunc_fwrite_unlocked:
+ return (NumParams == 4 && FTy.getReturnType()->isIntegerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isIntegerTy() &&
+ FTy.getParamType(2)->isIntegerTy() &&
+ FTy.getParamType(3)->isPointerTy());
+ case LibFunc_fputs:
+ case LibFunc_fputs_unlocked:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fscanf:
+ case LibFunc_fiprintf:
+ case LibFunc_small_fprintf:
+ case LibFunc_fprintf:
+ return (NumParams >= 2 && FTy.getReturnType()->isIntegerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fgetpos:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_getchar:
+ case LibFunc_getchar_unlocked:
+ return (NumParams == 0 && FTy.getReturnType()->isIntegerTy());
+ case LibFunc_gets:
+ return (NumParams == 1 && FTy.getParamType(0) == PCharTy);
+ case LibFunc_getitimer:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_ungetc:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_utime:
+ case LibFunc_utimes:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_putc:
+ case LibFunc_putc_unlocked:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_pread:
+ case LibFunc_pwrite:
+ return (NumParams == 4 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_popen:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_vscanf:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_vsscanf:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+ case LibFunc_vfscanf:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+ case LibFunc_valloc:
+ return (FTy.getReturnType()->isPointerTy());
+ case LibFunc_vprintf:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_vfprintf:
+ case LibFunc_vsprintf:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_vsprintf_chk:
+ return NumParams == 5 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isIntegerTy(32) &&
+ IsSizeTTy(FTy.getParamType(2)) && FTy.getParamType(3)->isPointerTy();
+ case LibFunc_vsnprintf:
+ return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+ case LibFunc_vsnprintf_chk:
+ return NumParams == 6 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(2)->isIntegerTy(32) &&
+ IsSizeTTy(FTy.getParamType(3)) && FTy.getParamType(4)->isPointerTy();
+ case LibFunc_open:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_opendir:
+ return (NumParams == 1 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy());
+ case LibFunc_tmpfile:
+ return (FTy.getReturnType()->isPointerTy());
+ case LibFunc_htonl:
+ case LibFunc_ntohl:
+ return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getReturnType() == FTy.getParamType(0));
+ case LibFunc_htons:
+ case LibFunc_ntohs:
+ return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(16) &&
+ FTy.getReturnType() == FTy.getParamType(0));
+ case LibFunc_lstat:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_lchown:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_qsort:
+ return (NumParams == 4 && FTy.getParamType(3)->isPointerTy());
+ case LibFunc_dunder_strdup:
+ case LibFunc_dunder_strndup:
+ return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy());
+ case LibFunc_dunder_strtok_r:
+ return (NumParams == 3 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_under_IO_putc:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_dunder_isoc99_scanf:
+ return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_stat64:
+ case LibFunc_lstat64:
+ case LibFunc_statvfs64:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_dunder_isoc99_sscanf:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_fopen64:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+ case LibFunc_tmpfile64:
+ return (FTy.getReturnType()->isPointerTy());
+ case LibFunc_fstat64:
+ case LibFunc_fstatvfs64:
+ return (NumParams == 2 && FTy.getParamType(1)->isPointerTy());
+ case LibFunc_open64:
+ return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy());
+ case LibFunc_gettimeofday:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy());
+
+ // new(unsigned int);
+ case LibFunc_Znwj:
+ // new(unsigned long);
+ case LibFunc_Znwm:
+ // new[](unsigned int);
+ case LibFunc_Znaj:
+ // new[](unsigned long);
+ case LibFunc_Znam:
+ // new(unsigned int);
+ case LibFunc_msvc_new_int:
+ // new(unsigned long long);
+ case LibFunc_msvc_new_longlong:
+ // new[](unsigned int);
+ case LibFunc_msvc_new_array_int:
+ // new[](unsigned long long);
+ case LibFunc_msvc_new_array_longlong:
+ return (NumParams == 1 && FTy.getReturnType()->isPointerTy());
+
+ // new(unsigned int, nothrow);
+ case LibFunc_ZnwjRKSt9nothrow_t:
+ // new(unsigned long, nothrow);
+ case LibFunc_ZnwmRKSt9nothrow_t:
+ // new[](unsigned int, nothrow);
+ case LibFunc_ZnajRKSt9nothrow_t:
+ // new[](unsigned long, nothrow);
+ case LibFunc_ZnamRKSt9nothrow_t:
+ // new(unsigned int, nothrow);
+ case LibFunc_msvc_new_int_nothrow:
+ // new(unsigned long long, nothrow);
+ case LibFunc_msvc_new_longlong_nothrow:
+ // new[](unsigned int, nothrow);
+ case LibFunc_msvc_new_array_int_nothrow:
+ // new[](unsigned long long, nothrow);
+ case LibFunc_msvc_new_array_longlong_nothrow:
+ // new(unsigned int, align_val_t)
+ case LibFunc_ZnwjSt11align_val_t:
+ // new(unsigned long, align_val_t)
+ case LibFunc_ZnwmSt11align_val_t:
+ // new[](unsigned int, align_val_t)
+ case LibFunc_ZnajSt11align_val_t:
+ // new[](unsigned long, align_val_t)
+ case LibFunc_ZnamSt11align_val_t:
+ return (NumParams == 2 && FTy.getReturnType()->isPointerTy());
+
+ // new(unsigned int, align_val_t, nothrow)
+ case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t:
+ // new(unsigned long, align_val_t, nothrow)
+ case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
+ // new[](unsigned int, align_val_t, nothrow)
+ case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t:
+ // new[](unsigned long, align_val_t, nothrow)
+ case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
+ return (NumParams == 3 && FTy.getReturnType()->isPointerTy());
+
+ // void operator delete[](void*);
+ case LibFunc_ZdaPv:
+ // void operator delete(void*);
+ case LibFunc_ZdlPv:
+ // void operator delete[](void*);
+ case LibFunc_msvc_delete_array_ptr32:
+ // void operator delete[](void*);
+ case LibFunc_msvc_delete_array_ptr64:
+ // void operator delete(void*);
+ case LibFunc_msvc_delete_ptr32:
+ // void operator delete(void*);
+ case LibFunc_msvc_delete_ptr64:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy());
+
+ // void operator delete[](void*, nothrow);
+ case LibFunc_ZdaPvRKSt9nothrow_t:
+ // void operator delete[](void*, unsigned int);
+ case LibFunc_ZdaPvj:
+ // void operator delete[](void*, unsigned long);
+ case LibFunc_ZdaPvm:
+ // void operator delete(void*, nothrow);
+ case LibFunc_ZdlPvRKSt9nothrow_t:
+ // void operator delete(void*, unsigned int);
+ case LibFunc_ZdlPvj:
+ // void operator delete(void*, unsigned long);
+ case LibFunc_ZdlPvm:
+ // void operator delete(void*, align_val_t)
+ case LibFunc_ZdlPvSt11align_val_t:
+ // void operator delete[](void*, align_val_t)
+ case LibFunc_ZdaPvSt11align_val_t:
+ // void operator delete[](void*, unsigned int);
+ case LibFunc_msvc_delete_array_ptr32_int:
+ // void operator delete[](void*, nothrow);
+ case LibFunc_msvc_delete_array_ptr32_nothrow:
+ // void operator delete[](void*, unsigned long long);
+ case LibFunc_msvc_delete_array_ptr64_longlong:
+ // void operator delete[](void*, nothrow);
+ case LibFunc_msvc_delete_array_ptr64_nothrow:
+ // void operator delete(void*, unsigned int);
+ case LibFunc_msvc_delete_ptr32_int:
+ // void operator delete(void*, nothrow);
+ case LibFunc_msvc_delete_ptr32_nothrow:
+ // void operator delete(void*, unsigned long long);
+ case LibFunc_msvc_delete_ptr64_longlong:
+ // void operator delete(void*, nothrow);
+ case LibFunc_msvc_delete_ptr64_nothrow:
+ return (NumParams == 2 && FTy.getParamType(0)->isPointerTy());
+
+ // void operator delete(void*, align_val_t, nothrow)
+ case LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t:
+ // void operator delete[](void*, align_val_t, nothrow)
+ case LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t:
+ return (NumParams == 3 && FTy.getParamType(0)->isPointerTy());
+
+ case LibFunc_memset_pattern16:
+ return (!FTy.isVarArg() && NumParams == 3 &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isIntegerTy());
+
+ case LibFunc_cxa_guard_abort:
+ case LibFunc_cxa_guard_acquire:
+ case LibFunc_cxa_guard_release:
+ case LibFunc_nvvm_reflect:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy());
+
+ case LibFunc_sincospi_stret:
+ case LibFunc_sincospif_stret:
+ return (NumParams == 1 && FTy.getParamType(0)->isFloatingPointTy());
+
+ case LibFunc_acos:
+ case LibFunc_acos_finite:
+ case LibFunc_acosf:
+ case LibFunc_acosf_finite:
+ case LibFunc_acosh:
+ case LibFunc_acosh_finite:
+ case LibFunc_acoshf:
+ case LibFunc_acoshf_finite:
+ case LibFunc_acoshl:
+ case LibFunc_acoshl_finite:
+ case LibFunc_acosl:
+ case LibFunc_acosl_finite:
+ case LibFunc_asin:
+ case LibFunc_asin_finite:
+ case LibFunc_asinf:
+ case LibFunc_asinf_finite:
+ case LibFunc_asinh:
+ case LibFunc_asinhf:
+ case LibFunc_asinhl:
+ case LibFunc_asinl:
+ case LibFunc_asinl_finite:
+ case LibFunc_atan:
+ case LibFunc_atanf:
+ case LibFunc_atanh:
+ case LibFunc_atanh_finite:
+ case LibFunc_atanhf:
+ case LibFunc_atanhf_finite:
+ case LibFunc_atanhl:
+ case LibFunc_atanhl_finite:
+ case LibFunc_atanl:
+ case LibFunc_cbrt:
+ case LibFunc_cbrtf:
+ case LibFunc_cbrtl:
+ case LibFunc_ceil:
+ case LibFunc_ceilf:
+ case LibFunc_ceill:
+ case LibFunc_cos:
+ case LibFunc_cosf:
+ case LibFunc_cosh:
+ case LibFunc_cosh_finite:
+ case LibFunc_coshf:
+ case LibFunc_coshf_finite:
+ case LibFunc_coshl:
+ case LibFunc_coshl_finite:
+ case LibFunc_cosl:
+ case LibFunc_exp10:
+ case LibFunc_exp10_finite:
+ case LibFunc_exp10f:
+ case LibFunc_exp10f_finite:
+ case LibFunc_exp10l:
+ case LibFunc_exp10l_finite:
+ case LibFunc_exp2:
+ case LibFunc_exp2_finite:
+ case LibFunc_exp2f:
+ case LibFunc_exp2f_finite:
+ case LibFunc_exp2l:
+ case LibFunc_exp2l_finite:
+ case LibFunc_exp:
+ case LibFunc_exp_finite:
+ case LibFunc_expf:
+ case LibFunc_expf_finite:
+ case LibFunc_expl:
+ case LibFunc_expl_finite:
+ case LibFunc_expm1:
+ case LibFunc_expm1f:
+ case LibFunc_expm1l:
+ case LibFunc_fabs:
+ case LibFunc_fabsf:
+ case LibFunc_fabsl:
+ case LibFunc_floor:
+ case LibFunc_floorf:
+ case LibFunc_floorl:
+ case LibFunc_log10:
+ case LibFunc_log10_finite:
+ case LibFunc_log10f:
+ case LibFunc_log10f_finite:
+ case LibFunc_log10l:
+ case LibFunc_log10l_finite:
+ case LibFunc_log1p:
+ case LibFunc_log1pf:
+ case LibFunc_log1pl:
+ case LibFunc_log2:
+ case LibFunc_log2_finite:
+ case LibFunc_log2f:
+ case LibFunc_log2f_finite:
+ case LibFunc_log2l:
+ case LibFunc_log2l_finite:
+ case LibFunc_log:
+ case LibFunc_log_finite:
+ case LibFunc_logb:
+ case LibFunc_logbf:
+ case LibFunc_logbl:
+ case LibFunc_logf:
+ case LibFunc_logf_finite:
+ case LibFunc_logl:
+ case LibFunc_logl_finite:
+ case LibFunc_nearbyint:
+ case LibFunc_nearbyintf:
+ case LibFunc_nearbyintl:
+ case LibFunc_rint:
+ case LibFunc_rintf:
+ case LibFunc_rintl:
+ case LibFunc_round:
+ case LibFunc_roundf:
+ case LibFunc_roundl:
+ case LibFunc_sin:
+ case LibFunc_sinf:
+ case LibFunc_sinh:
+ case LibFunc_sinh_finite:
+ case LibFunc_sinhf:
+ case LibFunc_sinhf_finite:
+ case LibFunc_sinhl:
+ case LibFunc_sinhl_finite:
+ case LibFunc_sinl:
+ case LibFunc_sqrt:
+ case LibFunc_sqrt_finite:
+ case LibFunc_sqrtf:
+ case LibFunc_sqrtf_finite:
+ case LibFunc_sqrtl:
+ case LibFunc_sqrtl_finite:
+ case LibFunc_tan:
+ case LibFunc_tanf:
+ case LibFunc_tanh:
+ case LibFunc_tanhf:
+ case LibFunc_tanhl:
+ case LibFunc_tanl:
+ case LibFunc_trunc:
+ case LibFunc_truncf:
+ case LibFunc_truncl:
+ return (NumParams == 1 && FTy.getReturnType()->isFloatingPointTy() &&
+ FTy.getReturnType() == FTy.getParamType(0));
+
+ case LibFunc_atan2:
+ case LibFunc_atan2_finite:
+ case LibFunc_atan2f:
+ case LibFunc_atan2f_finite:
+ case LibFunc_atan2l:
+ case LibFunc_atan2l_finite:
+ case LibFunc_fmin:
+ case LibFunc_fminf:
+ case LibFunc_fminl:
+ case LibFunc_fmax:
+ case LibFunc_fmaxf:
+ case LibFunc_fmaxl:
+ case LibFunc_fmod:
+ case LibFunc_fmodf:
+ case LibFunc_fmodl:
+ case LibFunc_copysign:
+ case LibFunc_copysignf:
+ case LibFunc_copysignl:
+ case LibFunc_pow:
+ case LibFunc_pow_finite:
+ case LibFunc_powf:
+ case LibFunc_powf_finite:
+ case LibFunc_powl:
+ case LibFunc_powl_finite:
+ return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() &&
+ FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getReturnType() == FTy.getParamType(1));
+
+ case LibFunc_ldexp:
+ case LibFunc_ldexpf:
+ case LibFunc_ldexpl:
+ return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() &&
+ FTy.getReturnType() == FTy.getParamType(0) &&
+ FTy.getParamType(1)->isIntegerTy(32));
+
+ case LibFunc_ffs:
+ case LibFunc_ffsl:
+ case LibFunc_ffsll:
+ case LibFunc_fls:
+ case LibFunc_flsl:
+ case LibFunc_flsll:
+ return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getParamType(0)->isIntegerTy());
+
+ case LibFunc_isdigit:
+ case LibFunc_isascii:
+ case LibFunc_toascii:
+ case LibFunc_putchar:
+ case LibFunc_putchar_unlocked:
+ return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getReturnType() == FTy.getParamType(0));
+
+ case LibFunc_abs:
+ case LibFunc_labs:
+ case LibFunc_llabs:
+ return (NumParams == 1 && FTy.getReturnType()->isIntegerTy() &&
+ FTy.getReturnType() == FTy.getParamType(0));
+
+ case LibFunc_cxa_atexit:
+ return (NumParams == 3 && FTy.getReturnType()->isIntegerTy() &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1)->isPointerTy() &&
+ FTy.getParamType(2)->isPointerTy());
+
+ case LibFunc_sinpi:
+ case LibFunc_cospi:
+ return (NumParams == 1 && FTy.getReturnType()->isDoubleTy() &&
+ FTy.getReturnType() == FTy.getParamType(0));
+
+ case LibFunc_sinpif:
+ case LibFunc_cospif:
+ return (NumParams == 1 && FTy.getReturnType()->isFloatTy() &&
+ FTy.getReturnType() == FTy.getParamType(0));
+
+ case LibFunc_strnlen:
+ return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(1) &&
+ FTy.getParamType(0) == PCharTy &&
+ FTy.getParamType(1) == SizeTTy);
+
+ case LibFunc_posix_memalign:
+ return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) &&
+ FTy.getParamType(0)->isPointerTy() &&
+ FTy.getParamType(1) == SizeTTy && FTy.getParamType(2) == SizeTTy);
+
+ case LibFunc_wcslen:
+ return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() &&
+ FTy.getReturnType()->isIntegerTy());
+
+ case LibFunc_cabs:
+ case LibFunc_cabsf:
+ case LibFunc_cabsl: {
+ Type* RetTy = FTy.getReturnType();
+ if (!RetTy->isFloatingPointTy())
+ return false;
+
+ // NOTE: These prototypes are target specific and currently support
+ // "complex" passed as an array or discrete real & imaginary parameters.
+ // Add other calling conventions to enable libcall optimizations.
+ if (NumParams == 1)
+ return (FTy.getParamType(0)->isArrayTy() &&
+ FTy.getParamType(0)->getArrayNumElements() == 2 &&
+ FTy.getParamType(0)->getArrayElementType() == RetTy);
+ else if (NumParams == 2)
+ return (FTy.getParamType(0) == RetTy && FTy.getParamType(1) == RetTy);
+ else
+ return false;
+ }
+ case LibFunc::NumLibFuncs:
+ case LibFunc::NotLibFunc:
+ break;
+ }
+
+ llvm_unreachable("Invalid libfunc");
+}
+
+bool TargetLibraryInfoImpl::getLibFunc(const Function &FDecl,
+ LibFunc &F) const {
+ // Intrinsics don't overlap w/libcalls; if our module has a large number of
+ // intrinsics, this ends up being an interesting compile time win since we
+ // avoid string normalization and comparison.
+ if (FDecl.isIntrinsic()) return false;
+
+ const DataLayout *DL =
+ FDecl.getParent() ? &FDecl.getParent()->getDataLayout() : nullptr;
+ return getLibFunc(FDecl.getName(), F) &&
+ isValidProtoForLibFunc(*FDecl.getFunctionType(), F, DL);
+}
+
+void TargetLibraryInfoImpl::disableAllFunctions() {
+ memset(AvailableArray, 0, sizeof(AvailableArray));
+}
+
+static bool compareByScalarFnName(const VecDesc &LHS, const VecDesc &RHS) {
+ return LHS.ScalarFnName < RHS.ScalarFnName;
+}
+
+static bool compareByVectorFnName(const VecDesc &LHS, const VecDesc &RHS) {
+ return LHS.VectorFnName < RHS.VectorFnName;
+}
+
+static bool compareWithScalarFnName(const VecDesc &LHS, StringRef S) {
+ return LHS.ScalarFnName < S;
+}
+
+static bool compareWithVectorFnName(const VecDesc &LHS, StringRef S) {
+ return LHS.VectorFnName < S;
+}
+
+void TargetLibraryInfoImpl::addVectorizableFunctions(ArrayRef<VecDesc> Fns) {
+ VectorDescs.insert(VectorDescs.end(), Fns.begin(), Fns.end());
+ llvm::sort(VectorDescs, compareByScalarFnName);
+
+ ScalarDescs.insert(ScalarDescs.end(), Fns.begin(), Fns.end());
+ llvm::sort(ScalarDescs, compareByVectorFnName);
+}
+
+void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib(
+ enum VectorLibrary VecLib) {
+ switch (VecLib) {
+ case Accelerate: {
+ const VecDesc VecFuncs[] = {
+ #define TLI_DEFINE_ACCELERATE_VECFUNCS
+ #include "llvm/Analysis/VecFuncs.def"
+ };
+ addVectorizableFunctions(VecFuncs);
+ break;
+ }
+ case MASSV: {
+ const VecDesc VecFuncs[] = {
+ #define TLI_DEFINE_MASSV_VECFUNCS
+ #include "llvm/Analysis/VecFuncs.def"
+ };
+ addVectorizableFunctions(VecFuncs);
+ break;
+ }
+ case SVML: {
+ const VecDesc VecFuncs[] = {
+ #define TLI_DEFINE_SVML_VECFUNCS
+ #include "llvm/Analysis/VecFuncs.def"
+ };
+ addVectorizableFunctions(VecFuncs);
+ break;
+ }
+ case NoLibrary:
+ break;
+ }
+}
+
+bool TargetLibraryInfoImpl::isFunctionVectorizable(StringRef funcName) const {
+ funcName = sanitizeFunctionName(funcName);
+ if (funcName.empty())
+ return false;
+
+ std::vector<VecDesc>::const_iterator I =
+ llvm::lower_bound(VectorDescs, funcName, compareWithScalarFnName);
+ return I != VectorDescs.end() && StringRef(I->ScalarFnName) == funcName;
+}
+
+StringRef TargetLibraryInfoImpl::getVectorizedFunction(StringRef F,
+ unsigned VF) const {
+ F = sanitizeFunctionName(F);
+ if (F.empty())
+ return F;
+ std::vector<VecDesc>::const_iterator I =
+ llvm::lower_bound(VectorDescs, F, compareWithScalarFnName);
+ while (I != VectorDescs.end() && StringRef(I->ScalarFnName) == F) {
+ if (I->VectorizationFactor == VF)
+ return I->VectorFnName;
+ ++I;
+ }
+ return StringRef();
+}
+
+StringRef TargetLibraryInfoImpl::getScalarizedFunction(StringRef F,
+ unsigned &VF) const {
+ F = sanitizeFunctionName(F);
+ if (F.empty())
+ return F;
+
+ std::vector<VecDesc>::const_iterator I =
+ llvm::lower_bound(ScalarDescs, F, compareWithVectorFnName);
+ if (I == VectorDescs.end() || StringRef(I->VectorFnName) != F)
+ return StringRef();
+ VF = I->VectorizationFactor;
+ return I->ScalarFnName;
+}
+
+TargetLibraryInfo TargetLibraryAnalysis::run(Function &F,
+ FunctionAnalysisManager &) {
+ if (PresetInfoImpl)
+ return TargetLibraryInfo(*PresetInfoImpl);
+
+ return TargetLibraryInfo(
+ lookupInfoImpl(Triple(F.getParent()->getTargetTriple())));
+}
+
+TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(const Triple &T) {
+ std::unique_ptr<TargetLibraryInfoImpl> &Impl =
+ Impls[T.normalize()];
+ if (!Impl)
+ Impl.reset(new TargetLibraryInfoImpl(T));
+
+ return *Impl;
+}
+
+unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const {
+ if (auto *ShortWChar = cast_or_null<ConstantAsMetadata>(
+ M.getModuleFlag("wchar_size")))
+ return cast<ConstantInt>(ShortWChar->getValue())->getZExtValue();
+ return 0;
+}
+
+TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
+ : ImmutablePass(ID), TLIImpl(), TLI(TLIImpl) {
+ initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass(const Triple &T)
+ : ImmutablePass(ID), TLIImpl(T), TLI(TLIImpl) {
+ initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass(
+ const TargetLibraryInfoImpl &TLIImpl)
+ : ImmutablePass(ID), TLIImpl(TLIImpl), TLI(this->TLIImpl) {
+ initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+AnalysisKey TargetLibraryAnalysis::Key;
+
+// Register the basic pass.
+INITIALIZE_PASS(TargetLibraryInfoWrapperPass, "targetlibinfo",
+ "Target Library Information", false, true)
+char TargetLibraryInfoWrapperPass::ID = 0;
+
+void TargetLibraryInfoWrapperPass::anchor() {}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
new file mode 100644
index 000000000000..c9c294873ea6
--- /dev/null
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -0,0 +1,1386 @@
+//===- llvm/Analysis/TargetTransformInfo.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/TargetTransformInfoImpl.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/LoopIterator.h"
+#include <utility>
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "tti"
+
+static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
+ cl::Hidden,
+ cl::desc("Recognize reduction patterns."));
+
+namespace {
+/// No-op implementation of the TTI interface using the utility base
+/// classes.
+///
+/// This is used when no target specific information is available.
+struct NoTTIImpl : TargetTransformInfoImplCRTPBase<NoTTIImpl> {
+ explicit NoTTIImpl(const DataLayout &DL)
+ : TargetTransformInfoImplCRTPBase<NoTTIImpl>(DL) {}
+};
+}
+
+bool HardwareLoopInfo::canAnalyze(LoopInfo &LI) {
+ // If the loop has irreducible control flow, it can not be converted to
+ // Hardware loop.
+ LoopBlocksRPO RPOT(L);
+ RPOT.perform(&LI);
+ if (containsIrreducibleCFG<const BasicBlock *>(RPOT, LI))
+ return false;
+ return true;
+}
+
+bool HardwareLoopInfo::isHardwareLoopCandidate(ScalarEvolution &SE,
+ LoopInfo &LI, DominatorTree &DT,
+ bool ForceNestedLoop,
+ bool ForceHardwareLoopPHI) {
+ SmallVector<BasicBlock *, 4> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ for (BasicBlock *BB : ExitingBlocks) {
+ // If we pass the updated counter back through a phi, we need to know
+ // which latch the updated value will be coming from.
+ if (!L->isLoopLatch(BB)) {
+ if (ForceHardwareLoopPHI || CounterInReg)
+ continue;
+ }
+
+ const SCEV *EC = SE.getExitCount(L, BB);
+ if (isa<SCEVCouldNotCompute>(EC))
+ continue;
+ if (const SCEVConstant *ConstEC = dyn_cast<SCEVConstant>(EC)) {
+ if (ConstEC->getValue()->isZero())
+ continue;
+ } else if (!SE.isLoopInvariant(EC, L))
+ continue;
+
+ if (SE.getTypeSizeInBits(EC->getType()) > CountType->getBitWidth())
+ continue;
+
+ // If this exiting block is contained in a nested loop, it is not eligible
+ // for insertion of the branch-and-decrement since the inner loop would
+ // end up messing up the value in the CTR.
+ if (!IsNestingLegal && LI.getLoopFor(BB) != L && !ForceNestedLoop)
+ continue;
+
+ // We now have a loop-invariant count of loop iterations (which is not the
+ // constant zero) for which we know that this loop will not exit via this
+ // existing block.
+
+ // We need to make sure that this block will run on every loop iteration.
+ // For this to be true, we must dominate all blocks with backedges. Such
+ // blocks are in-loop predecessors to the header block.
+ bool NotAlways = false;
+ for (BasicBlock *Pred : predecessors(L->getHeader())) {
+ if (!L->contains(Pred))
+ continue;
+
+ if (!DT.dominates(BB, Pred)) {
+ NotAlways = true;
+ break;
+ }
+ }
+
+ if (NotAlways)
+ continue;
+
+ // Make sure this blocks ends with a conditional branch.
+ Instruction *TI = BB->getTerminator();
+ if (!TI)
+ continue;
+
+ if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
+ if (!BI->isConditional())
+ continue;
+
+ ExitBranch = BI;
+ } else
+ continue;
+
+ // Note that this block may not be the loop latch block, even if the loop
+ // has a latch block.
+ ExitBlock = BB;
+ ExitCount = EC;
+ break;
+ }
+
+ if (!ExitBlock)
+ return false;
+ return true;
+}
+
+TargetTransformInfo::TargetTransformInfo(const DataLayout &DL)
+ : TTIImpl(new Model<NoTTIImpl>(NoTTIImpl(DL))) {}
+
+TargetTransformInfo::~TargetTransformInfo() {}
+
+TargetTransformInfo::TargetTransformInfo(TargetTransformInfo &&Arg)
+ : TTIImpl(std::move(Arg.TTIImpl)) {}
+
+TargetTransformInfo &TargetTransformInfo::operator=(TargetTransformInfo &&RHS) {
+ TTIImpl = std::move(RHS.TTIImpl);
+ return *this;
+}
+
+int TargetTransformInfo::getOperationCost(unsigned Opcode, Type *Ty,
+ Type *OpTy) const {
+ int Cost = TTIImpl->getOperationCost(Opcode, Ty, OpTy);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCallCost(FunctionType *FTy, int NumArgs,
+ const User *U) const {
+ int Cost = TTIImpl->getCallCost(FTy, NumArgs, U);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCallCost(const Function *F,
+ ArrayRef<const Value *> Arguments,
+ const User *U) const {
+ int Cost = TTIImpl->getCallCost(F, Arguments, U);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+unsigned TargetTransformInfo::getInliningThresholdMultiplier() const {
+ return TTIImpl->getInliningThresholdMultiplier();
+}
+
+int TargetTransformInfo::getInlinerVectorBonusPercent() const {
+ return TTIImpl->getInlinerVectorBonusPercent();
+}
+
+int TargetTransformInfo::getGEPCost(Type *PointeeType, const Value *Ptr,
+ ArrayRef<const Value *> Operands) const {
+ return TTIImpl->getGEPCost(PointeeType, Ptr, Operands);
+}
+
+int TargetTransformInfo::getExtCost(const Instruction *I,
+ const Value *Src) const {
+ return TTIImpl->getExtCost(I, Src);
+}
+
+int TargetTransformInfo::getIntrinsicCost(
+ Intrinsic::ID IID, Type *RetTy, ArrayRef<const Value *> Arguments,
+ const User *U) const {
+ int Cost = TTIImpl->getIntrinsicCost(IID, RetTy, Arguments, U);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+unsigned
+TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+ unsigned &JTSize) const {
+ return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize);
+}
+
+int TargetTransformInfo::getUserCost(const User *U,
+ ArrayRef<const Value *> Operands) const {
+ int Cost = TTIImpl->getUserCost(U, Operands);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+bool TargetTransformInfo::hasBranchDivergence() const {
+ return TTIImpl->hasBranchDivergence();
+}
+
+bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const {
+ return TTIImpl->isSourceOfDivergence(V);
+}
+
+bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const {
+ return TTIImpl->isAlwaysUniform(V);
+}
+
+unsigned TargetTransformInfo::getFlatAddressSpace() const {
+ return TTIImpl->getFlatAddressSpace();
+}
+
+bool TargetTransformInfo::collectFlatAddressOperands(
+ SmallVectorImpl<int> &OpIndexes, Intrinsic::ID IID) const {
+ return TTIImpl->collectFlatAddressOperands(OpIndexes, IID);
+}
+
+bool TargetTransformInfo::rewriteIntrinsicWithAddressSpace(
+ IntrinsicInst *II, Value *OldV, Value *NewV) const {
+ return TTIImpl->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
+}
+
+bool TargetTransformInfo::isLoweredToCall(const Function *F) const {
+ return TTIImpl->isLoweredToCall(F);
+}
+
+bool TargetTransformInfo::isHardwareLoopProfitable(
+ Loop *L, ScalarEvolution &SE, AssumptionCache &AC,
+ TargetLibraryInfo *LibInfo, HardwareLoopInfo &HWLoopInfo) const {
+ return TTIImpl->isHardwareLoopProfitable(L, SE, AC, LibInfo, HWLoopInfo);
+}
+
+void TargetTransformInfo::getUnrollingPreferences(
+ Loop *L, ScalarEvolution &SE, UnrollingPreferences &UP) const {
+ return TTIImpl->getUnrollingPreferences(L, SE, UP);
+}
+
+bool TargetTransformInfo::isLegalAddImmediate(int64_t Imm) const {
+ return TTIImpl->isLegalAddImmediate(Imm);
+}
+
+bool TargetTransformInfo::isLegalICmpImmediate(int64_t Imm) const {
+ return TTIImpl->isLegalICmpImmediate(Imm);
+}
+
+bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset,
+ bool HasBaseReg,
+ int64_t Scale,
+ unsigned AddrSpace,
+ Instruction *I) const {
+ return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg,
+ Scale, AddrSpace, I);
+}
+
+bool TargetTransformInfo::isLSRCostLess(LSRCost &C1, LSRCost &C2) const {
+ return TTIImpl->isLSRCostLess(C1, C2);
+}
+
+bool TargetTransformInfo::canMacroFuseCmp() const {
+ return TTIImpl->canMacroFuseCmp();
+}
+
+bool TargetTransformInfo::canSaveCmp(Loop *L, BranchInst **BI,
+ ScalarEvolution *SE, LoopInfo *LI,
+ DominatorTree *DT, AssumptionCache *AC,
+ TargetLibraryInfo *LibInfo) const {
+ return TTIImpl->canSaveCmp(L, BI, SE, LI, DT, AC, LibInfo);
+}
+
+bool TargetTransformInfo::shouldFavorPostInc() const {
+ return TTIImpl->shouldFavorPostInc();
+}
+
+bool TargetTransformInfo::shouldFavorBackedgeIndex(const Loop *L) const {
+ return TTIImpl->shouldFavorBackedgeIndex(L);
+}
+
+bool TargetTransformInfo::isLegalMaskedStore(Type *DataType,
+ MaybeAlign Alignment) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment);
+}
+
+bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType,
+ MaybeAlign Alignment) const {
+ return TTIImpl->isLegalMaskedLoad(DataType, Alignment);
+}
+
+bool TargetTransformInfo::isLegalNTStore(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalNTStore(DataType, Alignment);
+}
+
+bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const {
+ return TTIImpl->isLegalNTLoad(DataType, Alignment);
+}
+
+bool TargetTransformInfo::isLegalMaskedGather(Type *DataType) const {
+ return TTIImpl->isLegalMaskedGather(DataType);
+}
+
+bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType) const {
+ return TTIImpl->isLegalMaskedScatter(DataType);
+}
+
+bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const {
+ return TTIImpl->isLegalMaskedCompressStore(DataType);
+}
+
+bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
+ return TTIImpl->isLegalMaskedExpandLoad(DataType);
+}
+
+bool TargetTransformInfo::hasDivRemOp(Type *DataType, bool IsSigned) const {
+ return TTIImpl->hasDivRemOp(DataType, IsSigned);
+}
+
+bool TargetTransformInfo::hasVolatileVariant(Instruction *I,
+ unsigned AddrSpace) const {
+ return TTIImpl->hasVolatileVariant(I, AddrSpace);
+}
+
+bool TargetTransformInfo::prefersVectorizedAddressing() const {
+ return TTIImpl->prefersVectorizedAddressing();
+}
+
+int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset,
+ bool HasBaseReg,
+ int64_t Scale,
+ unsigned AddrSpace) const {
+ int Cost = TTIImpl->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg,
+ Scale, AddrSpace);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+bool TargetTransformInfo::LSRWithInstrQueries() const {
+ return TTIImpl->LSRWithInstrQueries();
+}
+
+bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const {
+ return TTIImpl->isTruncateFree(Ty1, Ty2);
+}
+
+bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const {
+ return TTIImpl->isProfitableToHoist(I);
+}
+
+bool TargetTransformInfo::useAA() const { return TTIImpl->useAA(); }
+
+bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
+ return TTIImpl->isTypeLegal(Ty);
+}
+
+bool TargetTransformInfo::shouldBuildLookupTables() const {
+ return TTIImpl->shouldBuildLookupTables();
+}
+bool TargetTransformInfo::shouldBuildLookupTablesForConstant(Constant *C) const {
+ return TTIImpl->shouldBuildLookupTablesForConstant(C);
+}
+
+bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
+ return TTIImpl->useColdCCForColdCall(F);
+}
+
+unsigned TargetTransformInfo::
+getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const {
+ return TTIImpl->getScalarizationOverhead(Ty, Insert, Extract);
+}
+
+unsigned TargetTransformInfo::
+getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
+ unsigned VF) const {
+ return TTIImpl->getOperandsScalarizationOverhead(Args, VF);
+}
+
+bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const {
+ return TTIImpl->supportsEfficientVectorElementLoadStore();
+}
+
+bool TargetTransformInfo::enableAggressiveInterleaving(bool LoopHasReductions) const {
+ return TTIImpl->enableAggressiveInterleaving(LoopHasReductions);
+}
+
+TargetTransformInfo::MemCmpExpansionOptions
+TargetTransformInfo::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
+ return TTIImpl->enableMemCmpExpansion(OptSize, IsZeroCmp);
+}
+
+bool TargetTransformInfo::enableInterleavedAccessVectorization() const {
+ return TTIImpl->enableInterleavedAccessVectorization();
+}
+
+bool TargetTransformInfo::enableMaskedInterleavedAccessVectorization() const {
+ return TTIImpl->enableMaskedInterleavedAccessVectorization();
+}
+
+bool TargetTransformInfo::isFPVectorizationPotentiallyUnsafe() const {
+ return TTIImpl->isFPVectorizationPotentiallyUnsafe();
+}
+
+bool TargetTransformInfo::allowsMisalignedMemoryAccesses(LLVMContext &Context,
+ unsigned BitWidth,
+ unsigned AddressSpace,
+ unsigned Alignment,
+ bool *Fast) const {
+ return TTIImpl->allowsMisalignedMemoryAccesses(Context, BitWidth, AddressSpace,
+ Alignment, Fast);
+}
+
+TargetTransformInfo::PopcntSupportKind
+TargetTransformInfo::getPopcntSupport(unsigned IntTyWidthInBit) const {
+ return TTIImpl->getPopcntSupport(IntTyWidthInBit);
+}
+
+bool TargetTransformInfo::haveFastSqrt(Type *Ty) const {
+ return TTIImpl->haveFastSqrt(Ty);
+}
+
+bool TargetTransformInfo::isFCmpOrdCheaperThanFCmpZero(Type *Ty) const {
+ return TTIImpl->isFCmpOrdCheaperThanFCmpZero(Ty);
+}
+
+int TargetTransformInfo::getFPOpCost(Type *Ty) const {
+ int Cost = TTIImpl->getFPOpCost(Ty);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx,
+ const APInt &Imm,
+ Type *Ty) const {
+ int Cost = TTIImpl->getIntImmCodeSizeCost(Opcode, Idx, Imm, Ty);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntImmCost(const APInt &Imm, Type *Ty) const {
+ int Cost = TTIImpl->getIntImmCost(Imm, Ty);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntImmCost(unsigned Opcode, unsigned Idx,
+ const APInt &Imm, Type *Ty) const {
+ int Cost = TTIImpl->getIntImmCost(Opcode, Idx, Imm, Ty);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntImmCost(Intrinsic::ID IID, unsigned Idx,
+ const APInt &Imm, Type *Ty) const {
+ int Cost = TTIImpl->getIntImmCost(IID, Idx, Imm, Ty);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
+ return TTIImpl->getNumberOfRegisters(ClassID);
+}
+
+unsigned TargetTransformInfo::getRegisterClassForType(bool Vector, Type *Ty) const {
+ return TTIImpl->getRegisterClassForType(Vector, Ty);
+}
+
+const char* TargetTransformInfo::getRegisterClassName(unsigned ClassID) const {
+ return TTIImpl->getRegisterClassName(ClassID);
+}
+
+unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const {
+ return TTIImpl->getRegisterBitWidth(Vector);
+}
+
+unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const {
+ return TTIImpl->getMinVectorRegisterBitWidth();
+}
+
+bool TargetTransformInfo::shouldMaximizeVectorBandwidth(bool OptSize) const {
+ return TTIImpl->shouldMaximizeVectorBandwidth(OptSize);
+}
+
+unsigned TargetTransformInfo::getMinimumVF(unsigned ElemWidth) const {
+ return TTIImpl->getMinimumVF(ElemWidth);
+}
+
+bool TargetTransformInfo::shouldConsiderAddressTypePromotion(
+ const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const {
+ return TTIImpl->shouldConsiderAddressTypePromotion(
+ I, AllowPromotionWithoutCommonHeader);
+}
+
+unsigned TargetTransformInfo::getCacheLineSize() const {
+ return TTIImpl->getCacheLineSize();
+}
+
+llvm::Optional<unsigned> TargetTransformInfo::getCacheSize(CacheLevel Level)
+ const {
+ return TTIImpl->getCacheSize(Level);
+}
+
+llvm::Optional<unsigned> TargetTransformInfo::getCacheAssociativity(
+ CacheLevel Level) const {
+ return TTIImpl->getCacheAssociativity(Level);
+}
+
+unsigned TargetTransformInfo::getPrefetchDistance() const {
+ return TTIImpl->getPrefetchDistance();
+}
+
+unsigned TargetTransformInfo::getMinPrefetchStride() const {
+ return TTIImpl->getMinPrefetchStride();
+}
+
+unsigned TargetTransformInfo::getMaxPrefetchIterationsAhead() const {
+ return TTIImpl->getMaxPrefetchIterationsAhead();
+}
+
+unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const {
+ return TTIImpl->getMaxInterleaveFactor(VF);
+}
+
+TargetTransformInfo::OperandValueKind
+TargetTransformInfo::getOperandInfo(Value *V, OperandValueProperties &OpProps) {
+ OperandValueKind OpInfo = OK_AnyValue;
+ OpProps = OP_None;
+
+ if (auto *CI = dyn_cast<ConstantInt>(V)) {
+ if (CI->getValue().isPowerOf2())
+ OpProps = OP_PowerOf2;
+ return OK_UniformConstantValue;
+ }
+
+ // A broadcast shuffle creates a uniform value.
+ // TODO: Add support for non-zero index broadcasts.
+ // TODO: Add support for different source vector width.
+ if (auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V))
+ if (ShuffleInst->isZeroEltSplat())
+ OpInfo = OK_UniformValue;
+
+ const Value *Splat = getSplatValue(V);
+
+ // Check for a splat of a constant or for a non uniform vector of constants
+ // and check if the constant(s) are all powers of two.
+ if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
+ OpInfo = OK_NonUniformConstantValue;
+ if (Splat) {
+ OpInfo = OK_UniformConstantValue;
+ if (auto *CI = dyn_cast<ConstantInt>(Splat))
+ if (CI->getValue().isPowerOf2())
+ OpProps = OP_PowerOf2;
+ } else if (auto *CDS = dyn_cast<ConstantDataSequential>(V)) {
+ OpProps = OP_PowerOf2;
+ for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
+ if (auto *CI = dyn_cast<ConstantInt>(CDS->getElementAsConstant(I)))
+ if (CI->getValue().isPowerOf2())
+ continue;
+ OpProps = OP_None;
+ break;
+ }
+ }
+ }
+
+ // Check for a splat of a uniform value. This is not loop aware, so return
+ // true only for the obviously uniform cases (argument, globalvalue)
+ if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat)))
+ OpInfo = OK_UniformValue;
+
+ return OpInfo;
+}
+
+int TargetTransformInfo::getArithmeticInstrCost(
+ unsigned Opcode, Type *Ty, OperandValueKind Opd1Info,
+ OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo,
+ OperandValueProperties Opd2PropInfo,
+ ArrayRef<const Value *> Args) const {
+ int Cost = TTIImpl->getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info,
+ Opd1PropInfo, Opd2PropInfo, Args);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Ty, int Index,
+ Type *SubTp) const {
+ int Cost = TTIImpl->getShuffleCost(Kind, Ty, Index, SubTp);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst,
+ Type *Src, const Instruction *I) const {
+ assert ((I == nullptr || I->getOpcode() == Opcode) &&
+ "Opcode should reflect passed instruction.");
+ int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
+ VectorType *VecTy,
+ unsigned Index) const {
+ int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const {
+ int Cost = TTIImpl->getCFInstrCost(Opcode);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
+ Type *CondTy, const Instruction *I) const {
+ assert ((I == nullptr || I->getOpcode() == Opcode) &&
+ "Opcode should reflect passed instruction.");
+ int Cost = TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy, I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getVectorInstrCost(unsigned Opcode, Type *Val,
+ unsigned Index) const {
+ int Cost = TTIImpl->getVectorInstrCost(Opcode, Val, Index);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getMemoryOpCost(unsigned Opcode, Type *Src,
+ unsigned Alignment,
+ unsigned AddressSpace,
+ const Instruction *I) const {
+ assert ((I == nullptr || I->getOpcode() == Opcode) &&
+ "Opcode should reflect passed instruction.");
+ int Cost = TTIImpl->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
+ unsigned Alignment,
+ unsigned AddressSpace) const {
+ int Cost =
+ TTIImpl->getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
+ Value *Ptr, bool VariableMask,
+ unsigned Alignment) const {
+ int Cost = TTIImpl->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getInterleavedMemoryOpCost(
+ unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
+ unsigned Alignment, unsigned AddressSpace, bool UseMaskForCond,
+ bool UseMaskForGaps) const {
+ int Cost = TTIImpl->getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
+ Alignment, AddressSpace,
+ UseMaskForCond,
+ UseMaskForGaps);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
+ ArrayRef<Type *> Tys, FastMathFlags FMF,
+ unsigned ScalarizationCostPassed) const {
+ int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF,
+ ScalarizationCostPassed);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
+ ArrayRef<Value *> Args, FastMathFlags FMF, unsigned VF) const {
+ int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getCallInstrCost(Function *F, Type *RetTy,
+ ArrayRef<Type *> Tys) const {
+ int Cost = TTIImpl->getCallInstrCost(F, RetTy, Tys);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const {
+ return TTIImpl->getNumberOfParts(Tp);
+}
+
+int TargetTransformInfo::getAddressComputationCost(Type *Tp,
+ ScalarEvolution *SE,
+ const SCEV *Ptr) const {
+ int Cost = TTIImpl->getAddressComputationCost(Tp, SE, Ptr);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getMemcpyCost(const Instruction *I) const {
+ int Cost = TTIImpl->getMemcpyCost(I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty,
+ bool IsPairwiseForm) const {
+ int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy,
+ bool IsPairwiseForm,
+ bool IsUnsigned) const {
+ int Cost =
+ TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
+unsigned
+TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const {
+ return TTIImpl->getCostOfKeepingLiveOverCall(Tys);
+}
+
+bool TargetTransformInfo::getTgtMemIntrinsic(IntrinsicInst *Inst,
+ MemIntrinsicInfo &Info) const {
+ return TTIImpl->getTgtMemIntrinsic(Inst, Info);
+}
+
+unsigned TargetTransformInfo::getAtomicMemIntrinsicMaxElementSize() const {
+ return TTIImpl->getAtomicMemIntrinsicMaxElementSize();
+}
+
+Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic(
+ IntrinsicInst *Inst, Type *ExpectedType) const {
+ return TTIImpl->getOrCreateResultFromMemIntrinsic(Inst, ExpectedType);
+}
+
+Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context,
+ Value *Length,
+ unsigned SrcAlign,
+ unsigned DestAlign) const {
+ return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAlign,
+ DestAlign);
+}
+
+void TargetTransformInfo::getMemcpyLoopResidualLoweringType(
+ SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
+ unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const {
+ TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
+ SrcAlign, DestAlign);
+}
+
+bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
+ const Function *Callee) const {
+ return TTIImpl->areInlineCompatible(Caller, Callee);
+}
+
+bool TargetTransformInfo::areFunctionArgsABICompatible(
+ const Function *Caller, const Function *Callee,
+ SmallPtrSetImpl<Argument *> &Args) const {
+ return TTIImpl->areFunctionArgsABICompatible(Caller, Callee, Args);
+}
+
+bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode,
+ Type *Ty) const {
+ return TTIImpl->isIndexedLoadLegal(Mode, Ty);
+}
+
+bool TargetTransformInfo::isIndexedStoreLegal(MemIndexedMode Mode,
+ Type *Ty) const {
+ return TTIImpl->isIndexedStoreLegal(Mode, Ty);
+}
+
+unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const {
+ return TTIImpl->getLoadStoreVecRegBitWidth(AS);
+}
+
+bool TargetTransformInfo::isLegalToVectorizeLoad(LoadInst *LI) const {
+ return TTIImpl->isLegalToVectorizeLoad(LI);
+}
+
+bool TargetTransformInfo::isLegalToVectorizeStore(StoreInst *SI) const {
+ return TTIImpl->isLegalToVectorizeStore(SI);
+}
+
+bool TargetTransformInfo::isLegalToVectorizeLoadChain(
+ unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const {
+ return TTIImpl->isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment,
+ AddrSpace);
+}
+
+bool TargetTransformInfo::isLegalToVectorizeStoreChain(
+ unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const {
+ return TTIImpl->isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment,
+ AddrSpace);
+}
+
+unsigned TargetTransformInfo::getLoadVectorFactor(unsigned VF,
+ unsigned LoadSize,
+ unsigned ChainSizeInBytes,
+ VectorType *VecTy) const {
+ return TTIImpl->getLoadVectorFactor(VF, LoadSize, ChainSizeInBytes, VecTy);
+}
+
+unsigned TargetTransformInfo::getStoreVectorFactor(unsigned VF,
+ unsigned StoreSize,
+ unsigned ChainSizeInBytes,
+ VectorType *VecTy) const {
+ return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy);
+}
+
+bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode,
+ Type *Ty, ReductionFlags Flags) const {
+ return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags);
+}
+
+bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const {
+ return TTIImpl->shouldExpandReduction(II);
+}
+
+unsigned TargetTransformInfo::getGISelRematGlobalCost() const {
+ return TTIImpl->getGISelRematGlobalCost();
+}
+
+int TargetTransformInfo::getInstructionLatency(const Instruction *I) const {
+ return TTIImpl->getInstructionLatency(I);
+}
+
+static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
+ unsigned Level) {
+ // We don't need a shuffle if we just want to have element 0 in position 0 of
+ // the vector.
+ if (!SI && Level == 0 && IsLeft)
+ return true;
+ else if (!SI)
+ return false;
+
+ SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
+
+ // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
+ // we look at the left or right side.
+ for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
+ Mask[i] = val;
+
+ SmallVector<int, 16> ActualMask = SI->getShuffleMask();
+ return Mask == ActualMask;
+}
+
+namespace {
+/// Kind of the reduction data.
+enum ReductionKind {
+ RK_None, /// Not a reduction.
+ RK_Arithmetic, /// Binary reduction data.
+ RK_MinMax, /// Min/max reduction data.
+ RK_UnsignedMinMax, /// Unsigned min/max reduction data.
+};
+/// Contains opcode + LHS/RHS parts of the reduction operations.
+struct ReductionData {
+ ReductionData() = delete;
+ ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
+ assert(Kind != RK_None && "expected binary or min/max reduction only.");
+ }
+ unsigned Opcode = 0;
+ Value *LHS = nullptr;
+ Value *RHS = nullptr;
+ ReductionKind Kind = RK_None;
+ bool hasSameData(ReductionData &RD) const {
+ return Kind == RD.Kind && Opcode == RD.Opcode;
+ }
+};
+} // namespace
+
+static Optional<ReductionData> getReductionData(Instruction *I) {
+ Value *L, *R;
+ if (m_BinOp(m_Value(L), m_Value(R)).match(I))
+ return ReductionData(RK_Arithmetic, I->getOpcode(), L, R);
+ if (auto *SI = dyn_cast<SelectInst>(I)) {
+ if (m_SMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_SMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_MinMax, CI->getOpcode(), L, R);
+ }
+ if (m_UMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R);
+ }
+ }
+ return llvm::None;
+}
+
+static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
+ unsigned Level,
+ unsigned NumLevels) {
+ // Match one level of pairwise operations.
+ // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
+ // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
+ // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
+ // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
+ // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
+ if (!I)
+ return RK_None;
+
+ assert(I->getType()->isVectorTy() && "Expecting a vector type");
+
+ Optional<ReductionData> RD = getReductionData(I);
+ if (!RD)
+ return RK_None;
+
+ ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS);
+ if (!LS && Level)
+ return RK_None;
+ ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS);
+ if (!RS && Level)
+ return RK_None;
+
+ // On level 0 we can omit one shufflevector instruction.
+ if (!Level && !RS && !LS)
+ return RK_None;
+
+ // Shuffle inputs must match.
+ Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
+ Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
+ Value *NextLevelOp = nullptr;
+ if (NextLevelOpR && NextLevelOpL) {
+ // If we have two shuffles their operands must match.
+ if (NextLevelOpL != NextLevelOpR)
+ return RK_None;
+
+ NextLevelOp = NextLevelOpL;
+ } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
+ // On the first level we can omit the shufflevector <0, undef,...>. So the
+ // input to the other shufflevector <1, undef> must match with one of the
+ // inputs to the current binary operation.
+ // Example:
+ // %NextLevelOpL = shufflevector %R, <1, undef ...>
+ // %BinOp = fadd %NextLevelOpL, %R
+ if (NextLevelOpL && NextLevelOpL != RD->RHS)
+ return RK_None;
+ else if (NextLevelOpR && NextLevelOpR != RD->LHS)
+ return RK_None;
+
+ NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS;
+ } else
+ return RK_None;
+
+ // Check that the next levels binary operation exists and matches with the
+ // current one.
+ if (Level + 1 != NumLevels) {
+ Optional<ReductionData> NextLevelRD =
+ getReductionData(cast<Instruction>(NextLevelOp));
+ if (!NextLevelRD || !RD->hasSameData(*NextLevelRD))
+ return RK_None;
+ }
+
+ // Shuffle mask for pairwise operation must match.
+ if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) {
+ if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level))
+ return RK_None;
+ } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) {
+ if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level))
+ return RK_None;
+ } else {
+ return RK_None;
+ }
+
+ if (++Level == NumLevels)
+ return RD->Kind;
+
+ // Match next level.
+ return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level,
+ NumLevels);
+}
+
+static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
+ if (!EnableReduxCost)
+ return RK_None;
+
+ // Need to extract the first element.
+ ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
+ unsigned Idx = ~0u;
+ if (CI)
+ Idx = CI->getZExtValue();
+ if (Idx != 0)
+ return RK_None;
+
+ auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
+ if (!RdxStart)
+ return RK_None;
+ Optional<ReductionData> RD = getReductionData(RdxStart);
+ if (!RD)
+ return RK_None;
+
+ Type *VecTy = RdxStart->getType();
+ unsigned NumVecElems = VecTy->getVectorNumElements();
+ if (!isPowerOf2_32(NumVecElems))
+ return RK_None;
+
+ // We look for a sequence of shuffle,shuffle,add triples like the following
+ // that builds a pairwise reduction tree.
+ //
+ // (X0, X1, X2, X3)
+ // (X0 + X1, X2 + X3, undef, undef)
+ // ((X0 + X1) + (X2 + X3), undef, undef, undef)
+ //
+ // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
+ // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
+ // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
+ // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
+ // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
+ // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
+ // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
+ // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
+ // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+ // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
+ // %r = extractelement <4 x float> %bin.rdx8, i32 0
+ if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) ==
+ RK_None)
+ return RK_None;
+
+ Opcode = RD->Opcode;
+ Ty = VecTy;
+
+ return RD->Kind;
+}
+
+static std::pair<Value *, ShuffleVectorInst *>
+getShuffleAndOtherOprd(Value *L, Value *R) {
+ ShuffleVectorInst *S = nullptr;
+
+ if ((S = dyn_cast<ShuffleVectorInst>(L)))
+ return std::make_pair(R, S);
+
+ S = dyn_cast<ShuffleVectorInst>(R);
+ return std::make_pair(L, S);
+}
+
+static ReductionKind
+matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
+ if (!EnableReduxCost)
+ return RK_None;
+
+ // Need to extract the first element.
+ ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
+ unsigned Idx = ~0u;
+ if (CI)
+ Idx = CI->getZExtValue();
+ if (Idx != 0)
+ return RK_None;
+
+ auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
+ if (!RdxStart)
+ return RK_None;
+ Optional<ReductionData> RD = getReductionData(RdxStart);
+ if (!RD)
+ return RK_None;
+
+ Type *VecTy = ReduxRoot->getOperand(0)->getType();
+ unsigned NumVecElems = VecTy->getVectorNumElements();
+ if (!isPowerOf2_32(NumVecElems))
+ return RK_None;
+
+ // We look for a sequence of shuffles and adds like the following matching one
+ // fadd, shuffle vector pair at a time.
+ //
+ // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
+ // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
+ // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
+ // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
+ // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+ // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
+ // %r = extractelement <4 x float> %bin.rdx8, i32 0
+
+ unsigned MaskStart = 1;
+ Instruction *RdxOp = RdxStart;
+ SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
+ unsigned NumVecElemsRemain = NumVecElems;
+ while (NumVecElemsRemain - 1) {
+ // Check for the right reduction operation.
+ if (!RdxOp)
+ return RK_None;
+ Optional<ReductionData> RDLevel = getReductionData(RdxOp);
+ if (!RDLevel || !RDLevel->hasSameData(*RD))
+ return RK_None;
+
+ Value *NextRdxOp;
+ ShuffleVectorInst *Shuffle;
+ std::tie(NextRdxOp, Shuffle) =
+ getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS);
+
+ // Check the current reduction operation and the shuffle use the same value.
+ if (Shuffle == nullptr)
+ return RK_None;
+ if (Shuffle->getOperand(0) != NextRdxOp)
+ return RK_None;
+
+ // Check that shuffle masks matches.
+ for (unsigned j = 0; j != MaskStart; ++j)
+ ShuffleMask[j] = MaskStart + j;
+ // Fill the rest of the mask with -1 for undef.
+ std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
+
+ SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
+ if (ShuffleMask != Mask)
+ return RK_None;
+
+ RdxOp = dyn_cast<Instruction>(NextRdxOp);
+ NumVecElemsRemain /= 2;
+ MaskStart *= 2;
+ }
+
+ Opcode = RD->Opcode;
+ Ty = VecTy;
+ return RD->Kind;
+}
+
+int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
+ switch (I->getOpcode()) {
+ case Instruction::GetElementPtr:
+ return getUserCost(I);
+
+ case Instruction::Ret:
+ case Instruction::PHI:
+ case Instruction::Br: {
+ return getCFInstrCost(I->getOpcode());
+ }
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::FDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor: {
+ TargetTransformInfo::OperandValueKind Op1VK, Op2VK;
+ TargetTransformInfo::OperandValueProperties Op1VP, Op2VP;
+ Op1VK = getOperandInfo(I->getOperand(0), Op1VP);
+ Op2VK = getOperandInfo(I->getOperand(1), Op2VP);
+ SmallVector<const Value *, 2> Operands(I->operand_values());
+ return getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, Op2VK,
+ Op1VP, Op2VP, Operands);
+ }
+ case Instruction::FNeg: {
+ TargetTransformInfo::OperandValueKind Op1VK, Op2VK;
+ TargetTransformInfo::OperandValueProperties Op1VP, Op2VP;
+ Op1VK = getOperandInfo(I->getOperand(0), Op1VP);
+ Op2VK = OK_AnyValue;
+ Op2VP = OP_None;
+ SmallVector<const Value *, 2> Operands(I->operand_values());
+ return getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, Op2VK,
+ Op1VP, Op2VP, Operands);
+ }
+ case Instruction::Select: {
+ const SelectInst *SI = cast<SelectInst>(I);
+ Type *CondTy = SI->getCondition()->getType();
+ return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, I);
+ }
+ case Instruction::ICmp:
+ case Instruction::FCmp: {
+ Type *ValTy = I->getOperand(0)->getType();
+ return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), I);
+ }
+ case Instruction::Store: {
+ const StoreInst *SI = cast<StoreInst>(I);
+ Type *ValTy = SI->getValueOperand()->getType();
+ return getMemoryOpCost(I->getOpcode(), ValTy,
+ SI->getAlignment(),
+ SI->getPointerAddressSpace(), I);
+ }
+ case Instruction::Load: {
+ const LoadInst *LI = cast<LoadInst>(I);
+ return getMemoryOpCost(I->getOpcode(), I->getType(),
+ LI->getAlignment(),
+ LI->getPointerAddressSpace(), I);
+ }
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::FPExt:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ case Instruction::Trunc:
+ case Instruction::FPTrunc:
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast: {
+ Type *SrcTy = I->getOperand(0)->getType();
+ return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, I);
+ }
+ case Instruction::ExtractElement: {
+ const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
+ ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
+ unsigned Idx = -1;
+ if (CI)
+ Idx = CI->getZExtValue();
+
+ // Try to match a reduction sequence (series of shufflevector and vector
+ // adds followed by a extractelement).
+ unsigned ReduxOpCode;
+ Type *ReduxType;
+
+ switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
+ return getArithmeticReductionCost(ReduxOpCode, ReduxType,
+ /*IsPairwiseForm=*/false);
+ case RK_MinMax:
+ return getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
+ }
+
+ switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
+ return getArithmeticReductionCost(ReduxOpCode, ReduxType,
+ /*IsPairwiseForm=*/true);
+ case RK_MinMax:
+ return getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
+ }
+
+ return getVectorInstrCost(I->getOpcode(),
+ EEI->getOperand(0)->getType(), Idx);
+ }
+ case Instruction::InsertElement: {
+ const InsertElementInst * IE = cast<InsertElementInst>(I);
+ ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
+ unsigned Idx = -1;
+ if (CI)
+ Idx = CI->getZExtValue();
+ return getVectorInstrCost(I->getOpcode(),
+ IE->getType(), Idx);
+ }
+ case Instruction::ExtractValue:
+ return 0; // Model all ExtractValue nodes as free.
+ case Instruction::ShuffleVector: {
+ const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
+ Type *Ty = Shuffle->getType();
+ Type *SrcTy = Shuffle->getOperand(0)->getType();
+
+ // TODO: Identify and add costs for insert subvector, etc.
+ int SubIndex;
+ if (Shuffle->isExtractSubvectorMask(SubIndex))
+ return TTIImpl->getShuffleCost(SK_ExtractSubvector, SrcTy, SubIndex, Ty);
+
+ if (Shuffle->changesLength())
+ return -1;
+
+ if (Shuffle->isIdentity())
+ return 0;
+
+ if (Shuffle->isReverse())
+ return TTIImpl->getShuffleCost(SK_Reverse, Ty, 0, nullptr);
+
+ if (Shuffle->isSelect())
+ return TTIImpl->getShuffleCost(SK_Select, Ty, 0, nullptr);
+
+ if (Shuffle->isTranspose())
+ return TTIImpl->getShuffleCost(SK_Transpose, Ty, 0, nullptr);
+
+ if (Shuffle->isZeroEltSplat())
+ return TTIImpl->getShuffleCost(SK_Broadcast, Ty, 0, nullptr);
+
+ if (Shuffle->isSingleSource())
+ return TTIImpl->getShuffleCost(SK_PermuteSingleSrc, Ty, 0, nullptr);
+
+ return TTIImpl->getShuffleCost(SK_PermuteTwoSrc, Ty, 0, nullptr);
+ }
+ case Instruction::Call:
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ SmallVector<Value *, 4> Args(II->arg_operands());
+
+ FastMathFlags FMF;
+ if (auto *FPMO = dyn_cast<FPMathOperator>(II))
+ FMF = FPMO->getFastMathFlags();
+
+ return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
+ Args, FMF);
+ }
+ return -1;
+ default:
+ // We don't have any information on this instruction.
+ return -1;
+ }
+}
+
+TargetTransformInfo::Concept::~Concept() {}
+
+TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
+
+TargetIRAnalysis::TargetIRAnalysis(
+ std::function<Result(const Function &)> TTICallback)
+ : TTICallback(std::move(TTICallback)) {}
+
+TargetIRAnalysis::Result TargetIRAnalysis::run(const Function &F,
+ FunctionAnalysisManager &) {
+ return TTICallback(F);
+}
+
+AnalysisKey TargetIRAnalysis::Key;
+
+TargetIRAnalysis::Result TargetIRAnalysis::getDefaultTTI(const Function &F) {
+ return Result(F.getParent()->getDataLayout());
+}
+
+// Register the basic pass.
+INITIALIZE_PASS(TargetTransformInfoWrapperPass, "tti",
+ "Target Transform Information", false, true)
+char TargetTransformInfoWrapperPass::ID = 0;
+
+void TargetTransformInfoWrapperPass::anchor() {}
+
+TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass()
+ : ImmutablePass(ID) {
+ initializeTargetTransformInfoWrapperPassPass(
+ *PassRegistry::getPassRegistry());
+}
+
+TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass(
+ TargetIRAnalysis TIRA)
+ : ImmutablePass(ID), TIRA(std::move(TIRA)) {
+ initializeTargetTransformInfoWrapperPassPass(
+ *PassRegistry::getPassRegistry());
+}
+
+TargetTransformInfo &TargetTransformInfoWrapperPass::getTTI(const Function &F) {
+ FunctionAnalysisManager DummyFAM;
+ TTI = TIRA.run(F, DummyFAM);
+ return *TTI;
+}
+
+ImmutablePass *
+llvm::createTargetTransformInfoWrapperPass(TargetIRAnalysis TIRA) {
+ return new TargetTransformInfoWrapperPass(std::move(TIRA));
+}
diff --git a/llvm/lib/Analysis/Trace.cpp b/llvm/lib/Analysis/Trace.cpp
new file mode 100644
index 000000000000..879c7172d038
--- /dev/null
+++ b/llvm/lib/Analysis/Trace.cpp
@@ -0,0 +1,53 @@
+//===- Trace.cpp - Implementation of Trace class --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This class represents a single trace of LLVM basic blocks. A trace is a
+// single entry, multiple exit, region of code that is often hot. Trace-based
+// optimizations treat traces almost like they are a large, strange, basic
+// block: because the trace path is assumed to be hot, optimizations for the
+// fall-through path are made at the expense of the non-fall-through paths.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Trace.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+Function *Trace::getFunction() const {
+ return getEntryBasicBlock()->getParent();
+}
+
+Module *Trace::getModule() const {
+ return getFunction()->getParent();
+}
+
+/// print - Write trace to output stream.
+void Trace::print(raw_ostream &O) const {
+ Function *F = getFunction();
+ O << "; Trace from function " << F->getName() << ", blocks:\n";
+ for (const_iterator i = begin(), e = end(); i != e; ++i) {
+ O << "; ";
+ (*i)->printAsOperand(O, true, getModule());
+ O << "\n";
+ }
+ O << "; Trace parent function: \n" << *F;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+/// dump - Debugger convenience method; writes trace to standard error
+/// output stream.
+LLVM_DUMP_METHOD void Trace::dump() const {
+ print(dbgs());
+}
+#endif
diff --git a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp
new file mode 100644
index 000000000000..3b9040aa0f52
--- /dev/null
+++ b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp
@@ -0,0 +1,740 @@
+//===- TypeBasedAliasAnalysis.cpp - Type-Based Alias Analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the TypeBasedAliasAnalysis pass, which implements
+// metadata-based TBAA.
+//
+// In LLVM IR, memory does not have types, so LLVM's own type system is not
+// suitable for doing TBAA. Instead, metadata is added to the IR to describe
+// a type system of a higher level language. This can be used to implement
+// typical C/C++ TBAA, but it can also be used to implement custom alias
+// analysis behavior for other languages.
+//
+// We now support two types of metadata format: scalar TBAA and struct-path
+// aware TBAA. After all testing cases are upgraded to use struct-path aware
+// TBAA and we can auto-upgrade existing bc files, the support for scalar TBAA
+// can be dropped.
+//
+// The scalar TBAA metadata format is very simple. TBAA MDNodes have up to
+// three fields, e.g.:
+// !0 = !{ !"an example type tree" }
+// !1 = !{ !"int", !0 }
+// !2 = !{ !"float", !0 }
+// !3 = !{ !"const float", !2, i64 1 }
+//
+// The first field is an identity field. It can be any value, usually
+// an MDString, which uniquely identifies the type. The most important
+// name in the tree is the name of the root node. Two trees with
+// different root node names are entirely disjoint, even if they
+// have leaves with common names.
+//
+// The second field identifies the type's parent node in the tree, or
+// is null or omitted for a root node. A type is considered to alias
+// all of its descendants and all of its ancestors in the tree. Also,
+// a type is considered to alias all types in other trees, so that
+// bitcode produced from multiple front-ends is handled conservatively.
+//
+// If the third field is present, it's an integer which if equal to 1
+// indicates that the type is "constant" (meaning pointsToConstantMemory
+// should return true; see
+// http://llvm.org/docs/AliasAnalysis.html#OtherItfs).
+//
+// With struct-path aware TBAA, the MDNodes attached to an instruction using
+// "!tbaa" are called path tag nodes.
+//
+// The path tag node has 4 fields with the last field being optional.
+//
+// The first field is the base type node, it can be a struct type node
+// or a scalar type node. The second field is the access type node, it
+// must be a scalar type node. The third field is the offset into the base type.
+// The last field has the same meaning as the last field of our scalar TBAA:
+// it's an integer which if equal to 1 indicates that the access is "constant".
+//
+// The struct type node has a name and a list of pairs, one pair for each member
+// of the struct. The first element of each pair is a type node (a struct type
+// node or a scalar type node), specifying the type of the member, the second
+// element of each pair is the offset of the member.
+//
+// Given an example
+// typedef struct {
+// short s;
+// } A;
+// typedef struct {
+// uint16_t s;
+// A a;
+// } B;
+//
+// For an access to B.a.s, we attach !5 (a path tag node) to the load/store
+// instruction. The base type is !4 (struct B), the access type is !2 (scalar
+// type short) and the offset is 4.
+//
+// !0 = !{!"Simple C/C++ TBAA"}
+// !1 = !{!"omnipotent char", !0} // Scalar type node
+// !2 = !{!"short", !1} // Scalar type node
+// !3 = !{!"A", !2, i64 0} // Struct type node
+// !4 = !{!"B", !2, i64 0, !3, i64 4}
+// // Struct type node
+// !5 = !{!4, !2, i64 4} // Path tag node
+//
+// The struct type nodes and the scalar type nodes form a type DAG.
+// Root (!0)
+// char (!1) -- edge to Root
+// short (!2) -- edge to char
+// A (!3) -- edge with offset 0 to short
+// B (!4) -- edge with offset 0 to short and edge with offset 4 to A
+//
+// To check if two tags (tagX and tagY) can alias, we start from the base type
+// of tagX, follow the edge with the correct offset in the type DAG and adjust
+// the offset until we reach the base type of tagY or until we reach the Root
+// node.
+// If we reach the base type of tagY, compare the adjusted offset with
+// offset of tagY, return Alias if the offsets are the same, return NoAlias
+// otherwise.
+// If we reach the Root node, perform the above starting from base type of tagY
+// to see if we reach base type of tagX.
+//
+// If they have different roots, they're part of different potentially
+// unrelated type systems, so we return Alias to be conservative.
+// If neither node is an ancestor of the other and they have the same root,
+// then we say NoAlias.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+#include <cstdint>
+
+using namespace llvm;
+
+// A handy option for disabling TBAA functionality. The same effect can also be
+// achieved by stripping the !tbaa tags from IR, but this option is sometimes
+// more convenient.
+static cl::opt<bool> EnableTBAA("enable-tbaa", cl::init(true), cl::Hidden);
+
+namespace {
+
+/// isNewFormatTypeNode - Return true iff the given type node is in the new
+/// size-aware format.
+static bool isNewFormatTypeNode(const MDNode *N) {
+ if (N->getNumOperands() < 3)
+ return false;
+ // In the old format the first operand is a string.
+ if (!isa<MDNode>(N->getOperand(0)))
+ return false;
+ return true;
+}
+
+/// This is a simple wrapper around an MDNode which provides a higher-level
+/// interface by hiding the details of how alias analysis information is encoded
+/// in its operands.
+template<typename MDNodeTy>
+class TBAANodeImpl {
+ MDNodeTy *Node = nullptr;
+
+public:
+ TBAANodeImpl() = default;
+ explicit TBAANodeImpl(MDNodeTy *N) : Node(N) {}
+
+ /// getNode - Get the MDNode for this TBAANode.
+ MDNodeTy *getNode() const { return Node; }
+
+ /// isNewFormat - Return true iff the wrapped type node is in the new
+ /// size-aware format.
+ bool isNewFormat() const { return isNewFormatTypeNode(Node); }
+
+ /// getParent - Get this TBAANode's Alias tree parent.
+ TBAANodeImpl<MDNodeTy> getParent() const {
+ if (isNewFormat())
+ return TBAANodeImpl(cast<MDNodeTy>(Node->getOperand(0)));
+
+ if (Node->getNumOperands() < 2)
+ return TBAANodeImpl<MDNodeTy>();
+ MDNodeTy *P = dyn_cast_or_null<MDNodeTy>(Node->getOperand(1));
+ if (!P)
+ return TBAANodeImpl<MDNodeTy>();
+ // Ok, this node has a valid parent. Return it.
+ return TBAANodeImpl<MDNodeTy>(P);
+ }
+
+ /// Test if this TBAANode represents a type for objects which are
+ /// not modified (by any means) in the context where this
+ /// AliasAnalysis is relevant.
+ bool isTypeImmutable() const {
+ if (Node->getNumOperands() < 3)
+ return false;
+ ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(2));
+ if (!CI)
+ return false;
+ return CI->getValue()[0];
+ }
+};
+
+/// \name Specializations of \c TBAANodeImpl for const and non const qualified
+/// \c MDNode.
+/// @{
+using TBAANode = TBAANodeImpl<const MDNode>;
+using MutableTBAANode = TBAANodeImpl<MDNode>;
+/// @}
+
+/// This is a simple wrapper around an MDNode which provides a
+/// higher-level interface by hiding the details of how alias analysis
+/// information is encoded in its operands.
+template<typename MDNodeTy>
+class TBAAStructTagNodeImpl {
+ /// This node should be created with createTBAAAccessTag().
+ MDNodeTy *Node;
+
+public:
+ explicit TBAAStructTagNodeImpl(MDNodeTy *N) : Node(N) {}
+
+ /// Get the MDNode for this TBAAStructTagNode.
+ MDNodeTy *getNode() const { return Node; }
+
+ /// isNewFormat - Return true iff the wrapped access tag is in the new
+ /// size-aware format.
+ bool isNewFormat() const {
+ if (Node->getNumOperands() < 4)
+ return false;
+ if (MDNodeTy *AccessType = getAccessType())
+ if (!TBAANodeImpl<MDNodeTy>(AccessType).isNewFormat())
+ return false;
+ return true;
+ }
+
+ MDNodeTy *getBaseType() const {
+ return dyn_cast_or_null<MDNode>(Node->getOperand(0));
+ }
+
+ MDNodeTy *getAccessType() const {
+ return dyn_cast_or_null<MDNode>(Node->getOperand(1));
+ }
+
+ uint64_t getOffset() const {
+ return mdconst::extract<ConstantInt>(Node->getOperand(2))->getZExtValue();
+ }
+
+ uint64_t getSize() const {
+ if (!isNewFormat())
+ return UINT64_MAX;
+ return mdconst::extract<ConstantInt>(Node->getOperand(3))->getZExtValue();
+ }
+
+ /// Test if this TBAAStructTagNode represents a type for objects
+ /// which are not modified (by any means) in the context where this
+ /// AliasAnalysis is relevant.
+ bool isTypeImmutable() const {
+ unsigned OpNo = isNewFormat() ? 4 : 3;
+ if (Node->getNumOperands() < OpNo + 1)
+ return false;
+ ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpNo));
+ if (!CI)
+ return false;
+ return CI->getValue()[0];
+ }
+};
+
+/// \name Specializations of \c TBAAStructTagNodeImpl for const and non const
+/// qualified \c MDNods.
+/// @{
+using TBAAStructTagNode = TBAAStructTagNodeImpl<const MDNode>;
+using MutableTBAAStructTagNode = TBAAStructTagNodeImpl<MDNode>;
+/// @}
+
+/// This is a simple wrapper around an MDNode which provides a
+/// higher-level interface by hiding the details of how alias analysis
+/// information is encoded in its operands.
+class TBAAStructTypeNode {
+ /// This node should be created with createTBAATypeNode().
+ const MDNode *Node = nullptr;
+
+public:
+ TBAAStructTypeNode() = default;
+ explicit TBAAStructTypeNode(const MDNode *N) : Node(N) {}
+
+ /// Get the MDNode for this TBAAStructTypeNode.
+ const MDNode *getNode() const { return Node; }
+
+ /// isNewFormat - Return true iff the wrapped type node is in the new
+ /// size-aware format.
+ bool isNewFormat() const { return isNewFormatTypeNode(Node); }
+
+ bool operator==(const TBAAStructTypeNode &Other) const {
+ return getNode() == Other.getNode();
+ }
+
+ /// getId - Return type identifier.
+ Metadata *getId() const {
+ return Node->getOperand(isNewFormat() ? 2 : 0);
+ }
+
+ unsigned getNumFields() const {
+ unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1;
+ unsigned NumOpsPerField = isNewFormat() ? 3 : 2;
+ return (getNode()->getNumOperands() - FirstFieldOpNo) / NumOpsPerField;
+ }
+
+ TBAAStructTypeNode getFieldType(unsigned FieldIndex) const {
+ unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1;
+ unsigned NumOpsPerField = isNewFormat() ? 3 : 2;
+ unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField;
+ auto *TypeNode = cast<MDNode>(getNode()->getOperand(OpIndex));
+ return TBAAStructTypeNode(TypeNode);
+ }
+
+ /// Get this TBAAStructTypeNode's field in the type DAG with
+ /// given offset. Update the offset to be relative to the field type.
+ TBAAStructTypeNode getField(uint64_t &Offset) const {
+ bool NewFormat = isNewFormat();
+ if (NewFormat) {
+ // New-format root and scalar type nodes have no fields.
+ if (Node->getNumOperands() < 6)
+ return TBAAStructTypeNode();
+ } else {
+ // Parent can be omitted for the root node.
+ if (Node->getNumOperands() < 2)
+ return TBAAStructTypeNode();
+
+ // Fast path for a scalar type node and a struct type node with a single
+ // field.
+ if (Node->getNumOperands() <= 3) {
+ uint64_t Cur = Node->getNumOperands() == 2
+ ? 0
+ : mdconst::extract<ConstantInt>(Node->getOperand(2))
+ ->getZExtValue();
+ Offset -= Cur;
+ MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(1));
+ if (!P)
+ return TBAAStructTypeNode();
+ return TBAAStructTypeNode(P);
+ }
+ }
+
+ // Assume the offsets are in order. We return the previous field if
+ // the current offset is bigger than the given offset.
+ unsigned FirstFieldOpNo = NewFormat ? 3 : 1;
+ unsigned NumOpsPerField = NewFormat ? 3 : 2;
+ unsigned TheIdx = 0;
+ for (unsigned Idx = FirstFieldOpNo; Idx < Node->getNumOperands();
+ Idx += NumOpsPerField) {
+ uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(Idx + 1))
+ ->getZExtValue();
+ if (Cur > Offset) {
+ assert(Idx >= FirstFieldOpNo + NumOpsPerField &&
+ "TBAAStructTypeNode::getField should have an offset match!");
+ TheIdx = Idx - NumOpsPerField;
+ break;
+ }
+ }
+ // Move along the last field.
+ if (TheIdx == 0)
+ TheIdx = Node->getNumOperands() - NumOpsPerField;
+ uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(TheIdx + 1))
+ ->getZExtValue();
+ Offset -= Cur;
+ MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(TheIdx));
+ if (!P)
+ return TBAAStructTypeNode();
+ return TBAAStructTypeNode(P);
+ }
+};
+
+} // end anonymous namespace
+
+/// Check the first operand of the tbaa tag node, if it is a MDNode, we treat
+/// it as struct-path aware TBAA format, otherwise, we treat it as scalar TBAA
+/// format.
+static bool isStructPathTBAA(const MDNode *MD) {
+ // Anonymous TBAA root starts with a MDNode and dragonegg uses it as
+ // a TBAA tag.
+ return isa<MDNode>(MD->getOperand(0)) && MD->getNumOperands() >= 3;
+}
+
+AliasResult TypeBasedAAResult::alias(const MemoryLocation &LocA,
+ const MemoryLocation &LocB,
+ AAQueryInfo &AAQI) {
+ if (!EnableTBAA)
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ // If accesses may alias, chain to the next AliasAnalysis.
+ if (Aliases(LocA.AATags.TBAA, LocB.AATags.TBAA))
+ return AAResultBase::alias(LocA, LocB, AAQI);
+
+ // Otherwise return a definitive result.
+ return NoAlias;
+}
+
+bool TypeBasedAAResult::pointsToConstantMemory(const MemoryLocation &Loc,
+ AAQueryInfo &AAQI,
+ bool OrLocal) {
+ if (!EnableTBAA)
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+
+ const MDNode *M = Loc.AATags.TBAA;
+ if (!M)
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+
+ // If this is an "immutable" type, we can assume the pointer is pointing
+ // to constant memory.
+ if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) ||
+ (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable()))
+ return true;
+
+ return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+}
+
+FunctionModRefBehavior
+TypeBasedAAResult::getModRefBehavior(const CallBase *Call) {
+ if (!EnableTBAA)
+ return AAResultBase::getModRefBehavior(Call);
+
+ FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior;
+
+ // If this is an "immutable" type, we can assume the call doesn't write
+ // to memory.
+ if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa))
+ if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) ||
+ (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable()))
+ Min = FMRB_OnlyReadsMemory;
+
+ return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min);
+}
+
+FunctionModRefBehavior TypeBasedAAResult::getModRefBehavior(const Function *F) {
+ // Functions don't have metadata. Just chain to the next implementation.
+ return AAResultBase::getModRefBehavior(F);
+}
+
+ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call,
+ const MemoryLocation &Loc,
+ AAQueryInfo &AAQI) {
+ if (!EnableTBAA)
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+
+ if (const MDNode *L = Loc.AATags.TBAA)
+ if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa))
+ if (!Aliases(L, M))
+ return ModRefInfo::NoModRef;
+
+ return AAResultBase::getModRefInfo(Call, Loc, AAQI);
+}
+
+ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call1,
+ const CallBase *Call2,
+ AAQueryInfo &AAQI) {
+ if (!EnableTBAA)
+ return AAResultBase::getModRefInfo(Call1, Call2, AAQI);
+
+ if (const MDNode *M1 = Call1->getMetadata(LLVMContext::MD_tbaa))
+ if (const MDNode *M2 = Call2->getMetadata(LLVMContext::MD_tbaa))
+ if (!Aliases(M1, M2))
+ return ModRefInfo::NoModRef;
+
+ return AAResultBase::getModRefInfo(Call1, Call2, AAQI);
+}
+
+bool MDNode::isTBAAVtableAccess() const {
+ if (!isStructPathTBAA(this)) {
+ if (getNumOperands() < 1)
+ return false;
+ if (MDString *Tag1 = dyn_cast<MDString>(getOperand(0))) {
+ if (Tag1->getString() == "vtable pointer")
+ return true;
+ }
+ return false;
+ }
+
+ // For struct-path aware TBAA, we use the access type of the tag.
+ TBAAStructTagNode Tag(this);
+ TBAAStructTypeNode AccessType(Tag.getAccessType());
+ if(auto *Id = dyn_cast<MDString>(AccessType.getId()))
+ if (Id->getString() == "vtable pointer")
+ return true;
+ return false;
+}
+
+static bool matchAccessTags(const MDNode *A, const MDNode *B,
+ const MDNode **GenericTag = nullptr);
+
+MDNode *MDNode::getMostGenericTBAA(MDNode *A, MDNode *B) {
+ const MDNode *GenericTag;
+ matchAccessTags(A, B, &GenericTag);
+ return const_cast<MDNode*>(GenericTag);
+}
+
+static const MDNode *getLeastCommonType(const MDNode *A, const MDNode *B) {
+ if (!A || !B)
+ return nullptr;
+
+ if (A == B)
+ return A;
+
+ SmallSetVector<const MDNode *, 4> PathA;
+ TBAANode TA(A);
+ while (TA.getNode()) {
+ if (PathA.count(TA.getNode()))
+ report_fatal_error("Cycle found in TBAA metadata.");
+ PathA.insert(TA.getNode());
+ TA = TA.getParent();
+ }
+
+ SmallSetVector<const MDNode *, 4> PathB;
+ TBAANode TB(B);
+ while (TB.getNode()) {
+ if (PathB.count(TB.getNode()))
+ report_fatal_error("Cycle found in TBAA metadata.");
+ PathB.insert(TB.getNode());
+ TB = TB.getParent();
+ }
+
+ int IA = PathA.size() - 1;
+ int IB = PathB.size() - 1;
+
+ const MDNode *Ret = nullptr;
+ while (IA >= 0 && IB >= 0) {
+ if (PathA[IA] == PathB[IB])
+ Ret = PathA[IA];
+ else
+ break;
+ --IA;
+ --IB;
+ }
+
+ return Ret;
+}
+
+void Instruction::getAAMetadata(AAMDNodes &N, bool Merge) const {
+ if (Merge)
+ N.TBAA =
+ MDNode::getMostGenericTBAA(N.TBAA, getMetadata(LLVMContext::MD_tbaa));
+ else
+ N.TBAA = getMetadata(LLVMContext::MD_tbaa);
+
+ if (Merge)
+ N.Scope = MDNode::getMostGenericAliasScope(
+ N.Scope, getMetadata(LLVMContext::MD_alias_scope));
+ else
+ N.Scope = getMetadata(LLVMContext::MD_alias_scope);
+
+ if (Merge)
+ N.NoAlias =
+ MDNode::intersect(N.NoAlias, getMetadata(LLVMContext::MD_noalias));
+ else
+ N.NoAlias = getMetadata(LLVMContext::MD_noalias);
+}
+
+static const MDNode *createAccessTag(const MDNode *AccessType) {
+ // If there is no access type or the access type is the root node, then
+ // we don't have any useful access tag to return.
+ if (!AccessType || AccessType->getNumOperands() < 2)
+ return nullptr;
+
+ Type *Int64 = IntegerType::get(AccessType->getContext(), 64);
+ auto *OffsetNode = ConstantAsMetadata::get(ConstantInt::get(Int64, 0));
+
+ if (TBAAStructTypeNode(AccessType).isNewFormat()) {
+ // TODO: Take access ranges into account when matching access tags and
+ // fix this code to generate actual access sizes for generic tags.
+ uint64_t AccessSize = UINT64_MAX;
+ auto *SizeNode =
+ ConstantAsMetadata::get(ConstantInt::get(Int64, AccessSize));
+ Metadata *Ops[] = {const_cast<MDNode*>(AccessType),
+ const_cast<MDNode*>(AccessType),
+ OffsetNode, SizeNode};
+ return MDNode::get(AccessType->getContext(), Ops);
+ }
+
+ Metadata *Ops[] = {const_cast<MDNode*>(AccessType),
+ const_cast<MDNode*>(AccessType),
+ OffsetNode};
+ return MDNode::get(AccessType->getContext(), Ops);
+}
+
+static bool hasField(TBAAStructTypeNode BaseType,
+ TBAAStructTypeNode FieldType) {
+ for (unsigned I = 0, E = BaseType.getNumFields(); I != E; ++I) {
+ TBAAStructTypeNode T = BaseType.getFieldType(I);
+ if (T == FieldType || hasField(T, FieldType))
+ return true;
+ }
+ return false;
+}
+
+/// Return true if for two given accesses, one of the accessed objects may be a
+/// subobject of the other. The \p BaseTag and \p SubobjectTag parameters
+/// describe the accesses to the base object and the subobject respectively.
+/// \p CommonType must be the metadata node describing the common type of the
+/// accessed objects. On return, \p MayAlias is set to true iff these accesses
+/// may alias and \p Generic, if not null, points to the most generic access
+/// tag for the given two.
+static bool mayBeAccessToSubobjectOf(TBAAStructTagNode BaseTag,
+ TBAAStructTagNode SubobjectTag,
+ const MDNode *CommonType,
+ const MDNode **GenericTag,
+ bool &MayAlias) {
+ // If the base object is of the least common type, then this may be an access
+ // to its subobject.
+ if (BaseTag.getAccessType() == BaseTag.getBaseType() &&
+ BaseTag.getAccessType() == CommonType) {
+ if (GenericTag)
+ *GenericTag = createAccessTag(CommonType);
+ MayAlias = true;
+ return true;
+ }
+
+ // If the access to the base object is through a field of the subobject's
+ // type, then this may be an access to that field. To check for that we start
+ // from the base type, follow the edge with the correct offset in the type DAG
+ // and adjust the offset until we reach the field type or until we reach the
+ // access type.
+ bool NewFormat = BaseTag.isNewFormat();
+ TBAAStructTypeNode BaseType(BaseTag.getBaseType());
+ uint64_t OffsetInBase = BaseTag.getOffset();
+
+ for (;;) {
+ // In the old format there is no distinction between fields and parent
+ // types, so in this case we consider all nodes up to the root.
+ if (!BaseType.getNode()) {
+ assert(!NewFormat && "Did not see access type in access path!");
+ break;
+ }
+
+ if (BaseType.getNode() == SubobjectTag.getBaseType()) {
+ bool SameMemberAccess = OffsetInBase == SubobjectTag.getOffset();
+ if (GenericTag) {
+ *GenericTag = SameMemberAccess ? SubobjectTag.getNode() :
+ createAccessTag(CommonType);
+ }
+ MayAlias = SameMemberAccess;
+ return true;
+ }
+
+ // With new-format nodes we stop at the access type.
+ if (NewFormat && BaseType.getNode() == BaseTag.getAccessType())
+ break;
+
+ // Follow the edge with the correct offset. Offset will be adjusted to
+ // be relative to the field type.
+ BaseType = BaseType.getField(OffsetInBase);
+ }
+
+ // If the base object has a direct or indirect field of the subobject's type,
+ // then this may be an access to that field. We need this to check now that
+ // we support aggregates as access types.
+ if (NewFormat) {
+ // TBAAStructTypeNode BaseAccessType(BaseTag.getAccessType());
+ TBAAStructTypeNode FieldType(SubobjectTag.getBaseType());
+ if (hasField(BaseType, FieldType)) {
+ if (GenericTag)
+ *GenericTag = createAccessTag(CommonType);
+ MayAlias = true;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+/// matchTags - Return true if the given couple of accesses are allowed to
+/// overlap. If \arg GenericTag is not null, then on return it points to the
+/// most generic access descriptor for the given two.
+static bool matchAccessTags(const MDNode *A, const MDNode *B,
+ const MDNode **GenericTag) {
+ if (A == B) {
+ if (GenericTag)
+ *GenericTag = A;
+ return true;
+ }
+
+ // Accesses with no TBAA information may alias with any other accesses.
+ if (!A || !B) {
+ if (GenericTag)
+ *GenericTag = nullptr;
+ return true;
+ }
+
+ // Verify that both input nodes are struct-path aware. Auto-upgrade should
+ // have taken care of this.
+ assert(isStructPathTBAA(A) && "Access A is not struct-path aware!");
+ assert(isStructPathTBAA(B) && "Access B is not struct-path aware!");
+
+ TBAAStructTagNode TagA(A), TagB(B);
+ const MDNode *CommonType = getLeastCommonType(TagA.getAccessType(),
+ TagB.getAccessType());
+
+ // If the final access types have different roots, they're part of different
+ // potentially unrelated type systems, so we must be conservative.
+ if (!CommonType) {
+ if (GenericTag)
+ *GenericTag = nullptr;
+ return true;
+ }
+
+ // If one of the accessed objects may be a subobject of the other, then such
+ // accesses may alias.
+ bool MayAlias;
+ if (mayBeAccessToSubobjectOf(/* BaseTag= */ TagA, /* SubobjectTag= */ TagB,
+ CommonType, GenericTag, MayAlias) ||
+ mayBeAccessToSubobjectOf(/* BaseTag= */ TagB, /* SubobjectTag= */ TagA,
+ CommonType, GenericTag, MayAlias))
+ return MayAlias;
+
+ // Otherwise, we've proved there's no alias.
+ if (GenericTag)
+ *GenericTag = createAccessTag(CommonType);
+ return false;
+}
+
+/// Aliases - Test whether the access represented by tag A may alias the
+/// access represented by tag B.
+bool TypeBasedAAResult::Aliases(const MDNode *A, const MDNode *B) const {
+ return matchAccessTags(A, B);
+}
+
+AnalysisKey TypeBasedAA::Key;
+
+TypeBasedAAResult TypeBasedAA::run(Function &F, FunctionAnalysisManager &AM) {
+ return TypeBasedAAResult();
+}
+
+char TypeBasedAAWrapperPass::ID = 0;
+INITIALIZE_PASS(TypeBasedAAWrapperPass, "tbaa", "Type-Based Alias Analysis",
+ false, true)
+
+ImmutablePass *llvm::createTypeBasedAAWrapperPass() {
+ return new TypeBasedAAWrapperPass();
+}
+
+TypeBasedAAWrapperPass::TypeBasedAAWrapperPass() : ImmutablePass(ID) {
+ initializeTypeBasedAAWrapperPassPass(*PassRegistry::getPassRegistry());
+}
+
+bool TypeBasedAAWrapperPass::doInitialization(Module &M) {
+ Result.reset(new TypeBasedAAResult());
+ return false;
+}
+
+bool TypeBasedAAWrapperPass::doFinalization(Module &M) {
+ Result.reset();
+ return false;
+}
+
+void TypeBasedAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+}
diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
new file mode 100644
index 000000000000..072d291f3f93
--- /dev/null
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -0,0 +1,161 @@
+//===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions that make it easier to manipulate type metadata
+// for devirtualization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+
+using namespace llvm;
+
+// Search for virtual calls that call FPtr and add them to DevirtCalls.
+static void
+findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
+ bool *HasNonCallUses, Value *FPtr, uint64_t Offset,
+ const CallInst *CI, DominatorTree &DT) {
+ for (const Use &U : FPtr->uses()) {
+ Instruction *User = cast<Instruction>(U.getUser());
+ // Ignore this instruction if it is not dominated by the type intrinsic
+ // being analyzed. Otherwise we may transform a call sharing the same
+ // vtable pointer incorrectly. Specifically, this situation can arise
+ // after indirect call promotion and inlining, where we may have uses
+ // of the vtable pointer guarded by a function pointer check, and a fallback
+ // indirect call.
+ if (!DT.dominates(CI, User))
+ continue;
+ if (isa<BitCastInst>(User)) {
+ findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
+ DT);
+ } else if (auto CI = dyn_cast<CallInst>(User)) {
+ DevirtCalls.push_back({Offset, CI});
+ } else if (auto II = dyn_cast<InvokeInst>(User)) {
+ DevirtCalls.push_back({Offset, II});
+ } else if (HasNonCallUses) {
+ *HasNonCallUses = true;
+ }
+ }
+}
+
+// Search for virtual calls that load from VPtr and add them to DevirtCalls.
+static void findLoadCallsAtConstantOffset(
+ const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr,
+ int64_t Offset, const CallInst *CI, DominatorTree &DT) {
+ for (const Use &U : VPtr->uses()) {
+ Value *User = U.getUser();
+ if (isa<BitCastInst>(User)) {
+ findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT);
+ } else if (isa<LoadInst>(User)) {
+ findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT);
+ } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
+ // Take into account the GEP offset.
+ if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
+ SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
+ int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
+ GEP->getSourceElementType(), Indices);
+ findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset,
+ CI, DT);
+ }
+ }
+ }
+}
+
+void llvm::findDevirtualizableCallsForTypeTest(
+ SmallVectorImpl<DevirtCallSite> &DevirtCalls,
+ SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
+ DominatorTree &DT) {
+ assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test);
+
+ const Module *M = CI->getParent()->getParent()->getParent();
+
+ // Find llvm.assume intrinsics for this llvm.type.test call.
+ for (const Use &CIU : CI->uses()) {
+ if (auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser())) {
+ Function *F = AssumeCI->getCalledFunction();
+ if (F && F->getIntrinsicID() == Intrinsic::assume)
+ Assumes.push_back(AssumeCI);
+ }
+ }
+
+ // If we found any, search for virtual calls based on %p and add them to
+ // DevirtCalls.
+ if (!Assumes.empty())
+ findLoadCallsAtConstantOffset(
+ M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT);
+}
+
+void llvm::findDevirtualizableCallsForTypeCheckedLoad(
+ SmallVectorImpl<DevirtCallSite> &DevirtCalls,
+ SmallVectorImpl<Instruction *> &LoadedPtrs,
+ SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
+ const CallInst *CI, DominatorTree &DT) {
+ assert(CI->getCalledFunction()->getIntrinsicID() ==
+ Intrinsic::type_checked_load);
+
+ auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ if (!Offset) {
+ HasNonCallUses = true;
+ return;
+ }
+
+ for (const Use &U : CI->uses()) {
+ auto CIU = U.getUser();
+ if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) {
+ if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) {
+ LoadedPtrs.push_back(EVI);
+ continue;
+ }
+ if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) {
+ Preds.push_back(EVI);
+ continue;
+ }
+ }
+ HasNonCallUses = true;
+ }
+
+ for (Value *LoadedPtr : LoadedPtrs)
+ findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
+ Offset->getZExtValue(), CI, DT);
+}
+
+Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M) {
+ if (I->getType()->isPointerTy()) {
+ if (Offset == 0)
+ return I;
+ return nullptr;
+ }
+
+ const DataLayout &DL = M.getDataLayout();
+
+ if (auto *C = dyn_cast<ConstantStruct>(I)) {
+ const StructLayout *SL = DL.getStructLayout(C->getType());
+ if (Offset >= SL->getSizeInBytes())
+ return nullptr;
+
+ unsigned Op = SL->getElementContainingOffset(Offset);
+ return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
+ Offset - SL->getElementOffset(Op), M);
+ }
+ if (auto *C = dyn_cast<ConstantArray>(I)) {
+ ArrayType *VTableTy = C->getType();
+ uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
+
+ unsigned Op = Offset / ElemSize;
+ if (Op >= C->getNumOperands())
+ return nullptr;
+
+ return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
+ Offset % ElemSize, M);
+ }
+ return nullptr;
+}
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
new file mode 100644
index 000000000000..6fd8ae63f5f0
--- /dev/null
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -0,0 +1,418 @@
+//===- VFABIDemangling.cpp - Vector Function ABI demangling utilities. ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/VectorUtils.h"
+
+using namespace llvm;
+
+namespace {
+/// Utilities for the Vector Function ABI name parser.
+
+/// Return types for the parser functions.
+enum class ParseRet {
+ OK, // Found.
+ None, // Not found.
+ Error // Syntax error.
+};
+
+/// Extracts the `<isa>` information from the mangled string, and
+/// sets the `ISA` accordingly.
+ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
+ if (MangledName.empty())
+ return ParseRet::Error;
+
+ ISA = StringSwitch<VFISAKind>(MangledName.take_front(1))
+ .Case("n", VFISAKind::AdvancedSIMD)
+ .Case("s", VFISAKind::SVE)
+ .Case("b", VFISAKind::SSE)
+ .Case("c", VFISAKind::AVX)
+ .Case("d", VFISAKind::AVX2)
+ .Case("e", VFISAKind::AVX512)
+ .Default(VFISAKind::Unknown);
+
+ MangledName = MangledName.drop_front(1);
+
+ return ParseRet::OK;
+}
+
+/// Extracts the `<mask>` information from the mangled string, and
+/// sets `IsMasked` accordingly. The input string `MangledName` is
+/// left unmodified.
+ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
+ if (MangledName.consume_front("M")) {
+ IsMasked = true;
+ return ParseRet::OK;
+ }
+
+ if (MangledName.consume_front("N")) {
+ IsMasked = false;
+ return ParseRet::OK;
+ }
+
+ return ParseRet::Error;
+}
+
+/// Extract the `<vlen>` information from the mangled string, and
+/// sets `VF` accordingly. A `<vlen> == "x"` token is interpreted as a scalable
+/// vector length. On success, the `<vlen>` token is removed from
+/// the input string `ParseString`.
+///
+ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
+ if (ParseString.consume_front("x")) {
+ VF = 0;
+ IsScalable = true;
+ return ParseRet::OK;
+ }
+
+ if (ParseString.consumeInteger(10, VF))
+ return ParseRet::Error;
+
+ IsScalable = false;
+ return ParseRet::OK;
+}
+
+/// The function looks for the following strings at the beginning of
+/// the input string `ParseString`:
+///
+/// <token> <number>
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `Pos` to
+/// <number>, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+///
+/// The function expects <token> to be one of "ls", "Rs", "Us" or
+/// "Ls".
+ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
+ VFParamKind &PKind, int &Pos,
+ const StringRef Token) {
+ if (ParseString.consume_front(Token)) {
+ PKind = VFABI::getVFParamKindFromString(Token);
+ if (ParseString.consumeInteger(10, Pos))
+ return ParseRet::Error;
+ return ParseRet::OK;
+ }
+
+ return ParseRet::None;
+}
+
+/// The function looks for the following stringt at the beginning of
+/// the input string `ParseString`:
+///
+/// <token> <number>
+///
+/// <token> is one of "ls", "Rs", "Us" or "Ls".
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `StepOrPos` to
+/// <number>, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
+ VFParamKind &PKind, int &StepOrPos) {
+ ParseRet Ret;
+
+ // "ls" <RuntimeStepPos>
+ Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "ls");
+ if (Ret != ParseRet::None)
+ return Ret;
+
+ // "Rs" <RuntimeStepPos>
+ Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Rs");
+ if (Ret != ParseRet::None)
+ return Ret;
+
+ // "Ls" <RuntimeStepPos>
+ Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Ls");
+ if (Ret != ParseRet::None)
+ return Ret;
+
+ // "Us" <RuntimeStepPos>
+ Ret = tryParseLinearTokenWithRuntimeStep(ParseString, PKind, StepOrPos, "Us");
+ if (Ret != ParseRet::None)
+ return Ret;
+
+ return ParseRet::None;
+}
+
+/// The function looks for the following strings at the beginning of
+/// the input string `ParseString`:
+///
+/// <token> {"n"} <number>
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `LinearStep` to
+/// <number>, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+///
+/// The function expects <token> to be one of "l", "R", "U" or
+/// "L".
+ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
+ VFParamKind &PKind, int &LinearStep,
+ const StringRef Token) {
+ if (ParseString.consume_front(Token)) {
+ PKind = VFABI::getVFParamKindFromString(Token);
+ const bool Negate = ParseString.consume_front("n");
+ if (ParseString.consumeInteger(10, LinearStep))
+ LinearStep = 1;
+ if (Negate)
+ LinearStep *= -1;
+ return ParseRet::OK;
+ }
+
+ return ParseRet::None;
+}
+
+/// The function looks for the following strings at the beginning of
+/// the input string `ParseString`:
+///
+/// ["l" | "R" | "U" | "L"] {"n"} <number>
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `LinearStep` to
+/// <number>, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
+ VFParamKind &PKind, int &StepOrPos) {
+ // "l" {"n"} <CompileTimeStep>
+ if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "l") ==
+ ParseRet::OK)
+ return ParseRet::OK;
+
+ // "R" {"n"} <CompileTimeStep>
+ if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "R") ==
+ ParseRet::OK)
+ return ParseRet::OK;
+
+ // "L" {"n"} <CompileTimeStep>
+ if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "L") ==
+ ParseRet::OK)
+ return ParseRet::OK;
+
+ // "U" {"n"} <CompileTimeStep>
+ if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "U") ==
+ ParseRet::OK)
+ return ParseRet::OK;
+
+ return ParseRet::None;
+}
+
+/// The function looks for the following strings at the beginning of
+/// the input string `ParseString`:
+///
+/// "u" <number>
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `Pos` to
+/// <number>, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+ParseRet tryParseUniform(StringRef &ParseString, VFParamKind &PKind, int &Pos) {
+ // "u" <Pos>
+ const char *UniformToken = "u";
+ if (ParseString.consume_front(UniformToken)) {
+ PKind = VFABI::getVFParamKindFromString(UniformToken);
+ if (ParseString.consumeInteger(10, Pos))
+ return ParseRet::Error;
+
+ return ParseRet::OK;
+ }
+ return ParseRet::None;
+}
+
+/// Looks into the <parameters> part of the mangled name in search
+/// for valid paramaters at the beginning of the string
+/// `ParseString`.
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `StepOrPos`
+/// accordingly, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
+ int &StepOrPos) {
+ if (ParseString.consume_front("v")) {
+ PKind = VFParamKind::Vector;
+ StepOrPos = 0;
+ return ParseRet::OK;
+ }
+
+ const ParseRet HasLinearRuntime =
+ tryParseLinearWithRuntimeStep(ParseString, PKind, StepOrPos);
+ if (HasLinearRuntime != ParseRet::None)
+ return HasLinearRuntime;
+
+ const ParseRet HasLinearCompileTime =
+ tryParseLinearWithCompileTimeStep(ParseString, PKind, StepOrPos);
+ if (HasLinearCompileTime != ParseRet::None)
+ return HasLinearCompileTime;
+
+ const ParseRet HasUniform = tryParseUniform(ParseString, PKind, StepOrPos);
+ if (HasUniform != ParseRet::None)
+ return HasUniform;
+
+ return ParseRet::None;
+}
+
+/// Looks into the <parameters> part of the mangled name in search
+/// of a valid 'aligned' clause. The function should be invoked
+/// after parsing a parameter via `tryParseParameter`.
+///
+/// On success, it removes the parsed parameter from `ParseString`,
+/// sets `PKind` to the correspondent enum value, sets `StepOrPos`
+/// accordingly, and return success. On a syntax error, it return a
+/// parsing error. If nothing is parsed, it returns None.
+ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
+ uint64_t Val;
+ // "a" <number>
+ if (ParseString.consume_front("a")) {
+ if (ParseString.consumeInteger(10, Val))
+ return ParseRet::Error;
+
+ if (!isPowerOf2_64(Val))
+ return ParseRet::Error;
+
+ Alignment = Align(Val);
+
+ return ParseRet::OK;
+ }
+
+ return ParseRet::None;
+}
+} // namespace
+
+// Format of the ABI name:
+// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
+Optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName) {
+ // Assume there is no custom name <redirection>, and therefore the
+ // vector name consists of
+ // _ZGV<isa><mask><vlen><parameters>_<scalarname>.
+ StringRef VectorName = MangledName;
+
+ // Parse the fixed size part of the manled name
+ if (!MangledName.consume_front("_ZGV"))
+ return None;
+
+ // Extract ISA. An unknow ISA is also supported, so we accept all
+ // values.
+ VFISAKind ISA;
+ if (tryParseISA(MangledName, ISA) != ParseRet::OK)
+ return None;
+
+ // Extract <mask>.
+ bool IsMasked;
+ if (tryParseMask(MangledName, IsMasked) != ParseRet::OK)
+ return None;
+
+ // Parse the variable size, starting from <vlen>.
+ unsigned VF;
+ bool IsScalable;
+ if (tryParseVLEN(MangledName, VF, IsScalable) != ParseRet::OK)
+ return None;
+
+ // Parse the <parameters>.
+ ParseRet ParamFound;
+ SmallVector<VFParameter, 8> Parameters;
+ do {
+ const unsigned ParameterPos = Parameters.size();
+ VFParamKind PKind;
+ int StepOrPos;
+ ParamFound = tryParseParameter(MangledName, PKind, StepOrPos);
+
+ // Bail off if there is a parsing error in the parsing of the parameter.
+ if (ParamFound == ParseRet::Error)
+ return None;
+
+ if (ParamFound == ParseRet::OK) {
+ Align Alignment;
+ // Look for the alignment token "a <number>".
+ const ParseRet AlignFound = tryParseAlign(MangledName, Alignment);
+ // Bail off if there is a syntax error in the align token.
+ if (AlignFound == ParseRet::Error)
+ return None;
+
+ // Add the parameter.
+ Parameters.push_back({ParameterPos, PKind, StepOrPos, Alignment});
+ }
+ } while (ParamFound == ParseRet::OK);
+
+ // A valid MangledName mus have at least one valid entry in the
+ // <parameters>.
+ if (Parameters.empty())
+ return None;
+
+ // Check for the <scalarname> and the optional <redirection>, which
+ // are separated from the prefix with "_"
+ if (!MangledName.consume_front("_"))
+ return None;
+
+ // The rest of the string must be in the format:
+ // <scalarname>[(<redirection>)]
+ const StringRef ScalarName =
+ MangledName.take_while([](char In) { return In != '('; });
+
+ if (ScalarName.empty())
+ return None;
+
+ // Reduce MangledName to [(<redirection>)].
+ MangledName = MangledName.ltrim(ScalarName);
+ // Find the optional custom name redirection.
+ if (MangledName.consume_front("(")) {
+ if (!MangledName.consume_back(")"))
+ return None;
+ // Update the vector variant with the one specified by the user.
+ VectorName = MangledName;
+ // If the vector name is missing, bail out.
+ if (VectorName.empty())
+ return None;
+ }
+
+ // When <mask> is "M", we need to add a parameter that is used as
+ // global predicate for the function.
+ if (IsMasked) {
+ const unsigned Pos = Parameters.size();
+ Parameters.push_back({Pos, VFParamKind::GlobalPredicate});
+ }
+
+ // Asserts for parameters of type `VFParamKind::GlobalPredicate`, as
+ // prescribed by the Vector Function ABI specifications supported by
+ // this parser:
+ // 1. Uniqueness.
+ // 2. Must be the last in the parameter list.
+ const auto NGlobalPreds = std::count_if(
+ Parameters.begin(), Parameters.end(), [](const VFParameter PK) {
+ return PK.ParamKind == VFParamKind::GlobalPredicate;
+ });
+ assert(NGlobalPreds < 2 && "Cannot have more than one global predicate.");
+ if (NGlobalPreds)
+ assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate &&
+ "The global predicate must be the last parameter");
+
+ const VFShape Shape({VF, IsScalable, ISA, Parameters});
+ return VFInfo({Shape, ScalarName, VectorName});
+}
+
+VFParamKind VFABI::getVFParamKindFromString(const StringRef Token) {
+ const VFParamKind ParamKind = StringSwitch<VFParamKind>(Token)
+ .Case("v", VFParamKind::Vector)
+ .Case("l", VFParamKind::OMP_Linear)
+ .Case("R", VFParamKind::OMP_LinearRef)
+ .Case("L", VFParamKind::OMP_LinearVal)
+ .Case("U", VFParamKind::OMP_LinearUVal)
+ .Case("ls", VFParamKind::OMP_LinearPos)
+ .Case("Ls", VFParamKind::OMP_LinearValPos)
+ .Case("Rs", VFParamKind::OMP_LinearRefPos)
+ .Case("Us", VFParamKind::OMP_LinearUValPos)
+ .Case("u", VFParamKind::OMP_Uniform)
+ .Default(VFParamKind::Unknown);
+
+ if (ParamKind != VFParamKind::Unknown)
+ return ParamKind;
+
+ // This function should never be invoked with an invalid input.
+ llvm_unreachable("This fuction should be invoken only on parameters"
+ " that have a textual representation in the mangled name"
+ " of the Vector Function ABI");
+}
diff --git a/llvm/lib/Analysis/ValueLattice.cpp b/llvm/lib/Analysis/ValueLattice.cpp
new file mode 100644
index 000000000000..a0115a0eec36
--- /dev/null
+++ b/llvm/lib/Analysis/ValueLattice.cpp
@@ -0,0 +1,25 @@
+//===- ValueLattice.cpp - Value constraint analysis -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ValueLattice.h"
+
+namespace llvm {
+raw_ostream &operator<<(raw_ostream &OS, const ValueLatticeElement &Val) {
+ if (Val.isUndefined())
+ return OS << "undefined";
+ if (Val.isOverdefined())
+ return OS << "overdefined";
+
+ if (Val.isNotConstant())
+ return OS << "notconstant<" << *Val.getNotConstant() << ">";
+ if (Val.isConstantRange())
+ return OS << "constantrange<" << Val.getConstantRange().getLower() << ", "
+ << Val.getConstantRange().getUpper() << ">";
+ return OS << "constant<" << *Val.getConstant() << ">";
+}
+} // end namespace llvm
diff --git a/llvm/lib/Analysis/ValueLatticeUtils.cpp b/llvm/lib/Analysis/ValueLatticeUtils.cpp
new file mode 100644
index 000000000000..3f9287e26ce7
--- /dev/null
+++ b/llvm/lib/Analysis/ValueLatticeUtils.cpp
@@ -0,0 +1,43 @@
+//===-- ValueLatticeUtils.cpp - Utils for solving lattices ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common functions useful for performing data-flow
+// analyses that propagate values across function boundaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ValueLatticeUtils.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instructions.h"
+using namespace llvm;
+
+bool llvm::canTrackArgumentsInterprocedurally(Function *F) {
+ return F->hasLocalLinkage() && !F->hasAddressTaken();
+}
+
+bool llvm::canTrackReturnsInterprocedurally(Function *F) {
+ return F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked);
+}
+
+bool llvm::canTrackGlobalVariableInterprocedurally(GlobalVariable *GV) {
+ if (GV->isConstant() || !GV->hasLocalLinkage() ||
+ !GV->hasDefinitiveInitializer())
+ return false;
+ return !any_of(GV->users(), [&](User *U) {
+ if (auto *Store = dyn_cast<StoreInst>(U)) {
+ if (Store->getValueOperand() == GV || Store->isVolatile())
+ return true;
+ } else if (auto *Load = dyn_cast<LoadInst>(U)) {
+ if (Load->isVolatile())
+ return true;
+ } else {
+ return true;
+ }
+ return false;
+ });
+}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
new file mode 100644
index 000000000000..bbf389991836
--- /dev/null
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -0,0 +1,5820 @@
+//===- ValueTracking.cpp - Walk computations to compute properties --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains routines that help analyze properties that chains of
+// computations have.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/GuardUtils.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+const unsigned MaxDepth = 6;
+
+// Controls the number of uses of the value searched for possible
+// dominating comparisons.
+static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
+ cl::Hidden, cl::init(20));
+
+/// Returns the bitwidth of the given scalar or pointer type. For vector types,
+/// returns the element type's bitwidth.
+static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+ if (unsigned BitWidth = Ty->getScalarSizeInBits())
+ return BitWidth;
+
+ return DL.getIndexTypeSizeInBits(Ty);
+}
+
+namespace {
+
+// Simplifying using an assume can only be done in a particular control-flow
+// context (the context instruction provides that context). If an assume and
+// the context instruction are not in the same block then the DT helps in
+// figuring out if we can use it.
+struct Query {
+ const DataLayout &DL;
+ AssumptionCache *AC;
+ const Instruction *CxtI;
+ const DominatorTree *DT;
+
+ // Unlike the other analyses, this may be a nullptr because not all clients
+ // provide it currently.
+ OptimizationRemarkEmitter *ORE;
+
+ /// Set of assumptions that should be excluded from further queries.
+ /// This is because of the potential for mutual recursion to cause
+ /// computeKnownBits to repeatedly visit the same assume intrinsic. The
+ /// classic case of this is assume(x = y), which will attempt to determine
+ /// bits in x from bits in y, which will attempt to determine bits in y from
+ /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call
+ /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo
+ /// (all of which can call computeKnownBits), and so on.
+ std::array<const Value *, MaxDepth> Excluded;
+
+ /// If true, it is safe to use metadata during simplification.
+ InstrInfoQuery IIQ;
+
+ unsigned NumExcluded = 0;
+
+ Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo,
+ OptimizationRemarkEmitter *ORE = nullptr)
+ : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {}
+
+ Query(const Query &Q, const Value *NewExcl)
+ : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ),
+ NumExcluded(Q.NumExcluded) {
+ Excluded = Q.Excluded;
+ Excluded[NumExcluded++] = NewExcl;
+ assert(NumExcluded <= Excluded.size());
+ }
+
+ bool isExcluded(const Value *Value) const {
+ if (NumExcluded == 0)
+ return false;
+ auto End = Excluded.begin() + NumExcluded;
+ return std::find(Excluded.begin(), End, Value) != End;
+ }
+};
+
+} // end anonymous namespace
+
+// Given the provided Value and, potentially, a context instruction, return
+// the preferred context instruction (if any).
+static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
+ // If we've been provided with a context instruction, then use that (provided
+ // it has been inserted).
+ if (CxtI && CxtI->getParent())
+ return CxtI;
+
+ // If the value is really an already-inserted instruction, then use that.
+ CxtI = dyn_cast<Instruction>(V);
+ if (CxtI && CxtI->getParent())
+ return CxtI;
+
+ return nullptr;
+}
+
+static void computeKnownBits(const Value *V, KnownBits &Known,
+ unsigned Depth, const Query &Q);
+
+void llvm::computeKnownBits(const Value *V, KnownBits &Known,
+ const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT,
+ OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
+ ::computeKnownBits(V, Known, Depth,
+ Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
+static KnownBits computeKnownBits(const Value *V, unsigned Depth,
+ const Query &Q);
+
+KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
+ unsigned Depth, AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT,
+ OptimizationRemarkEmitter *ORE,
+ bool UseInstrInfo) {
+ return ::computeKnownBits(
+ V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
+bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
+ const DataLayout &DL, AssumptionCache *AC,
+ const Instruction *CxtI, const DominatorTree *DT,
+ bool UseInstrInfo) {
+ assert(LHS->getType() == RHS->getType() &&
+ "LHS and RHS should have the same type");
+ assert(LHS->getType()->isIntOrIntVectorTy() &&
+ "LHS and RHS should be integers");
+ // Look for an inverted mask: (X & ~M) op (Y & M).
+ Value *M;
+ if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
+ match(RHS, m_c_And(m_Specific(M), m_Value())))
+ return true;
+ if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
+ match(LHS, m_c_And(m_Specific(M), m_Value())))
+ return true;
+ IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
+ KnownBits LHSKnown(IT->getBitWidth());
+ KnownBits RHSKnown(IT->getBitWidth());
+ computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo);
+ computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo);
+ return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue();
+}
+
+bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI) {
+ for (const User *U : CxtI->users()) {
+ if (const ICmpInst *IC = dyn_cast<ICmpInst>(U))
+ if (IC->isEquality())
+ if (Constant *C = dyn_cast<Constant>(IC->getOperand(1)))
+ if (C->isNullValue())
+ continue;
+ return false;
+ }
+ return true;
+}
+
+static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
+ const Query &Q);
+
+bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
+ bool OrZero, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ return ::isKnownToBeAPowerOfTwo(
+ V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
+}
+
+static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q);
+
+bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ return ::isKnownNonZero(V, Depth,
+ Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
+}
+
+bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL,
+ unsigned Depth, AssumptionCache *AC,
+ const Instruction *CxtI, const DominatorTree *DT,
+ bool UseInstrInfo) {
+ KnownBits Known =
+ computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
+ return Known.isNonNegative();
+}
+
+bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ if (auto *CI = dyn_cast<ConstantInt>(V))
+ return CI->getValue().isStrictlyPositive();
+
+ // TODO: We'd doing two recursive queries here. We should factor this such
+ // that only a single query is needed.
+ return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT, UseInstrInfo) &&
+ isKnownNonZero(V, DL, Depth, AC, CxtI, DT, UseInstrInfo);
+}
+
+bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ KnownBits Known =
+ computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
+ return Known.isNegative();
+}
+
+static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q);
+
+bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
+ const DataLayout &DL, AssumptionCache *AC,
+ const Instruction *CxtI, const DominatorTree *DT,
+ bool UseInstrInfo) {
+ return ::isKnownNonEqual(V1, V2,
+ Query(DL, AC, safeCxtI(V1, safeCxtI(V2, CxtI)), DT,
+ UseInstrInfo, /*ORE=*/nullptr));
+}
+
+static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth,
+ const Query &Q);
+
+bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
+ const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ return ::MaskedValueIsZero(
+ V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
+}
+
+static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
+ const Query &Q);
+
+unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
+ unsigned Depth, AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ return ::ComputeNumSignBits(
+ V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
+}
+
+static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
+ bool NSW,
+ KnownBits &KnownOut, KnownBits &Known2,
+ unsigned Depth, const Query &Q) {
+ unsigned BitWidth = KnownOut.getBitWidth();
+
+ // If an initial sequence of bits in the result is not needed, the
+ // corresponding bits in the operands are not needed.
+ KnownBits LHSKnown(BitWidth);
+ computeKnownBits(Op0, LHSKnown, Depth + 1, Q);
+ computeKnownBits(Op1, Known2, Depth + 1, Q);
+
+ KnownOut = KnownBits::computeForAddSub(Add, NSW, LHSKnown, Known2);
+}
+
+static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
+ KnownBits &Known, KnownBits &Known2,
+ unsigned Depth, const Query &Q) {
+ unsigned BitWidth = Known.getBitWidth();
+ computeKnownBits(Op1, Known, Depth + 1, Q);
+ computeKnownBits(Op0, Known2, Depth + 1, Q);
+
+ bool isKnownNegative = false;
+ bool isKnownNonNegative = false;
+ // If the multiplication is known not to overflow, compute the sign bit.
+ if (NSW) {
+ if (Op0 == Op1) {
+ // The product of a number with itself is non-negative.
+ isKnownNonNegative = true;
+ } else {
+ bool isKnownNonNegativeOp1 = Known.isNonNegative();
+ bool isKnownNonNegativeOp0 = Known2.isNonNegative();
+ bool isKnownNegativeOp1 = Known.isNegative();
+ bool isKnownNegativeOp0 = Known2.isNegative();
+ // The product of two numbers with the same sign is non-negative.
+ isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
+ (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
+ // The product of a negative number and a non-negative number is either
+ // negative or zero.
+ if (!isKnownNonNegative)
+ isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
+ isKnownNonZero(Op0, Depth, Q)) ||
+ (isKnownNegativeOp0 && isKnownNonNegativeOp1 &&
+ isKnownNonZero(Op1, Depth, Q));
+ }
+ }
+
+ assert(!Known.hasConflict() && !Known2.hasConflict());
+ // Compute a conservative estimate for high known-0 bits.
+ unsigned LeadZ = std::max(Known.countMinLeadingZeros() +
+ Known2.countMinLeadingZeros(),
+ BitWidth) - BitWidth;
+ LeadZ = std::min(LeadZ, BitWidth);
+
+ // The result of the bottom bits of an integer multiply can be
+ // inferred by looking at the bottom bits of both operands and
+ // multiplying them together.
+ // We can infer at least the minimum number of known trailing bits
+ // of both operands. Depending on number of trailing zeros, we can
+ // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
+ // a and b are divisible by m and n respectively.
+ // We then calculate how many of those bits are inferrable and set
+ // the output. For example, the i8 mul:
+ // a = XXXX1100 (12)
+ // b = XXXX1110 (14)
+ // We know the bottom 3 bits are zero since the first can be divided by
+ // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
+ // Applying the multiplication to the trimmed arguments gets:
+ // XX11 (3)
+ // X111 (7)
+ // -------
+ // XX11
+ // XX11
+ // XX11
+ // XX11
+ // -------
+ // XXXXX01
+ // Which allows us to infer the 2 LSBs. Since we're multiplying the result
+ // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
+ // The proof for this can be described as:
+ // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
+ // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
+ // umin(countTrailingZeros(C2), C6) +
+ // umin(C5 - umin(countTrailingZeros(C1), C5),
+ // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
+ // %aa = shl i8 %a, C5
+ // %bb = shl i8 %b, C6
+ // %aaa = or i8 %aa, C1
+ // %bbb = or i8 %bb, C2
+ // %mul = mul i8 %aaa, %bbb
+ // %mask = and i8 %mul, C7
+ // =>
+ // %mask = i8 ((C1*C2)&C7)
+ // Where C5, C6 describe the known bits of %a, %b
+ // C1, C2 describe the known bottom bits of %a, %b.
+ // C7 describes the mask of the known bits of the result.
+ APInt Bottom0 = Known.One;
+ APInt Bottom1 = Known2.One;
+
+ // How many times we'd be able to divide each argument by 2 (shr by 1).
+ // This gives us the number of trailing zeros on the multiplication result.
+ unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes();
+ unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes();
+ unsigned TrailZero0 = Known.countMinTrailingZeros();
+ unsigned TrailZero1 = Known2.countMinTrailingZeros();
+ unsigned TrailZ = TrailZero0 + TrailZero1;
+
+ // Figure out the fewest known-bits operand.
+ unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0,
+ TrailBitsKnown1 - TrailZero1);
+ unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
+
+ APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) *
+ Bottom1.getLoBits(TrailBitsKnown1);
+
+ Known.resetAll();
+ Known.Zero.setHighBits(LeadZ);
+ Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
+ Known.One |= BottomKnown.getLoBits(ResultBitsKnown);
+
+ // Only make use of no-wrap flags if we failed to compute the sign bit
+ // directly. This matters if the multiplication always overflows, in
+ // which case we prefer to follow the result of the direct computation,
+ // though as the program is invoking undefined behaviour we can choose
+ // whatever we like here.
+ if (isKnownNonNegative && !Known.isNegative())
+ Known.makeNonNegative();
+ else if (isKnownNegative && !Known.isNonNegative())
+ Known.makeNegative();
+}
+
+void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
+ KnownBits &Known) {
+ unsigned BitWidth = Known.getBitWidth();
+ unsigned NumRanges = Ranges.getNumOperands() / 2;
+ assert(NumRanges >= 1);
+
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+
+ for (unsigned i = 0; i < NumRanges; ++i) {
+ ConstantInt *Lower =
+ mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
+ ConstantInt *Upper =
+ mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
+ ConstantRange Range(Lower->getValue(), Upper->getValue());
+
+ // The first CommonPrefixBits of all values in Range are equal.
+ unsigned CommonPrefixBits =
+ (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros();
+
+ APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
+ Known.One &= Range.getUnsignedMax() & Mask;
+ Known.Zero &= ~Range.getUnsignedMax() & Mask;
+ }
+}
+
+static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
+ SmallVector<const Value *, 16> WorkSet(1, I);
+ SmallPtrSet<const Value *, 32> Visited;
+ SmallPtrSet<const Value *, 16> EphValues;
+
+ // The instruction defining an assumption's condition itself is always
+ // considered ephemeral to that assumption (even if it has other
+ // non-ephemeral users). See r246696's test case for an example.
+ if (is_contained(I->operands(), E))
+ return true;
+
+ while (!WorkSet.empty()) {
+ const Value *V = WorkSet.pop_back_val();
+ if (!Visited.insert(V).second)
+ continue;
+
+ // If all uses of this value are ephemeral, then so is this value.
+ if (llvm::all_of(V->users(), [&](const User *U) {
+ return EphValues.count(U);
+ })) {
+ if (V == E)
+ return true;
+
+ if (V == I || isSafeToSpeculativelyExecute(V)) {
+ EphValues.insert(V);
+ if (const User *U = dyn_cast<User>(V))
+ for (User::const_op_iterator J = U->op_begin(), JE = U->op_end();
+ J != JE; ++J)
+ WorkSet.push_back(*J);
+ }
+ }
+ }
+
+ return false;
+}
+
+// Is this an intrinsic that cannot be speculated but also cannot trap?
+bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
+ if (const CallInst *CI = dyn_cast<CallInst>(I))
+ if (Function *F = CI->getCalledFunction())
+ switch (F->getIntrinsicID()) {
+ default: break;
+ // FIXME: This list is repeated from NoTTI::getIntrinsicCost.
+ case Intrinsic::assume:
+ case Intrinsic::sideeffect:
+ case Intrinsic::dbg_declare:
+ case Intrinsic::dbg_value:
+ case Intrinsic::dbg_label:
+ case Intrinsic::invariant_start:
+ case Intrinsic::invariant_end:
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ case Intrinsic::objectsize:
+ case Intrinsic::ptr_annotation:
+ case Intrinsic::var_annotation:
+ return true;
+ }
+
+ return false;
+}
+
+bool llvm::isValidAssumeForContext(const Instruction *Inv,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ // There are two restrictions on the use of an assume:
+ // 1. The assume must dominate the context (or the control flow must
+ // reach the assume whenever it reaches the context).
+ // 2. The context must not be in the assume's set of ephemeral values
+ // (otherwise we will use the assume to prove that the condition
+ // feeding the assume is trivially true, thus causing the removal of
+ // the assume).
+
+ if (DT) {
+ if (DT->dominates(Inv, CxtI))
+ return true;
+ } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor()) {
+ // We don't have a DT, but this trivially dominates.
+ return true;
+ }
+
+ // With or without a DT, the only remaining case we will check is if the
+ // instructions are in the same BB. Give up if that is not the case.
+ if (Inv->getParent() != CxtI->getParent())
+ return false;
+
+ // If we have a dom tree, then we now know that the assume doesn't dominate
+ // the other instruction. If we don't have a dom tree then we can check if
+ // the assume is first in the BB.
+ if (!DT) {
+ // Search forward from the assume until we reach the context (or the end
+ // of the block); the common case is that the assume will come first.
+ for (auto I = std::next(BasicBlock::const_iterator(Inv)),
+ IE = Inv->getParent()->end(); I != IE; ++I)
+ if (&*I == CxtI)
+ return true;
+ }
+
+ // Don't let an assume affect itself - this would cause the problems
+ // `isEphemeralValueOf` is trying to prevent, and it would also make
+ // the loop below go out of bounds.
+ if (Inv == CxtI)
+ return false;
+
+ // The context comes first, but they're both in the same block. Make sure
+ // there is nothing in between that might interrupt the control flow.
+ for (BasicBlock::const_iterator I =
+ std::next(BasicBlock::const_iterator(CxtI)), IE(Inv);
+ I != IE; ++I)
+ if (!isGuaranteedToTransferExecutionToSuccessor(&*I))
+ return false;
+
+ return !isEphemeralValueOf(Inv, CxtI);
+}
+
+static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
+ unsigned Depth, const Query &Q) {
+ // Use of assumptions is context-sensitive. If we don't have a context, we
+ // cannot use them!
+ if (!Q.AC || !Q.CxtI)
+ return;
+
+ unsigned BitWidth = Known.getBitWidth();
+
+ // Note that the patterns below need to be kept in sync with the code
+ // in AssumptionCache::updateAffectedValues.
+
+ for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
+ if (!AssumeVH)
+ continue;
+ CallInst *I = cast<CallInst>(AssumeVH);
+ assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
+ "Got assumption for the wrong function!");
+ if (Q.isExcluded(I))
+ continue;
+
+ // Warning: This loop can end up being somewhat performance sensitive.
+ // We're running this loop for once for each value queried resulting in a
+ // runtime of ~O(#assumes * #values).
+
+ assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
+ "must be an assume intrinsic");
+
+ Value *Arg = I->getArgOperand(0);
+
+ if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ assert(BitWidth == 1 && "assume operand is not i1?");
+ Known.setAllOnes();
+ return;
+ }
+ if (match(Arg, m_Not(m_Specific(V))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ assert(BitWidth == 1 && "assume operand is not i1?");
+ Known.setAllZero();
+ return;
+ }
+
+ // The remaining tests are all recursive, so bail out if we hit the limit.
+ if (Depth == MaxDepth)
+ continue;
+
+ ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
+ if (!Cmp)
+ continue;
+
+ Value *A, *B;
+ auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
+
+ CmpInst::Predicate Pred;
+ uint64_t C;
+ switch (Cmp->getPredicate()) {
+ default:
+ break;
+ case ICmpInst::ICMP_EQ:
+ // assume(v = a)
+ if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ Known.Zero |= RHSKnown.Zero;
+ Known.One |= RHSKnown.One;
+ // assume(v & b = a)
+ } else if (match(Cmp,
+ m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits MaskKnown(BitWidth);
+ computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
+
+ // For those bits in the mask that are known to be one, we can propagate
+ // known bits from the RHS to V.
+ Known.Zero |= RHSKnown.Zero & MaskKnown.One;
+ Known.One |= RHSKnown.One & MaskKnown.One;
+ // assume(~(v & b) = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits MaskKnown(BitWidth);
+ computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
+
+ // For those bits in the mask that are known to be one, we can propagate
+ // inverted known bits from the RHS to V.
+ Known.Zero |= RHSKnown.One & MaskKnown.One;
+ Known.One |= RHSKnown.Zero & MaskKnown.One;
+ // assume(v | b = a)
+ } else if (match(Cmp,
+ m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits BKnown(BitWidth);
+ computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+
+ // For those bits in B that are known to be zero, we can propagate known
+ // bits from the RHS to V.
+ Known.Zero |= RHSKnown.Zero & BKnown.Zero;
+ Known.One |= RHSKnown.One & BKnown.Zero;
+ // assume(~(v | b) = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits BKnown(BitWidth);
+ computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+
+ // For those bits in B that are known to be zero, we can propagate
+ // inverted known bits from the RHS to V.
+ Known.Zero |= RHSKnown.One & BKnown.Zero;
+ Known.One |= RHSKnown.Zero & BKnown.Zero;
+ // assume(v ^ b = a)
+ } else if (match(Cmp,
+ m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits BKnown(BitWidth);
+ computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+
+ // For those bits in B that are known to be zero, we can propagate known
+ // bits from the RHS to V. For those bits in B that are known to be one,
+ // we can propagate inverted known bits from the RHS to V.
+ Known.Zero |= RHSKnown.Zero & BKnown.Zero;
+ Known.One |= RHSKnown.One & BKnown.Zero;
+ Known.Zero |= RHSKnown.One & BKnown.One;
+ Known.One |= RHSKnown.Zero & BKnown.One;
+ // assume(~(v ^ b) = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ KnownBits BKnown(BitWidth);
+ computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+
+ // For those bits in B that are known to be zero, we can propagate
+ // inverted known bits from the RHS to V. For those bits in B that are
+ // known to be one, we can propagate known bits from the RHS to V.
+ Known.Zero |= RHSKnown.One & BKnown.Zero;
+ Known.One |= RHSKnown.Zero & BKnown.Zero;
+ Known.Zero |= RHSKnown.Zero & BKnown.One;
+ Known.One |= RHSKnown.One & BKnown.One;
+ // assume(v << c = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ // For those bits in RHS that are known, we can propagate them to known
+ // bits in V shifted to the right by C.
+ RHSKnown.Zero.lshrInPlace(C);
+ Known.Zero |= RHSKnown.Zero;
+ RHSKnown.One.lshrInPlace(C);
+ Known.One |= RHSKnown.One;
+ // assume(~(v << c) = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ // For those bits in RHS that are known, we can propagate them inverted
+ // to known bits in V shifted to the right by C.
+ RHSKnown.One.lshrInPlace(C);
+ Known.Zero |= RHSKnown.One;
+ RHSKnown.Zero.lshrInPlace(C);
+ Known.One |= RHSKnown.Zero;
+ // assume(v >> c = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ // For those bits in RHS that are known, we can propagate them to known
+ // bits in V shifted to the right by C.
+ Known.Zero |= RHSKnown.Zero << C;
+ Known.One |= RHSKnown.One << C;
+ // assume(~(v >> c) = a)
+ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))),
+ m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+ // For those bits in RHS that are known, we can propagate them inverted
+ // to known bits in V shifted to the right by C.
+ Known.Zero |= RHSKnown.One << C;
+ Known.One |= RHSKnown.Zero << C;
+ }
+ break;
+ case ICmpInst::ICMP_SGE:
+ // assume(v >=_s c) where c is non-negative
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+
+ if (RHSKnown.isNonNegative()) {
+ // We know that the sign bit is zero.
+ Known.makeNonNegative();
+ }
+ }
+ break;
+ case ICmpInst::ICMP_SGT:
+ // assume(v >_s c) where c is at least -1.
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+
+ if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
+ // We know that the sign bit is zero.
+ Known.makeNonNegative();
+ }
+ }
+ break;
+ case ICmpInst::ICMP_SLE:
+ // assume(v <=_s c) where c is negative
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+
+ if (RHSKnown.isNegative()) {
+ // We know that the sign bit is one.
+ Known.makeNegative();
+ }
+ }
+ break;
+ case ICmpInst::ICMP_SLT:
+ // assume(v <_s c) where c is non-positive
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+
+ if (RHSKnown.isZero() || RHSKnown.isNegative()) {
+ // We know that the sign bit is one.
+ Known.makeNegative();
+ }
+ }
+ break;
+ case ICmpInst::ICMP_ULE:
+ // assume(v <=_u c)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+
+ // Whatever high bits in c are zero are known to be zero.
+ Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
+ }
+ break;
+ case ICmpInst::ICMP_ULT:
+ // assume(v <_u c)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
+ isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
+ KnownBits RHSKnown(BitWidth);
+ computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+
+ // If the RHS is known zero, then this assumption must be wrong (nothing
+ // is unsigned less than zero). Signal a conflict and get out of here.
+ if (RHSKnown.isZero()) {
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ break;
+ }
+
+ // Whatever high bits in c are zero are known to be zero (if c is a power
+ // of 2, then one more).
+ if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I)))
+ Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1);
+ else
+ Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
+ }
+ break;
+ }
+ }
+
+ // If assumptions conflict with each other or previous known bits, then we
+ // have a logical fallacy. It's possible that the assumption is not reachable,
+ // so this isn't a real bug. On the other hand, the program may have undefined
+ // behavior, or we might have a bug in the compiler. We can't assert/crash, so
+ // clear out the known bits, try to warn the user, and hope for the best.
+ if (Known.Zero.intersects(Known.One)) {
+ Known.resetAll();
+
+ if (Q.ORE)
+ Q.ORE->emit([&]() {
+ auto *CxtI = const_cast<Instruction *>(Q.CxtI);
+ return OptimizationRemarkAnalysis("value-tracking", "BadAssumption",
+ CxtI)
+ << "Detected conflicting code assumptions. Program may "
+ "have undefined behavior, or compiler may have "
+ "internal error.";
+ });
+ }
+}
+
+/// Compute known bits from a shift operator, including those with a
+/// non-constant shift amount. Known is the output of this function. Known2 is a
+/// pre-allocated temporary with the same bit width as Known. KZF and KOF are
+/// operator-specific functions that, given the known-zero or known-one bits
+/// respectively, and a shift amount, compute the implied known-zero or
+/// known-one bits of the shift operator's result respectively for that shift
+/// amount. The results from calling KZF and KOF are conservatively combined for
+/// all permitted shift amounts.
+static void computeKnownBitsFromShiftOperator(
+ const Operator *I, KnownBits &Known, KnownBits &Known2,
+ unsigned Depth, const Query &Q,
+ function_ref<APInt(const APInt &, unsigned)> KZF,
+ function_ref<APInt(const APInt &, unsigned)> KOF) {
+ unsigned BitWidth = Known.getBitWidth();
+
+ if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1);
+
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ Known.Zero = KZF(Known.Zero, ShiftAmt);
+ Known.One = KOF(Known.One, ShiftAmt);
+ // If the known bits conflict, this must be an overflowing left shift, so
+ // the shift result is poison. We can return anything we want. Choose 0 for
+ // the best folding opportunity.
+ if (Known.hasConflict())
+ Known.setAllZero();
+
+ return;
+ }
+
+ computeKnownBits(I->getOperand(1), Known, Depth + 1, Q);
+
+ // If the shift amount could be greater than or equal to the bit-width of the
+ // LHS, the value could be poison, but bail out because the check below is
+ // expensive. TODO: Should we just carry on?
+ if ((~Known.Zero).uge(BitWidth)) {
+ Known.resetAll();
+ return;
+ }
+
+ // Note: We cannot use Known.Zero.getLimitedValue() here, because if
+ // BitWidth > 64 and any upper bits are known, we'll end up returning the
+ // limit value (which implies all bits are known).
+ uint64_t ShiftAmtKZ = Known.Zero.zextOrTrunc(64).getZExtValue();
+ uint64_t ShiftAmtKO = Known.One.zextOrTrunc(64).getZExtValue();
+
+ // It would be more-clearly correct to use the two temporaries for this
+ // calculation. Reusing the APInts here to prevent unnecessary allocations.
+ Known.resetAll();
+
+ // If we know the shifter operand is nonzero, we can sometimes infer more
+ // known bits. However this is expensive to compute, so be lazy about it and
+ // only compute it when absolutely necessary.
+ Optional<bool> ShifterOperandIsNonZero;
+
+ // Early exit if we can't constrain any well-defined shift amount.
+ if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) &&
+ !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) {
+ ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), Depth + 1, Q);
+ if (!*ShifterOperandIsNonZero)
+ return;
+ }
+
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) {
+ // Combine the shifted known input bits only for those shift amounts
+ // compatible with its known constraints.
+ if ((ShiftAmt & ~ShiftAmtKZ) != ShiftAmt)
+ continue;
+ if ((ShiftAmt | ShiftAmtKO) != ShiftAmt)
+ continue;
+ // If we know the shifter is nonzero, we may be able to infer more known
+ // bits. This check is sunk down as far as possible to avoid the expensive
+ // call to isKnownNonZero if the cheaper checks above fail.
+ if (ShiftAmt == 0) {
+ if (!ShifterOperandIsNonZero.hasValue())
+ ShifterOperandIsNonZero =
+ isKnownNonZero(I->getOperand(1), Depth + 1, Q);
+ if (*ShifterOperandIsNonZero)
+ continue;
+ }
+
+ Known.Zero &= KZF(Known2.Zero, ShiftAmt);
+ Known.One &= KOF(Known2.One, ShiftAmt);
+ }
+
+ // If the known bits conflict, the result is poison. Return a 0 and hope the
+ // caller can further optimize that.
+ if (Known.hasConflict())
+ Known.setAllZero();
+}
+
+static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
+ unsigned Depth, const Query &Q) {
+ unsigned BitWidth = Known.getBitWidth();
+
+ KnownBits Known2(Known);
+ switch (I->getOpcode()) {
+ default: break;
+ case Instruction::Load:
+ if (MDNode *MD =
+ Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range))
+ computeKnownBitsFromRangeMetadata(*MD, Known);
+ break;
+ case Instruction::And: {
+ // If either the LHS or the RHS are Zero, the result is zero.
+ computeKnownBits(I->getOperand(1), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+
+ // Output known-1 bits are only known if set in both the LHS & RHS.
+ Known.One &= Known2.One;
+ // Output known-0 are known to be clear if zero in either the LHS | RHS.
+ Known.Zero |= Known2.Zero;
+
+ // and(x, add (x, -1)) is a common idiom that always clears the low bit;
+ // here we handle the more general case of adding any odd number by
+ // matching the form add(x, add(x, y)) where y is odd.
+ // TODO: This could be generalized to clearing any bit set in y where the
+ // following bit is known to be unset in y.
+ Value *X = nullptr, *Y = nullptr;
+ if (!Known.Zero[0] && !Known.One[0] &&
+ match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) {
+ Known2.resetAll();
+ computeKnownBits(Y, Known2, Depth + 1, Q);
+ if (Known2.countMinTrailingOnes() > 0)
+ Known.Zero.setBit(0);
+ }
+ break;
+ }
+ case Instruction::Or:
+ computeKnownBits(I->getOperand(1), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+
+ // Output known-0 bits are only known if clear in both the LHS & RHS.
+ Known.Zero &= Known2.Zero;
+ // Output known-1 are known to be set if set in either the LHS | RHS.
+ Known.One |= Known2.One;
+ break;
+ case Instruction::Xor: {
+ computeKnownBits(I->getOperand(1), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+
+ // Output known-0 bits are known if clear or set in both the LHS & RHS.
+ APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One);
+ // Output known-1 are known to be set if set in only one of the LHS, RHS.
+ Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero);
+ Known.Zero = std::move(KnownZeroOut);
+ break;
+ }
+ case Instruction::Mul: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known,
+ Known2, Depth, Q);
+ break;
+ }
+ case Instruction::UDiv: {
+ // For the purposes of computing leading zeros we can conservatively
+ // treat a udiv as a logical right shift by the power of 2 known to
+ // be less than the denominator.
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ unsigned LeadZ = Known2.countMinLeadingZeros();
+
+ Known2.resetAll();
+ computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ unsigned RHSMaxLeadingZeros = Known2.countMaxLeadingZeros();
+ if (RHSMaxLeadingZeros != BitWidth)
+ LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1);
+
+ Known.Zero.setHighBits(LeadZ);
+ break;
+ }
+ case Instruction::Select: {
+ const Value *LHS = nullptr, *RHS = nullptr;
+ SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor;
+ if (SelectPatternResult::isMinOrMax(SPF)) {
+ computeKnownBits(RHS, Known, Depth + 1, Q);
+ computeKnownBits(LHS, Known2, Depth + 1, Q);
+ } else {
+ computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ }
+
+ unsigned MaxHighOnes = 0;
+ unsigned MaxHighZeros = 0;
+ if (SPF == SPF_SMAX) {
+ // If both sides are negative, the result is negative.
+ if (Known.isNegative() && Known2.isNegative())
+ // We can derive a lower bound on the result by taking the max of the
+ // leading one bits.
+ MaxHighOnes =
+ std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
+ // If either side is non-negative, the result is non-negative.
+ else if (Known.isNonNegative() || Known2.isNonNegative())
+ MaxHighZeros = 1;
+ } else if (SPF == SPF_SMIN) {
+ // If both sides are non-negative, the result is non-negative.
+ if (Known.isNonNegative() && Known2.isNonNegative())
+ // We can derive an upper bound on the result by taking the max of the
+ // leading zero bits.
+ MaxHighZeros = std::max(Known.countMinLeadingZeros(),
+ Known2.countMinLeadingZeros());
+ // If either side is negative, the result is negative.
+ else if (Known.isNegative() || Known2.isNegative())
+ MaxHighOnes = 1;
+ } else if (SPF == SPF_UMAX) {
+ // We can derive a lower bound on the result by taking the max of the
+ // leading one bits.
+ MaxHighOnes =
+ std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
+ } else if (SPF == SPF_UMIN) {
+ // We can derive an upper bound on the result by taking the max of the
+ // leading zero bits.
+ MaxHighZeros =
+ std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros());
+ } else if (SPF == SPF_ABS) {
+ // RHS from matchSelectPattern returns the negation part of abs pattern.
+ // If the negate has an NSW flag we can assume the sign bit of the result
+ // will be 0 because that makes abs(INT_MIN) undefined.
+ if (match(RHS, m_Neg(m_Specific(LHS))) &&
+ Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
+ MaxHighZeros = 1;
+ }
+
+ // Only known if known in both the LHS and RHS.
+ Known.One &= Known2.One;
+ Known.Zero &= Known2.Zero;
+ if (MaxHighOnes > 0)
+ Known.One.setHighBits(MaxHighOnes);
+ if (MaxHighZeros > 0)
+ Known.Zero.setHighBits(MaxHighZeros);
+ break;
+ }
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ break; // Can't work with floating point.
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ // Fall through and handle them the same as zext/trunc.
+ LLVM_FALLTHROUGH;
+ case Instruction::ZExt:
+ case Instruction::Trunc: {
+ Type *SrcTy = I->getOperand(0)->getType();
+
+ unsigned SrcBitWidth;
+ // Note that we handle pointer operands here because of inttoptr/ptrtoint
+ // which fall through here.
+ Type *ScalarTy = SrcTy->getScalarType();
+ SrcBitWidth = ScalarTy->isPointerTy() ?
+ Q.DL.getIndexTypeSizeInBits(ScalarTy) :
+ Q.DL.getTypeSizeInBits(ScalarTy);
+
+ assert(SrcBitWidth && "SrcBitWidth can't be zero");
+ Known = Known.zextOrTrunc(SrcBitWidth, false);
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ Known = Known.zextOrTrunc(BitWidth, true /* ExtendedBitsAreKnownZero */);
+ break;
+ }
+ case Instruction::BitCast: {
+ Type *SrcTy = I->getOperand(0)->getType();
+ if (SrcTy->isIntOrPtrTy() &&
+ // TODO: For now, not handling conversions like:
+ // (bitcast i64 %x to <2 x i32>)
+ !I->getType()->isVectorTy()) {
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ break;
+ }
+ break;
+ }
+ case Instruction::SExt: {
+ // Compute the bits in the result that are not present in the input.
+ unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
+
+ Known = Known.trunc(SrcBitWidth);
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ // If the sign bit of the input is known set or clear, then we know the
+ // top bits of the result.
+ Known = Known.sext(BitWidth);
+ break;
+ }
+ case Instruction::Shl: {
+ // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) {
+ APInt KZResult = KnownZero << ShiftAmt;
+ KZResult.setLowBits(ShiftAmt); // Low bits known 0.
+ // If this shift has "nsw" keyword, then the result is either a poison
+ // value or has the same sign bit as the first operand.
+ if (NSW && KnownZero.isSignBitSet())
+ KZResult.setSignBit();
+ return KZResult;
+ };
+
+ auto KOF = [NSW](const APInt &KnownOne, unsigned ShiftAmt) {
+ APInt KOResult = KnownOne << ShiftAmt;
+ if (NSW && KnownOne.isSignBitSet())
+ KOResult.setSignBit();
+ return KOResult;
+ };
+
+ computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF);
+ break;
+ }
+ case Instruction::LShr: {
+ // (lshr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0
+ auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) {
+ APInt KZResult = KnownZero.lshr(ShiftAmt);
+ // High bits known zero.
+ KZResult.setHighBits(ShiftAmt);
+ return KZResult;
+ };
+
+ auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) {
+ return KnownOne.lshr(ShiftAmt);
+ };
+
+ computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF);
+ break;
+ }
+ case Instruction::AShr: {
+ // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0
+ auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) {
+ return KnownZero.ashr(ShiftAmt);
+ };
+
+ auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) {
+ return KnownOne.ashr(ShiftAmt);
+ };
+
+ computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF);
+ break;
+ }
+ case Instruction::Sub: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
+ Known, Known2, Depth, Q);
+ break;
+ }
+ case Instruction::Add: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
+ Known, Known2, Depth, Q);
+ break;
+ }
+ case Instruction::SRem:
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ APInt RA = Rem->getValue().abs();
+ if (RA.isPowerOf2()) {
+ APInt LowBits = RA - 1;
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+
+ // The low bits of the first operand are unchanged by the srem.
+ Known.Zero = Known2.Zero & LowBits;
+ Known.One = Known2.One & LowBits;
+
+ // If the first operand is non-negative or has all low bits zero, then
+ // the upper bits are all zero.
+ if (Known2.isNonNegative() || LowBits.isSubsetOf(Known2.Zero))
+ Known.Zero |= ~LowBits;
+
+ // If the first operand is negative and not all low bits are zero, then
+ // the upper bits are all one.
+ if (Known2.isNegative() && LowBits.intersects(Known2.One))
+ Known.One |= ~LowBits;
+
+ assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?");
+ break;
+ }
+ }
+
+ // The sign bit is the LHS's sign bit, except when the result of the
+ // remainder is zero.
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ // If it's known zero, our sign bit is also zero.
+ if (Known2.isNonNegative())
+ Known.makeNonNegative();
+
+ break;
+ case Instruction::URem: {
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ const APInt &RA = Rem->getValue();
+ if (RA.isPowerOf2()) {
+ APInt LowBits = (RA - 1);
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ Known.Zero |= ~LowBits;
+ Known.One &= LowBits;
+ break;
+ }
+ }
+
+ // Since the result is less than or equal to either operand, any leading
+ // zero bits in either operand must also exist in the result.
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+
+ unsigned Leaders =
+ std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros());
+ Known.resetAll();
+ Known.Zero.setHighBits(Leaders);
+ break;
+ }
+
+ case Instruction::Alloca: {
+ const AllocaInst *AI = cast<AllocaInst>(I);
+ unsigned Align = AI->getAlignment();
+ if (Align == 0)
+ Align = Q.DL.getABITypeAlignment(AI->getAllocatedType());
+
+ if (Align > 0)
+ Known.Zero.setLowBits(countTrailingZeros(Align));
+ break;
+ }
+ case Instruction::GetElementPtr: {
+ // Analyze all of the subscripts of this getelementptr instruction
+ // to determine if we can prove known low zero bits.
+ KnownBits LocalKnown(BitWidth);
+ computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q);
+ unsigned TrailZ = LocalKnown.countMinTrailingZeros();
+
+ gep_type_iterator GTI = gep_type_begin(I);
+ for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
+ Value *Index = I->getOperand(i);
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ // Handle struct member offset arithmetic.
+
+ // Handle case when index is vector zeroinitializer
+ Constant *CIndex = cast<Constant>(Index);
+ if (CIndex->isZeroValue())
+ continue;
+
+ if (CIndex->getType()->isVectorTy())
+ Index = CIndex->getSplatValue();
+
+ unsigned Idx = cast<ConstantInt>(Index)->getZExtValue();
+ const StructLayout *SL = Q.DL.getStructLayout(STy);
+ uint64_t Offset = SL->getElementOffset(Idx);
+ TrailZ = std::min<unsigned>(TrailZ,
+ countTrailingZeros(Offset));
+ } else {
+ // Handle array index arithmetic.
+ Type *IndexedTy = GTI.getIndexedType();
+ if (!IndexedTy->isSized()) {
+ TrailZ = 0;
+ break;
+ }
+ unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits();
+ uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy);
+ LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0);
+ computeKnownBits(Index, LocalKnown, Depth + 1, Q);
+ TrailZ = std::min(TrailZ,
+ unsigned(countTrailingZeros(TypeSize) +
+ LocalKnown.countMinTrailingZeros()));
+ }
+ }
+
+ Known.Zero.setLowBits(TrailZ);
+ break;
+ }
+ case Instruction::PHI: {
+ const PHINode *P = cast<PHINode>(I);
+ // Handle the case of a simple two-predecessor recurrence PHI.
+ // There's a lot more that could theoretically be done here, but
+ // this is sufficient to catch some interesting cases.
+ if (P->getNumIncomingValues() == 2) {
+ for (unsigned i = 0; i != 2; ++i) {
+ Value *L = P->getIncomingValue(i);
+ Value *R = P->getIncomingValue(!i);
+ Operator *LU = dyn_cast<Operator>(L);
+ if (!LU)
+ continue;
+ unsigned Opcode = LU->getOpcode();
+ // Check for operations that have the property that if
+ // both their operands have low zero bits, the result
+ // will have low zero bits.
+ if (Opcode == Instruction::Add ||
+ Opcode == Instruction::Sub ||
+ Opcode == Instruction::And ||
+ Opcode == Instruction::Or ||
+ Opcode == Instruction::Mul) {
+ Value *LL = LU->getOperand(0);
+ Value *LR = LU->getOperand(1);
+ // Find a recurrence.
+ if (LL == I)
+ L = LR;
+ else if (LR == I)
+ L = LL;
+ else
+ continue; // Check for recurrence with L and R flipped.
+ // Ok, we have a PHI of the form L op= R. Check for low
+ // zero bits.
+ computeKnownBits(R, Known2, Depth + 1, Q);
+
+ // We need to take the minimum number of known bits
+ KnownBits Known3(Known);
+ computeKnownBits(L, Known3, Depth + 1, Q);
+
+ Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
+ Known3.countMinTrailingZeros()));
+
+ auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU);
+ if (OverflowOp && Q.IIQ.hasNoSignedWrap(OverflowOp)) {
+ // If initial value of recurrence is nonnegative, and we are adding
+ // a nonnegative number with nsw, the result can only be nonnegative
+ // or poison value regardless of the number of times we execute the
+ // add in phi recurrence. If initial value is negative and we are
+ // adding a negative number with nsw, the result can only be
+ // negative or poison value. Similar arguments apply to sub and mul.
+ //
+ // (add non-negative, non-negative) --> non-negative
+ // (add negative, negative) --> negative
+ if (Opcode == Instruction::Add) {
+ if (Known2.isNonNegative() && Known3.isNonNegative())
+ Known.makeNonNegative();
+ else if (Known2.isNegative() && Known3.isNegative())
+ Known.makeNegative();
+ }
+
+ // (sub nsw non-negative, negative) --> non-negative
+ // (sub nsw negative, non-negative) --> negative
+ else if (Opcode == Instruction::Sub && LL == I) {
+ if (Known2.isNonNegative() && Known3.isNegative())
+ Known.makeNonNegative();
+ else if (Known2.isNegative() && Known3.isNonNegative())
+ Known.makeNegative();
+ }
+
+ // (mul nsw non-negative, non-negative) --> non-negative
+ else if (Opcode == Instruction::Mul && Known2.isNonNegative() &&
+ Known3.isNonNegative())
+ Known.makeNonNegative();
+ }
+
+ break;
+ }
+ }
+ }
+
+ // Unreachable blocks may have zero-operand PHI nodes.
+ if (P->getNumIncomingValues() == 0)
+ break;
+
+ // Otherwise take the unions of the known bit sets of the operands,
+ // taking conservative care to avoid excessive recursion.
+ if (Depth < MaxDepth - 1 && !Known.Zero && !Known.One) {
+ // Skip if every incoming value references to ourself.
+ if (dyn_cast_or_null<UndefValue>(P->hasConstantValue()))
+ break;
+
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ for (Value *IncValue : P->incoming_values()) {
+ // Skip direct self references.
+ if (IncValue == P) continue;
+
+ Known2 = KnownBits(BitWidth);
+ // Recurse, but cap the recursion to one level, because we don't
+ // want to waste time spinning around in loops.
+ computeKnownBits(IncValue, Known2, MaxDepth - 1, Q);
+ Known.Zero &= Known2.Zero;
+ Known.One &= Known2.One;
+ // If all bits have been ruled out, there's no need to check
+ // more operands.
+ if (!Known.Zero && !Known.One)
+ break;
+ }
+ }
+ break;
+ }
+ case Instruction::Call:
+ case Instruction::Invoke:
+ // If range metadata is attached to this call, set known bits from that,
+ // and then intersect with known bits based on other properties of the
+ // function.
+ if (MDNode *MD =
+ Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
+ computeKnownBitsFromRangeMetadata(*MD, Known);
+ if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) {
+ computeKnownBits(RV, Known2, Depth + 1, Q);
+ Known.Zero |= Known2.Zero;
+ Known.One |= Known2.One;
+ }
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ switch (II->getIntrinsicID()) {
+ default: break;
+ case Intrinsic::bitreverse:
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ Known.Zero |= Known2.Zero.reverseBits();
+ Known.One |= Known2.One.reverseBits();
+ break;
+ case Intrinsic::bswap:
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ Known.Zero |= Known2.Zero.byteSwap();
+ Known.One |= Known2.One.byteSwap();
+ break;
+ case Intrinsic::ctlz: {
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ // If we have a known 1, its position is our upper bound.
+ unsigned PossibleLZ = Known2.One.countLeadingZeros();
+ // If this call is undefined for 0, the result will be less than 2^n.
+ if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
+ PossibleLZ = std::min(PossibleLZ, BitWidth - 1);
+ unsigned LowBits = Log2_32(PossibleLZ)+1;
+ Known.Zero.setBitsFrom(LowBits);
+ break;
+ }
+ case Intrinsic::cttz: {
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ // If we have a known 1, its position is our upper bound.
+ unsigned PossibleTZ = Known2.One.countTrailingZeros();
+ // If this call is undefined for 0, the result will be less than 2^n.
+ if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
+ PossibleTZ = std::min(PossibleTZ, BitWidth - 1);
+ unsigned LowBits = Log2_32(PossibleTZ)+1;
+ Known.Zero.setBitsFrom(LowBits);
+ break;
+ }
+ case Intrinsic::ctpop: {
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ // We can bound the space the count needs. Also, bits known to be zero
+ // can't contribute to the population.
+ unsigned BitsPossiblySet = Known2.countMaxPopulation();
+ unsigned LowBits = Log2_32(BitsPossiblySet)+1;
+ Known.Zero.setBitsFrom(LowBits);
+ // TODO: we could bound KnownOne using the lower bound on the number
+ // of bits which might be set provided by popcnt KnownOne2.
+ break;
+ }
+ case Intrinsic::fshr:
+ case Intrinsic::fshl: {
+ const APInt *SA;
+ if (!match(I->getOperand(2), m_APInt(SA)))
+ break;
+
+ // Normalize to funnel shift left.
+ uint64_t ShiftAmt = SA->urem(BitWidth);
+ if (II->getIntrinsicID() == Intrinsic::fshr)
+ ShiftAmt = BitWidth - ShiftAmt;
+
+ KnownBits Known3(Known);
+ computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q);
+
+ Known.Zero =
+ Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
+ Known.One =
+ Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt);
+ break;
+ }
+ case Intrinsic::uadd_sat:
+ case Intrinsic::usub_sat: {
+ bool IsAdd = II->getIntrinsicID() == Intrinsic::uadd_sat;
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+
+ // Add: Leading ones of either operand are preserved.
+ // Sub: Leading zeros of LHS and leading ones of RHS are preserved
+ // as leading zeros in the result.
+ unsigned LeadingKnown;
+ if (IsAdd)
+ LeadingKnown = std::max(Known.countMinLeadingOnes(),
+ Known2.countMinLeadingOnes());
+ else
+ LeadingKnown = std::max(Known.countMinLeadingZeros(),
+ Known2.countMinLeadingOnes());
+
+ Known = KnownBits::computeForAddSub(
+ IsAdd, /* NSW */ false, Known, Known2);
+
+ // We select between the operation result and all-ones/zero
+ // respectively, so we can preserve known ones/zeros.
+ if (IsAdd) {
+ Known.One.setHighBits(LeadingKnown);
+ Known.Zero.clearAllBits();
+ } else {
+ Known.Zero.setHighBits(LeadingKnown);
+ Known.One.clearAllBits();
+ }
+ break;
+ }
+ case Intrinsic::x86_sse42_crc32_64_64:
+ Known.Zero.setBitsFrom(32);
+ break;
+ }
+ }
+ break;
+ case Instruction::ExtractElement:
+ // Look through extract element. At the moment we keep this simple and skip
+ // tracking the specific element. But at least we might find information
+ // valid for all elements of the vector (for example if vector is sign
+ // extended, shifted, etc).
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ break;
+ case Instruction::ExtractValue:
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) {
+ const ExtractValueInst *EVI = cast<ExtractValueInst>(I);
+ if (EVI->getNumIndices() != 1) break;
+ if (EVI->getIndices()[0] == 0) {
+ switch (II->getIntrinsicID()) {
+ default: break;
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::sadd_with_overflow:
+ computeKnownBitsAddSub(true, II->getArgOperand(0),
+ II->getArgOperand(1), false, Known, Known2,
+ Depth, Q);
+ break;
+ case Intrinsic::usub_with_overflow:
+ case Intrinsic::ssub_with_overflow:
+ computeKnownBitsAddSub(false, II->getArgOperand(0),
+ II->getArgOperand(1), false, Known, Known2,
+ Depth, Q);
+ break;
+ case Intrinsic::umul_with_overflow:
+ case Intrinsic::smul_with_overflow:
+ computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false,
+ Known, Known2, Depth, Q);
+ break;
+ }
+ }
+ }
+ }
+}
+
+/// Determine which bits of V are known to be either zero or one and return
+/// them.
+KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) {
+ KnownBits Known(getBitWidth(V->getType(), Q.DL));
+ computeKnownBits(V, Known, Depth, Q);
+ return Known;
+}
+
+/// Determine which bits of V are known to be either zero or one and return
+/// them in the Known bit set.
+///
+/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that
+/// we cannot optimize based on the assumption that it is zero without changing
+/// it to be an explicit zero. If we don't change it to zero, other code could
+/// optimized based on the contradictory assumption that it is non-zero.
+/// Because instcombine aggressively folds operations with undef args anyway,
+/// this won't lose us code quality.
+///
+/// This function is defined on values with integer type, values with pointer
+/// type, and vectors of integers. In the case
+/// where V is a vector, known zero, and known one values are the
+/// same width as the vector element, and the bit is set only if it is true
+/// for all of the elements in the vector.
+void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
+ const Query &Q) {
+ assert(V && "No Value?");
+ assert(Depth <= MaxDepth && "Limit Search Depth");
+ unsigned BitWidth = Known.getBitWidth();
+
+ assert((V->getType()->isIntOrIntVectorTy(BitWidth) ||
+ V->getType()->isPtrOrPtrVectorTy()) &&
+ "Not integer or pointer type!");
+
+ Type *ScalarTy = V->getType()->getScalarType();
+ unsigned ExpectedWidth = ScalarTy->isPointerTy() ?
+ Q.DL.getIndexTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy);
+ assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth");
+ (void)BitWidth;
+ (void)ExpectedWidth;
+
+ const APInt *C;
+ if (match(V, m_APInt(C))) {
+ // We know all of the bits for a scalar constant or a splat vector constant!
+ Known.One = *C;
+ Known.Zero = ~Known.One;
+ return;
+ }
+ // Null and aggregate-zero are all-zeros.
+ if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) {
+ Known.setAllZero();
+ return;
+ }
+ // Handle a constant vector by taking the intersection of the known bits of
+ // each element.
+ if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) {
+ // We know that CDS must be a vector of integers. Take the intersection of
+ // each element.
+ Known.Zero.setAllBits(); Known.One.setAllBits();
+ for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) {
+ APInt Elt = CDS->getElementAsAPInt(i);
+ Known.Zero &= ~Elt;
+ Known.One &= Elt;
+ }
+ return;
+ }
+
+ if (const auto *CV = dyn_cast<ConstantVector>(V)) {
+ // We know that CV must be a vector of integers. Take the intersection of
+ // each element.
+ Known.Zero.setAllBits(); Known.One.setAllBits();
+ for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+ Constant *Element = CV->getAggregateElement(i);
+ auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
+ if (!ElementCI) {
+ Known.resetAll();
+ return;
+ }
+ const APInt &Elt = ElementCI->getValue();
+ Known.Zero &= ~Elt;
+ Known.One &= Elt;
+ }
+ return;
+ }
+
+ // Start out not knowing anything.
+ Known.resetAll();
+
+ // We can't imply anything about undefs.
+ if (isa<UndefValue>(V))
+ return;
+
+ // There's no point in looking through other users of ConstantData for
+ // assumptions. Confirm that we've handled them all.
+ assert(!isa<ConstantData>(V) && "Unhandled constant data!");
+
+ // Limit search depth.
+ // All recursive calls that increase depth must come after this.
+ if (Depth == MaxDepth)
+ return;
+
+ // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
+ // the bits of its aliasee.
+ if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+ if (!GA->isInterposable())
+ computeKnownBits(GA->getAliasee(), Known, Depth + 1, Q);
+ return;
+ }
+
+ if (const Operator *I = dyn_cast<Operator>(V))
+ computeKnownBitsFromOperator(I, Known, Depth, Q);
+
+ // Aligned pointers have trailing zeros - refine Known.Zero set
+ if (V->getType()->isPointerTy()) {
+ const MaybeAlign Align = V->getPointerAlignment(Q.DL);
+ if (Align)
+ Known.Zero.setLowBits(countTrailingZeros(Align->value()));
+ }
+
+ // computeKnownBitsFromAssume strictly refines Known.
+ // Therefore, we run them after computeKnownBitsFromOperator.
+
+ // Check whether a nearby assume intrinsic can determine some known bits.
+ computeKnownBitsFromAssume(V, Known, Depth, Q);
+
+ assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?");
+}
+
+/// Return true if the given value is known to have exactly one
+/// bit set when defined. For vectors return true if every element is known to
+/// be a power of two when defined. Supports values with integer or pointer
+/// types and vectors of integers.
+bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
+ const Query &Q) {
+ assert(Depth <= MaxDepth && "Limit Search Depth");
+
+ // Attempt to match against constants.
+ if (OrZero && match(V, m_Power2OrZero()))
+ return true;
+ if (match(V, m_Power2()))
+ return true;
+
+ // 1 << X is clearly a power of two if the one is not shifted off the end. If
+ // it is shifted off the end then the result is undefined.
+ if (match(V, m_Shl(m_One(), m_Value())))
+ return true;
+
+ // (signmask) >>l X is clearly a power of two if the one is not shifted off
+ // the bottom. If it is shifted off the bottom then the result is undefined.
+ if (match(V, m_LShr(m_SignMask(), m_Value())))
+ return true;
+
+ // The remaining tests are all recursive, so bail out if we hit the limit.
+ if (Depth++ == MaxDepth)
+ return false;
+
+ Value *X = nullptr, *Y = nullptr;
+ // A shift left or a logical shift right of a power of two is a power of two
+ // or zero.
+ if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) ||
+ match(V, m_LShr(m_Value(X), m_Value()))))
+ return isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q);
+
+ if (const ZExtInst *ZI = dyn_cast<ZExtInst>(V))
+ return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q);
+
+ if (const SelectInst *SI = dyn_cast<SelectInst>(V))
+ return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) &&
+ isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q);
+
+ if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) {
+ // A power of two and'd with anything is a power of two or zero.
+ if (isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q) ||
+ isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, Depth, Q))
+ return true;
+ // X & (-X) is always a power of two or zero.
+ if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X))))
+ return true;
+ return false;
+ }
+
+ // Adding a power-of-two or zero to the same power-of-two or zero yields
+ // either the original power-of-two, a larger power-of-two or zero.
+ if (match(V, m_Add(m_Value(X), m_Value(Y)))) {
+ const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V);
+ if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) ||
+ Q.IIQ.hasNoSignedWrap(VOBO)) {
+ if (match(X, m_And(m_Specific(Y), m_Value())) ||
+ match(X, m_And(m_Value(), m_Specific(Y))))
+ if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q))
+ return true;
+ if (match(Y, m_And(m_Specific(X), m_Value())) ||
+ match(Y, m_And(m_Value(), m_Specific(X))))
+ if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q))
+ return true;
+
+ unsigned BitWidth = V->getType()->getScalarSizeInBits();
+ KnownBits LHSBits(BitWidth);
+ computeKnownBits(X, LHSBits, Depth, Q);
+
+ KnownBits RHSBits(BitWidth);
+ computeKnownBits(Y, RHSBits, Depth, Q);
+ // If i8 V is a power of two or zero:
+ // ZeroBits: 1 1 1 0 1 1 1 1
+ // ~ZeroBits: 0 0 0 1 0 0 0 0
+ if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
+ // If OrZero isn't set, we cannot give back a zero result.
+ // Make sure either the LHS or RHS has a bit set.
+ if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
+ return true;
+ }
+ }
+
+ // An exact divide or right shift can only shift off zero bits, so the result
+ // is a power of two only if the first operand is a power of two and not
+ // copying a sign bit (sdiv int_min, 2).
+ if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) ||
+ match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) {
+ return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero,
+ Depth, Q);
+ }
+
+ return false;
+}
+
+/// Test whether a GEP's result is known to be non-null.
+///
+/// Uses properties inherent in a GEP to try to determine whether it is known
+/// to be non-null.
+///
+/// Currently this routine does not support vector GEPs.
+static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth,
+ const Query &Q) {
+ const Function *F = nullptr;
+ if (const Instruction *I = dyn_cast<Instruction>(GEP))
+ F = I->getFunction();
+
+ if (!GEP->isInBounds() ||
+ NullPointerIsDefined(F, GEP->getPointerAddressSpace()))
+ return false;
+
+ // FIXME: Support vector-GEPs.
+ assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
+
+ // If the base pointer is non-null, we cannot walk to a null address with an
+ // inbounds GEP in address space zero.
+ if (isKnownNonZero(GEP->getPointerOperand(), Depth, Q))
+ return true;
+
+ // Walk the GEP operands and see if any operand introduces a non-zero offset.
+ // If so, then the GEP cannot produce a null pointer, as doing so would
+ // inherently violate the inbounds contract within address space zero.
+ for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
+ GTI != GTE; ++GTI) {
+ // Struct types are easy -- they must always be indexed by a constant.
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand());
+ unsigned ElementIdx = OpC->getZExtValue();
+ const StructLayout *SL = Q.DL.getStructLayout(STy);
+ uint64_t ElementOffset = SL->getElementOffset(ElementIdx);
+ if (ElementOffset > 0)
+ return true;
+ continue;
+ }
+
+ // If we have a zero-sized type, the index doesn't matter. Keep looping.
+ if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0)
+ continue;
+
+ // Fast path the constant operand case both for efficiency and so we don't
+ // increment Depth when just zipping down an all-constant GEP.
+ if (ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand())) {
+ if (!OpC->isZero())
+ return true;
+ continue;
+ }
+
+ // We post-increment Depth here because while isKnownNonZero increments it
+ // as well, when we pop back up that increment won't persist. We don't want
+ // to recurse 10k times just because we have 10k GEP operands. We don't
+ // bail completely out because we want to handle constant GEPs regardless
+ // of depth.
+ if (Depth++ >= MaxDepth)
+ continue;
+
+ if (isKnownNonZero(GTI.getOperand(), Depth, Q))
+ return true;
+ }
+
+ return false;
+}
+
+static bool isKnownNonNullFromDominatingCondition(const Value *V,
+ const Instruction *CtxI,
+ const DominatorTree *DT) {
+ assert(V->getType()->isPointerTy() && "V must be pointer type");
+ assert(!isa<ConstantData>(V) && "Did not expect ConstantPointerNull");
+
+ if (!CtxI || !DT)
+ return false;
+
+ unsigned NumUsesExplored = 0;
+ for (auto *U : V->users()) {
+ // Avoid massive lists
+ if (NumUsesExplored >= DomConditionsMaxUses)
+ break;
+ NumUsesExplored++;
+
+ // If the value is used as an argument to a call or invoke, then argument
+ // attributes may provide an answer about null-ness.
+ if (auto CS = ImmutableCallSite(U))
+ if (auto *CalledFunc = CS.getCalledFunction())
+ for (const Argument &Arg : CalledFunc->args())
+ if (CS.getArgOperand(Arg.getArgNo()) == V &&
+ Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI))
+ return true;
+
+ // Consider only compare instructions uniquely controlling a branch
+ CmpInst::Predicate Pred;
+ if (!match(const_cast<User *>(U),
+ m_c_ICmp(Pred, m_Specific(V), m_Zero())) ||
+ (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE))
+ continue;
+
+ SmallVector<const User *, 4> WorkList;
+ SmallPtrSet<const User *, 4> Visited;
+ for (auto *CmpU : U->users()) {
+ assert(WorkList.empty() && "Should be!");
+ if (Visited.insert(CmpU).second)
+ WorkList.push_back(CmpU);
+
+ while (!WorkList.empty()) {
+ auto *Curr = WorkList.pop_back_val();
+
+ // If a user is an AND, add all its users to the work list. We only
+ // propagate "pred != null" condition through AND because it is only
+ // correct to assume that all conditions of AND are met in true branch.
+ // TODO: Support similar logic of OR and EQ predicate?
+ if (Pred == ICmpInst::ICMP_NE)
+ if (auto *BO = dyn_cast<BinaryOperator>(Curr))
+ if (BO->getOpcode() == Instruction::And) {
+ for (auto *BOU : BO->users())
+ if (Visited.insert(BOU).second)
+ WorkList.push_back(BOU);
+ continue;
+ }
+
+ if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) {
+ assert(BI->isConditional() && "uses a comparison!");
+
+ BasicBlock *NonNullSuccessor =
+ BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0);
+ BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
+ if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent()))
+ return true;
+ } else if (Pred == ICmpInst::ICMP_NE && isGuard(Curr) &&
+ DT->dominates(cast<Instruction>(Curr), CtxI)) {
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
+/// Does the 'Range' metadata (which must be a valid MD_range operand list)
+/// ensure that the value it's attached to is never Value? 'RangeType' is
+/// is the type of the value described by the range.
+static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
+ const unsigned NumRanges = Ranges->getNumOperands() / 2;
+ assert(NumRanges >= 1);
+ for (unsigned i = 0; i < NumRanges; ++i) {
+ ConstantInt *Lower =
+ mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0));
+ ConstantInt *Upper =
+ mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1));
+ ConstantRange Range(Lower->getValue(), Upper->getValue());
+ if (Range.contains(Value))
+ return false;
+ }
+ return true;
+}
+
+/// Return true if the given value is known to be non-zero when defined. For
+/// vectors, return true if every element is known to be non-zero when
+/// defined. For pointers, if the context instruction and dominator tree are
+/// specified, perform context-sensitive analysis and return true if the
+/// pointer couldn't possibly be null at the specified instruction.
+/// Supports values with integer or pointer type and vectors of integers.
+bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
+ if (auto *C = dyn_cast<Constant>(V)) {
+ if (C->isNullValue())
+ return false;
+ if (isa<ConstantInt>(C))
+ // Must be non-zero due to null test above.
+ return true;
+
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ // See the comment for IntToPtr/PtrToInt instructions below.
+ if (CE->getOpcode() == Instruction::IntToPtr ||
+ CE->getOpcode() == Instruction::PtrToInt)
+ if (Q.DL.getTypeSizeInBits(CE->getOperand(0)->getType()) <=
+ Q.DL.getTypeSizeInBits(CE->getType()))
+ return isKnownNonZero(CE->getOperand(0), Depth, Q);
+ }
+
+ // For constant vectors, check that all elements are undefined or known
+ // non-zero to determine that the whole vector is known non-zero.
+ if (auto *VecTy = dyn_cast<VectorType>(C->getType())) {
+ for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
+ Constant *Elt = C->getAggregateElement(i);
+ if (!Elt || Elt->isNullValue())
+ return false;
+ if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt))
+ return false;
+ }
+ return true;
+ }
+
+ // A global variable in address space 0 is non null unless extern weak
+ // or an absolute symbol reference. Other address spaces may have null as a
+ // valid address for a global, so we can't assume anything.
+ if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
+ if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
+ GV->getType()->getAddressSpace() == 0)
+ return true;
+ } else
+ return false;
+ }
+
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) {
+ // If the possible ranges don't contain zero, then the value is
+ // definitely non-zero.
+ if (auto *Ty = dyn_cast<IntegerType>(V->getType())) {
+ const APInt ZeroValue(Ty->getBitWidth(), 0);
+ if (rangeMetadataExcludesValue(Ranges, ZeroValue))
+ return true;
+ }
+ }
+ }
+
+ // Some of the tests below are recursive, so bail out if we hit the limit.
+ if (Depth++ >= MaxDepth)
+ return false;
+
+ // Check for pointer simplifications.
+ if (V->getType()->isPointerTy()) {
+ // Alloca never returns null, malloc might.
+ if (isa<AllocaInst>(V) && Q.DL.getAllocaAddrSpace() == 0)
+ return true;
+
+ // A byval, inalloca, or nonnull argument is never null.
+ if (const Argument *A = dyn_cast<Argument>(V))
+ if (A->hasByValOrInAllocaAttr() || A->hasNonNullAttr())
+ return true;
+
+ // A Load tagged with nonnull metadata is never null.
+ if (const LoadInst *LI = dyn_cast<LoadInst>(V))
+ if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull))
+ return true;
+
+ if (const auto *Call = dyn_cast<CallBase>(V)) {
+ if (Call->isReturnNonNull())
+ return true;
+ if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true))
+ return isKnownNonZero(RP, Depth, Q);
+ }
+ }
+
+
+ // Check for recursive pointer simplifications.
+ if (V->getType()->isPointerTy()) {
+ if (isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT))
+ return true;
+
+ // Look through bitcast operations, GEPs, and int2ptr instructions as they
+ // do not alter the value, or at least not the nullness property of the
+ // value, e.g., int2ptr is allowed to zero/sign extend the value.
+ //
+ // Note that we have to take special care to avoid looking through
+ // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
+ // as casts that can alter the value, e.g., AddrSpaceCasts.
+ if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V))
+ if (isGEPKnownNonNull(GEP, Depth, Q))
+ return true;
+
+ if (auto *BCO = dyn_cast<BitCastOperator>(V))
+ return isKnownNonZero(BCO->getOperand(0), Depth, Q);
+
+ if (auto *I2P = dyn_cast<IntToPtrInst>(V))
+ if (Q.DL.getTypeSizeInBits(I2P->getSrcTy()) <=
+ Q.DL.getTypeSizeInBits(I2P->getDestTy()))
+ return isKnownNonZero(I2P->getOperand(0), Depth, Q);
+ }
+
+ // Similar to int2ptr above, we can look through ptr2int here if the cast
+ // is a no-op or an extend and not a truncate.
+ if (auto *P2I = dyn_cast<PtrToIntInst>(V))
+ if (Q.DL.getTypeSizeInBits(P2I->getSrcTy()) <=
+ Q.DL.getTypeSizeInBits(P2I->getDestTy()))
+ return isKnownNonZero(P2I->getOperand(0), Depth, Q);
+
+ unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL);
+
+ // X | Y != 0 if X != 0 or Y != 0.
+ Value *X = nullptr, *Y = nullptr;
+ if (match(V, m_Or(m_Value(X), m_Value(Y))))
+ return isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q);
+
+ // ext X != 0 if X != 0.
+ if (isa<SExtInst>(V) || isa<ZExtInst>(V))
+ return isKnownNonZero(cast<Instruction>(V)->getOperand(0), Depth, Q);
+
+ // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
+ // if the lowest bit is shifted off the end.
+ if (match(V, m_Shl(m_Value(X), m_Value(Y)))) {
+ // shl nuw can't remove any non-zero bits.
+ const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
+ if (Q.IIQ.hasNoUnsignedWrap(BO))
+ return isKnownNonZero(X, Depth, Q);
+
+ KnownBits Known(BitWidth);
+ computeKnownBits(X, Known, Depth, Q);
+ if (Known.One[0])
+ return true;
+ }
+ // shr X, Y != 0 if X is negative. Note that the value of the shift is not
+ // defined if the sign bit is shifted off the end.
+ else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) {
+ // shr exact can only shift out zero bits.
+ const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V);
+ if (BO->isExact())
+ return isKnownNonZero(X, Depth, Q);
+
+ KnownBits Known = computeKnownBits(X, Depth, Q);
+ if (Known.isNegative())
+ return true;
+
+ // If the shifter operand is a constant, and all of the bits shifted
+ // out are known to be zero, and X is known non-zero then at least one
+ // non-zero bit must remain.
+ if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) {
+ auto ShiftVal = Shift->getLimitedValue(BitWidth - 1);
+ // Is there a known one in the portion not shifted out?
+ if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal)
+ return true;
+ // Are all the bits to be shifted out known zero?
+ if (Known.countMinTrailingZeros() >= ShiftVal)
+ return isKnownNonZero(X, Depth, Q);
+ }
+ }
+ // div exact can only produce a zero if the dividend is zero.
+ else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) {
+ return isKnownNonZero(X, Depth, Q);
+ }
+ // X + Y.
+ else if (match(V, m_Add(m_Value(X), m_Value(Y)))) {
+ KnownBits XKnown = computeKnownBits(X, Depth, Q);
+ KnownBits YKnown = computeKnownBits(Y, Depth, Q);
+
+ // If X and Y are both non-negative (as signed values) then their sum is not
+ // zero unless both X and Y are zero.
+ if (XKnown.isNonNegative() && YKnown.isNonNegative())
+ if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q))
+ return true;
+
+ // If X and Y are both negative (as signed values) then their sum is not
+ // zero unless both X and Y equal INT_MIN.
+ if (XKnown.isNegative() && YKnown.isNegative()) {
+ APInt Mask = APInt::getSignedMaxValue(BitWidth);
+ // The sign bit of X is set. If some other bit is set then X is not equal
+ // to INT_MIN.
+ if (XKnown.One.intersects(Mask))
+ return true;
+ // The sign bit of Y is set. If some other bit is set then Y is not equal
+ // to INT_MIN.
+ if (YKnown.One.intersects(Mask))
+ return true;
+ }
+
+ // The sum of a non-negative number and a power of two is not zero.
+ if (XKnown.isNonNegative() &&
+ isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q))
+ return true;
+ if (YKnown.isNonNegative() &&
+ isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
+ return true;
+ }
+ // X * Y.
+ else if (match(V, m_Mul(m_Value(X), m_Value(Y)))) {
+ const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
+ // If X and Y are non-zero then so is X * Y as long as the multiplication
+ // does not overflow.
+ if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) &&
+ isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q))
+ return true;
+ }
+ // (C ? X : Y) != 0 if X != 0 and Y != 0.
+ else if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
+ if (isKnownNonZero(SI->getTrueValue(), Depth, Q) &&
+ isKnownNonZero(SI->getFalseValue(), Depth, Q))
+ return true;
+ }
+ // PHI
+ else if (const PHINode *PN = dyn_cast<PHINode>(V)) {
+ // Try and detect a recurrence that monotonically increases from a
+ // starting value, as these are common as induction variables.
+ if (PN->getNumIncomingValues() == 2) {
+ Value *Start = PN->getIncomingValue(0);
+ Value *Induction = PN->getIncomingValue(1);
+ if (isa<ConstantInt>(Induction) && !isa<ConstantInt>(Start))
+ std::swap(Start, Induction);
+ if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) {
+ if (!C->isZero() && !C->isNegative()) {
+ ConstantInt *X;
+ if (Q.IIQ.UseInstrInfo &&
+ (match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) ||
+ match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) &&
+ !X->isNegative())
+ return true;
+ }
+ }
+ }
+ // Check if all incoming values are non-zero constant.
+ bool AllNonZeroConstants = llvm::all_of(PN->operands(), [](Value *V) {
+ return isa<ConstantInt>(V) && !cast<ConstantInt>(V)->isZero();
+ });
+ if (AllNonZeroConstants)
+ return true;
+ }
+
+ KnownBits Known(BitWidth);
+ computeKnownBits(V, Known, Depth, Q);
+ return Known.One != 0;
+}
+
+/// Return true if V2 == V1 + X, where X is known non-zero.
+static bool isAddOfNonZero(const Value *V1, const Value *V2, const Query &Q) {
+ const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
+ if (!BO || BO->getOpcode() != Instruction::Add)
+ return false;
+ Value *Op = nullptr;
+ if (V2 == BO->getOperand(0))
+ Op = BO->getOperand(1);
+ else if (V2 == BO->getOperand(1))
+ Op = BO->getOperand(0);
+ else
+ return false;
+ return isKnownNonZero(Op, 0, Q);
+}
+
+/// Return true if it is known that V1 != V2.
+static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q) {
+ if (V1 == V2)
+ return false;
+ if (V1->getType() != V2->getType())
+ // We can't look through casts yet.
+ return false;
+ if (isAddOfNonZero(V1, V2, Q) || isAddOfNonZero(V2, V1, Q))
+ return true;
+
+ if (V1->getType()->isIntOrIntVectorTy()) {
+ // Are any known bits in V1 contradictory to known bits in V2? If V1
+ // has a known zero where V2 has a known one, they must not be equal.
+ KnownBits Known1 = computeKnownBits(V1, 0, Q);
+ KnownBits Known2 = computeKnownBits(V2, 0, Q);
+
+ if (Known1.Zero.intersects(Known2.One) ||
+ Known2.Zero.intersects(Known1.One))
+ return true;
+ }
+ return false;
+}
+
+/// Return true if 'V & Mask' is known to be zero. We use this predicate to
+/// simplify operations downstream. Mask is known to be zero for bits that V
+/// cannot have.
+///
+/// This function is defined on values with integer type, values with pointer
+/// type, and vectors of integers. In the case
+/// where V is a vector, the mask, known zero, and known one values are the
+/// same width as the vector element, and the bit is set only if it is true
+/// for all of the elements in the vector.
+bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth,
+ const Query &Q) {
+ KnownBits Known(Mask.getBitWidth());
+ computeKnownBits(V, Known, Depth, Q);
+ return Mask.isSubsetOf(Known.Zero);
+}
+
+// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
+// Returns the input and lower/upper bounds.
+static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
+ const APInt *&CLow, const APInt *&CHigh) {
+ assert(isa<Operator>(Select) &&
+ cast<Operator>(Select)->getOpcode() == Instruction::Select &&
+ "Input should be a Select!");
+
+ const Value *LHS = nullptr, *RHS = nullptr;
+ SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
+ if (SPF != SPF_SMAX && SPF != SPF_SMIN)
+ return false;
+
+ if (!match(RHS, m_APInt(CLow)))
+ return false;
+
+ const Value *LHS2 = nullptr, *RHS2 = nullptr;
+ SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
+ if (getInverseMinMaxFlavor(SPF) != SPF2)
+ return false;
+
+ if (!match(RHS2, m_APInt(CHigh)))
+ return false;
+
+ if (SPF == SPF_SMIN)
+ std::swap(CLow, CHigh);
+
+ In = LHS2;
+ return CLow->sle(*CHigh);
+}
+
+/// For vector constants, loop over the elements and find the constant with the
+/// minimum number of sign bits. Return 0 if the value is not a vector constant
+/// or if any element was not analyzed; otherwise, return the count for the
+/// element with the minimum number of sign bits.
+static unsigned computeNumSignBitsVectorConstant(const Value *V,
+ unsigned TyBits) {
+ const auto *CV = dyn_cast<Constant>(V);
+ if (!CV || !CV->getType()->isVectorTy())
+ return 0;
+
+ unsigned MinSignBits = TyBits;
+ unsigned NumElts = CV->getType()->getVectorNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ // If we find a non-ConstantInt, bail out.
+ auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i));
+ if (!Elt)
+ return 0;
+
+ MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits());
+ }
+
+ return MinSignBits;
+}
+
+static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,
+ const Query &Q);
+
+static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
+ const Query &Q) {
+ unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q);
+ assert(Result > 0 && "At least one sign bit needs to be present!");
+ return Result;
+}
+
+/// Return the number of times the sign bit of the register is replicated into
+/// the other bits. We know that at least 1 bit is always equal to the sign bit
+/// (itself), but other cases can give us information. For example, immediately
+/// after an "ashr X, 2", we know that the top 3 bits are all equal to each
+/// other, so we return 3. For vectors, return the number of sign bits for the
+/// vector element with the minimum number of known sign bits.
+static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,
+ const Query &Q) {
+ assert(Depth <= MaxDepth && "Limit Search Depth");
+
+ // We return the minimum number of sign bits that are guaranteed to be present
+ // in V, so for undef we have to conservatively return 1. We don't have the
+ // same behavior for poison though -- that's a FIXME today.
+
+ Type *ScalarTy = V->getType()->getScalarType();
+ unsigned TyBits = ScalarTy->isPointerTy() ?
+ Q.DL.getIndexTypeSizeInBits(ScalarTy) :
+ Q.DL.getTypeSizeInBits(ScalarTy);
+
+ unsigned Tmp, Tmp2;
+ unsigned FirstAnswer = 1;
+
+ // Note that ConstantInt is handled by the general computeKnownBits case
+ // below.
+
+ if (Depth == MaxDepth)
+ return 1; // Limit search depth.
+
+ if (auto *U = dyn_cast<Operator>(V)) {
+ switch (Operator::getOpcode(V)) {
+ default: break;
+ case Instruction::SExt:
+ Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
+ return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp;
+
+ case Instruction::SDiv: {
+ const APInt *Denominator;
+ // sdiv X, C -> adds log(C) sign bits.
+ if (match(U->getOperand(1), m_APInt(Denominator))) {
+
+ // Ignore non-positive denominator.
+ if (!Denominator->isStrictlyPositive())
+ break;
+
+ // Calculate the incoming numerator bits.
+ unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+
+ // Add floor(log(C)) bits to the numerator bits.
+ return std::min(TyBits, NumBits + Denominator->logBase2());
+ }
+ break;
+ }
+
+ case Instruction::SRem: {
+ const APInt *Denominator;
+ // srem X, C -> we know that the result is within [-C+1,C) when C is a
+ // positive constant. This let us put a lower bound on the number of sign
+ // bits.
+ if (match(U->getOperand(1), m_APInt(Denominator))) {
+
+ // Ignore non-positive denominator.
+ if (!Denominator->isStrictlyPositive())
+ break;
+
+ // Calculate the incoming numerator bits. SRem by a positive constant
+ // can't lower the number of sign bits.
+ unsigned NumrBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+
+ // Calculate the leading sign bit constraints by examining the
+ // denominator. Given that the denominator is positive, there are two
+ // cases:
+ //
+ // 1. the numerator is positive. The result range is [0,C) and [0,C) u<
+ // (1 << ceilLogBase2(C)).
+ //
+ // 2. the numerator is negative. Then the result range is (-C,0] and
+ // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
+ //
+ // Thus a lower bound on the number of sign bits is `TyBits -
+ // ceilLogBase2(C)`.
+
+ unsigned ResBits = TyBits - Denominator->ceilLogBase2();
+ return std::max(NumrBits, ResBits);
+ }
+ break;
+ }
+
+ case Instruction::AShr: {
+ Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ // ashr X, C -> adds C sign bits. Vectors too.
+ const APInt *ShAmt;
+ if (match(U->getOperand(1), m_APInt(ShAmt))) {
+ if (ShAmt->uge(TyBits))
+ break; // Bad shift.
+ unsigned ShAmtLimited = ShAmt->getZExtValue();
+ Tmp += ShAmtLimited;
+ if (Tmp > TyBits) Tmp = TyBits;
+ }
+ return Tmp;
+ }
+ case Instruction::Shl: {
+ const APInt *ShAmt;
+ if (match(U->getOperand(1), m_APInt(ShAmt))) {
+ // shl destroys sign bits.
+ Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ if (ShAmt->uge(TyBits) || // Bad shift.
+ ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
+ Tmp2 = ShAmt->getZExtValue();
+ return Tmp - Tmp2;
+ }
+ break;
+ }
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor: // NOT is handled here.
+ // Logical binary ops preserve the number of sign bits at the worst.
+ Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ if (Tmp != 1) {
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ FirstAnswer = std::min(Tmp, Tmp2);
+ // We computed what we know about the sign bits as our first
+ // answer. Now proceed to the generic code that uses
+ // computeKnownBits, and pick whichever answer is better.
+ }
+ break;
+
+ case Instruction::Select: {
+ // If we have a clamp pattern, we know that the number of sign bits will
+ // be the minimum of the clamp min/max range.
+ const Value *X;
+ const APInt *CLow, *CHigh;
+ if (isSignedMinMaxClamp(U, X, CLow, CHigh))
+ return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
+
+ Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ if (Tmp == 1) break;
+ Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q);
+ return std::min(Tmp, Tmp2);
+ }
+
+ case Instruction::Add:
+ // Add can have at most one carry bit. Thus we know that the output
+ // is, at worst, one more bit than the inputs.
+ Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ if (Tmp == 1) break;
+
+ // Special case decrementing a value (ADD X, -1):
+ if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
+ if (CRHS->isAllOnesValue()) {
+ KnownBits Known(TyBits);
+ computeKnownBits(U->getOperand(0), Known, Depth + 1, Q);
+
+ // If the input is known to be 0 or 1, the output is 0/-1, which is
+ // all sign bits set.
+ if ((Known.Zero | 1).isAllOnesValue())
+ return TyBits;
+
+ // If we are subtracting one from a positive number, there is no carry
+ // out of the result.
+ if (Known.isNonNegative())
+ return Tmp;
+ }
+
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ if (Tmp2 == 1) break;
+ return std::min(Tmp, Tmp2) - 1;
+
+ case Instruction::Sub:
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ if (Tmp2 == 1) break;
+
+ // Handle NEG.
+ if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
+ if (CLHS->isNullValue()) {
+ KnownBits Known(TyBits);
+ computeKnownBits(U->getOperand(1), Known, Depth + 1, Q);
+ // If the input is known to be 0 or 1, the output is 0/-1, which is
+ // all sign bits set.
+ if ((Known.Zero | 1).isAllOnesValue())
+ return TyBits;
+
+ // If the input is known to be positive (the sign bit is known clear),
+ // the output of the NEG has the same number of sign bits as the
+ // input.
+ if (Known.isNonNegative())
+ return Tmp2;
+
+ // Otherwise, we treat this like a SUB.
+ }
+
+ // Sub can have at most one carry bit. Thus we know that the output
+ // is, at worst, one more bit than the inputs.
+ Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ if (Tmp == 1) break;
+ return std::min(Tmp, Tmp2) - 1;
+
+ case Instruction::Mul: {
+ // The output of the Mul can be at most twice the valid bits in the
+ // inputs.
+ unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ if (SignBitsOp0 == 1) break;
+ unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ if (SignBitsOp1 == 1) break;
+ unsigned OutValidBits =
+ (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
+ return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
+ }
+
+ case Instruction::PHI: {
+ const PHINode *PN = cast<PHINode>(U);
+ unsigned NumIncomingValues = PN->getNumIncomingValues();
+ // Don't analyze large in-degree PHIs.
+ if (NumIncomingValues > 4) break;
+ // Unreachable blocks may have zero-operand PHI nodes.
+ if (NumIncomingValues == 0) break;
+
+ // Take the minimum of all incoming values. This can't infinitely loop
+ // because of our depth threshold.
+ Tmp = ComputeNumSignBits(PN->getIncomingValue(0), Depth + 1, Q);
+ for (unsigned i = 1, e = NumIncomingValues; i != e; ++i) {
+ if (Tmp == 1) return Tmp;
+ Tmp = std::min(
+ Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, Q));
+ }
+ return Tmp;
+ }
+
+ case Instruction::Trunc:
+ // FIXME: it's tricky to do anything useful for this, but it is an
+ // important case for targets like X86.
+ break;
+
+ case Instruction::ExtractElement:
+ // Look through extract element. At the moment we keep this simple and
+ // skip tracking the specific element. But at least we might find
+ // information valid for all elements of the vector (for example if vector
+ // is sign extended, shifted, etc).
+ return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+
+ case Instruction::ShuffleVector: {
+ // TODO: This is copied almost directly from the SelectionDAG version of
+ // ComputeNumSignBits. It would be better if we could share common
+ // code. If not, make sure that changes are translated to the DAG.
+
+ // Collect the minimum number of sign bits that are shared by every vector
+ // element referenced by the shuffle.
+ auto *Shuf = cast<ShuffleVectorInst>(U);
+ int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements();
+ int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements();
+ APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
+ for (int i = 0; i != NumMaskElts; ++i) {
+ int M = Shuf->getMaskValue(i);
+ assert(M < NumElts * 2 && "Invalid shuffle mask constant");
+ // For undef elements, we don't know anything about the common state of
+ // the shuffle result.
+ if (M == -1)
+ return 1;
+ if (M < NumElts)
+ DemandedLHS.setBit(M % NumElts);
+ else
+ DemandedRHS.setBit(M % NumElts);
+ }
+ Tmp = std::numeric_limits<unsigned>::max();
+ if (!!DemandedLHS)
+ Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q);
+ if (!!DemandedRHS) {
+ Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q);
+ Tmp = std::min(Tmp, Tmp2);
+ }
+ // If we don't know anything, early out and try computeKnownBits
+ // fall-back.
+ if (Tmp == 1)
+ break;
+ assert(Tmp <= V->getType()->getScalarSizeInBits() &&
+ "Failed to determine minimum sign bits");
+ return Tmp;
+ }
+ }
+ }
+
+ // Finally, if we can prove that the top bits of the result are 0's or 1's,
+ // use this information.
+
+ // If we can examine all elements of a vector constant successfully, we're
+ // done (we can't do any better than that). If not, keep trying.
+ if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits))
+ return VecSignBits;
+
+ KnownBits Known(TyBits);
+ computeKnownBits(V, Known, Depth, Q);
+
+ // If we know that the sign bit is either zero or one, determine the number of
+ // identical bits in the top of the input value.
+ return std::max(FirstAnswer, Known.countMinSignBits());
+}
+
+/// This function computes the integer multiple of Base that equals V.
+/// If successful, it returns true and returns the multiple in
+/// Multiple. If unsuccessful, it returns false. It looks
+/// through SExt instructions only if LookThroughSExt is true.
+bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple,
+ bool LookThroughSExt, unsigned Depth) {
+ assert(V && "No Value?");
+ assert(Depth <= MaxDepth && "Limit Search Depth");
+ assert(V->getType()->isIntegerTy() && "Not integer or pointer type!");
+
+ Type *T = V->getType();
+
+ ConstantInt *CI = dyn_cast<ConstantInt>(V);
+
+ if (Base == 0)
+ return false;
+
+ if (Base == 1) {
+ Multiple = V;
+ return true;
+ }
+
+ ConstantExpr *CO = dyn_cast<ConstantExpr>(V);
+ Constant *BaseVal = ConstantInt::get(T, Base);
+ if (CO && CO == BaseVal) {
+ // Multiple is 1.
+ Multiple = ConstantInt::get(T, 1);
+ return true;
+ }
+
+ if (CI && CI->getZExtValue() % Base == 0) {
+ Multiple = ConstantInt::get(T, CI->getZExtValue() / Base);
+ return true;
+ }
+
+ if (Depth == MaxDepth) return false; // Limit search depth.
+
+ Operator *I = dyn_cast<Operator>(V);
+ if (!I) return false;
+
+ switch (I->getOpcode()) {
+ default: break;
+ case Instruction::SExt:
+ if (!LookThroughSExt) return false;
+ // otherwise fall through to ZExt
+ LLVM_FALLTHROUGH;
+ case Instruction::ZExt:
+ return ComputeMultiple(I->getOperand(0), Base, Multiple,
+ LookThroughSExt, Depth+1);
+ case Instruction::Shl:
+ case Instruction::Mul: {
+ Value *Op0 = I->getOperand(0);
+ Value *Op1 = I->getOperand(1);
+
+ if (I->getOpcode() == Instruction::Shl) {
+ ConstantInt *Op1CI = dyn_cast<ConstantInt>(Op1);
+ if (!Op1CI) return false;
+ // Turn Op0 << Op1 into Op0 * 2^Op1
+ APInt Op1Int = Op1CI->getValue();
+ uint64_t BitToSet = Op1Int.getLimitedValue(Op1Int.getBitWidth() - 1);
+ APInt API(Op1Int.getBitWidth(), 0);
+ API.setBit(BitToSet);
+ Op1 = ConstantInt::get(V->getContext(), API);
+ }
+
+ Value *Mul0 = nullptr;
+ if (ComputeMultiple(Op0, Base, Mul0, LookThroughSExt, Depth+1)) {
+ if (Constant *Op1C = dyn_cast<Constant>(Op1))
+ if (Constant *MulC = dyn_cast<Constant>(Mul0)) {
+ if (Op1C->getType()->getPrimitiveSizeInBits() <
+ MulC->getType()->getPrimitiveSizeInBits())
+ Op1C = ConstantExpr::getZExt(Op1C, MulC->getType());
+ if (Op1C->getType()->getPrimitiveSizeInBits() >
+ MulC->getType()->getPrimitiveSizeInBits())
+ MulC = ConstantExpr::getZExt(MulC, Op1C->getType());
+
+ // V == Base * (Mul0 * Op1), so return (Mul0 * Op1)
+ Multiple = ConstantExpr::getMul(MulC, Op1C);
+ return true;
+ }
+
+ if (ConstantInt *Mul0CI = dyn_cast<ConstantInt>(Mul0))
+ if (Mul0CI->getValue() == 1) {
+ // V == Base * Op1, so return Op1
+ Multiple = Op1;
+ return true;
+ }
+ }
+
+ Value *Mul1 = nullptr;
+ if (ComputeMultiple(Op1, Base, Mul1, LookThroughSExt, Depth+1)) {
+ if (Constant *Op0C = dyn_cast<Constant>(Op0))
+ if (Constant *MulC = dyn_cast<Constant>(Mul1)) {
+ if (Op0C->getType()->getPrimitiveSizeInBits() <
+ MulC->getType()->getPrimitiveSizeInBits())
+ Op0C = ConstantExpr::getZExt(Op0C, MulC->getType());
+ if (Op0C->getType()->getPrimitiveSizeInBits() >
+ MulC->getType()->getPrimitiveSizeInBits())
+ MulC = ConstantExpr::getZExt(MulC, Op0C->getType());
+
+ // V == Base * (Mul1 * Op0), so return (Mul1 * Op0)
+ Multiple = ConstantExpr::getMul(MulC, Op0C);
+ return true;
+ }
+
+ if (ConstantInt *Mul1CI = dyn_cast<ConstantInt>(Mul1))
+ if (Mul1CI->getValue() == 1) {
+ // V == Base * Op0, so return Op0
+ Multiple = Op0;
+ return true;
+ }
+ }
+ }
+ }
+
+ // We could not determine if V is a multiple of Base.
+ return false;
+}
+
+Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS,
+ const TargetLibraryInfo *TLI) {
+ const Function *F = ICS.getCalledFunction();
+ if (!F)
+ return Intrinsic::not_intrinsic;
+
+ if (F->isIntrinsic())
+ return F->getIntrinsicID();
+
+ if (!TLI)
+ return Intrinsic::not_intrinsic;
+
+ LibFunc Func;
+ // We're going to make assumptions on the semantics of the functions, check
+ // that the target knows that it's available in this environment and it does
+ // not have local linkage.
+ if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(*F, Func))
+ return Intrinsic::not_intrinsic;
+
+ if (!ICS.onlyReadsMemory())
+ return Intrinsic::not_intrinsic;
+
+ // Otherwise check if we have a call to a function that can be turned into a
+ // vector intrinsic.
+ switch (Func) {
+ default:
+ break;
+ case LibFunc_sin:
+ case LibFunc_sinf:
+ case LibFunc_sinl:
+ return Intrinsic::sin;
+ case LibFunc_cos:
+ case LibFunc_cosf:
+ case LibFunc_cosl:
+ return Intrinsic::cos;
+ case LibFunc_exp:
+ case LibFunc_expf:
+ case LibFunc_expl:
+ return Intrinsic::exp;
+ case LibFunc_exp2:
+ case LibFunc_exp2f:
+ case LibFunc_exp2l:
+ return Intrinsic::exp2;
+ case LibFunc_log:
+ case LibFunc_logf:
+ case LibFunc_logl:
+ return Intrinsic::log;
+ case LibFunc_log10:
+ case LibFunc_log10f:
+ case LibFunc_log10l:
+ return Intrinsic::log10;
+ case LibFunc_log2:
+ case LibFunc_log2f:
+ case LibFunc_log2l:
+ return Intrinsic::log2;
+ case LibFunc_fabs:
+ case LibFunc_fabsf:
+ case LibFunc_fabsl:
+ return Intrinsic::fabs;
+ case LibFunc_fmin:
+ case LibFunc_fminf:
+ case LibFunc_fminl:
+ return Intrinsic::minnum;
+ case LibFunc_fmax:
+ case LibFunc_fmaxf:
+ case LibFunc_fmaxl:
+ return Intrinsic::maxnum;
+ case LibFunc_copysign:
+ case LibFunc_copysignf:
+ case LibFunc_copysignl:
+ return Intrinsic::copysign;
+ case LibFunc_floor:
+ case LibFunc_floorf:
+ case LibFunc_floorl:
+ return Intrinsic::floor;
+ case LibFunc_ceil:
+ case LibFunc_ceilf:
+ case LibFunc_ceill:
+ return Intrinsic::ceil;
+ case LibFunc_trunc:
+ case LibFunc_truncf:
+ case LibFunc_truncl:
+ return Intrinsic::trunc;
+ case LibFunc_rint:
+ case LibFunc_rintf:
+ case LibFunc_rintl:
+ return Intrinsic::rint;
+ case LibFunc_nearbyint:
+ case LibFunc_nearbyintf:
+ case LibFunc_nearbyintl:
+ return Intrinsic::nearbyint;
+ case LibFunc_round:
+ case LibFunc_roundf:
+ case LibFunc_roundl:
+ return Intrinsic::round;
+ case LibFunc_pow:
+ case LibFunc_powf:
+ case LibFunc_powl:
+ return Intrinsic::pow;
+ case LibFunc_sqrt:
+ case LibFunc_sqrtf:
+ case LibFunc_sqrtl:
+ return Intrinsic::sqrt;
+ }
+
+ return Intrinsic::not_intrinsic;
+}
+
+/// Return true if we can prove that the specified FP value is never equal to
+/// -0.0.
+///
+/// NOTE: this function will need to be revisited when we support non-default
+/// rounding modes!
+bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI,
+ unsigned Depth) {
+ if (auto *CFP = dyn_cast<ConstantFP>(V))
+ return !CFP->getValueAPF().isNegZero();
+
+ // Limit search depth.
+ if (Depth == MaxDepth)
+ return false;
+
+ auto *Op = dyn_cast<Operator>(V);
+ if (!Op)
+ return false;
+
+ // Check if the nsz fast-math flag is set.
+ if (auto *FPO = dyn_cast<FPMathOperator>(Op))
+ if (FPO->hasNoSignedZeros())
+ return true;
+
+ // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0.
+ if (match(Op, m_FAdd(m_Value(), m_PosZeroFP())))
+ return true;
+
+ // sitofp and uitofp turn into +0.0 for zero.
+ if (isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op))
+ return true;
+
+ if (auto *Call = dyn_cast<CallInst>(Op)) {
+ Intrinsic::ID IID = getIntrinsicForCallSite(Call, TLI);
+ switch (IID) {
+ default:
+ break;
+ // sqrt(-0.0) = -0.0, no other negative results are possible.
+ case Intrinsic::sqrt:
+ case Intrinsic::canonicalize:
+ return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1);
+ // fabs(x) != -0.0
+ case Intrinsic::fabs:
+ return true;
+ }
+ }
+
+ return false;
+}
+
+/// If \p SignBitOnly is true, test for a known 0 sign bit rather than a
+/// standard ordered compare. e.g. make -0.0 olt 0.0 be true because of the sign
+/// bit despite comparing equal.
+static bool cannotBeOrderedLessThanZeroImpl(const Value *V,
+ const TargetLibraryInfo *TLI,
+ bool SignBitOnly,
+ unsigned Depth) {
+ // TODO: This function does not do the right thing when SignBitOnly is true
+ // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform
+ // which flips the sign bits of NaNs. See
+ // https://llvm.org/bugs/show_bug.cgi?id=31702.
+
+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+ return !CFP->getValueAPF().isNegative() ||
+ (!SignBitOnly && CFP->getValueAPF().isZero());
+ }
+
+ // Handle vector of constants.
+ if (auto *CV = dyn_cast<Constant>(V)) {
+ if (CV->getType()->isVectorTy()) {
+ unsigned NumElts = CV->getType()->getVectorNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
+ if (!CFP)
+ return false;
+ if (CFP->getValueAPF().isNegative() &&
+ (SignBitOnly || !CFP->getValueAPF().isZero()))
+ return false;
+ }
+
+ // All non-negative ConstantFPs.
+ return true;
+ }
+ }
+
+ if (Depth == MaxDepth)
+ return false; // Limit search depth.
+
+ const Operator *I = dyn_cast<Operator>(V);
+ if (!I)
+ return false;
+
+ switch (I->getOpcode()) {
+ default:
+ break;
+ // Unsigned integers are always nonnegative.
+ case Instruction::UIToFP:
+ return true;
+ case Instruction::FMul:
+ // x*x is always non-negative or a NaN.
+ if (I->getOperand(0) == I->getOperand(1) &&
+ (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs()))
+ return true;
+
+ LLVM_FALLTHROUGH;
+ case Instruction::FAdd:
+ case Instruction::FDiv:
+ case Instruction::FRem:
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly,
+ Depth + 1);
+ case Instruction::Select:
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly,
+ Depth + 1) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly,
+ Depth + 1);
+ case Instruction::FPExt:
+ case Instruction::FPTrunc:
+ // Widening/narrowing never change sign.
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1);
+ case Instruction::ExtractElement:
+ // Look through extract element. At the moment we keep this simple and skip
+ // tracking the specific element. But at least we might find information
+ // valid for all elements of the vector.
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1);
+ case Instruction::Call:
+ const auto *CI = cast<CallInst>(I);
+ Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI);
+ switch (IID) {
+ default:
+ break;
+ case Intrinsic::maxnum:
+ return (isKnownNeverNaN(I->getOperand(0), TLI) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI,
+ SignBitOnly, Depth + 1)) ||
+ (isKnownNeverNaN(I->getOperand(1), TLI) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI,
+ SignBitOnly, Depth + 1));
+
+ case Intrinsic::maximum:
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1) ||
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly,
+ Depth + 1);
+ case Intrinsic::minnum:
+ case Intrinsic::minimum:
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly,
+ Depth + 1);
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::fabs:
+ return true;
+
+ case Intrinsic::sqrt:
+ // sqrt(x) is always >= -0 or NaN. Moreover, sqrt(x) == -0 iff x == -0.
+ if (!SignBitOnly)
+ return true;
+ return CI->hasNoNaNs() && (CI->hasNoSignedZeros() ||
+ CannotBeNegativeZero(CI->getOperand(0), TLI));
+
+ case Intrinsic::powi:
+ if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ // powi(x,n) is non-negative if n is even.
+ if (Exponent->getBitWidth() <= 64 && Exponent->getSExtValue() % 2u == 0)
+ return true;
+ }
+ // TODO: This is not correct. Given that exp is an integer, here are the
+ // ways that pow can return a negative value:
+ //
+ // pow(x, exp) --> negative if exp is odd and x is negative.
+ // pow(-0, exp) --> -inf if exp is negative odd.
+ // pow(-0, exp) --> -0 if exp is positive odd.
+ // pow(-inf, exp) --> -0 if exp is negative odd.
+ // pow(-inf, exp) --> -inf if exp is positive odd.
+ //
+ // Therefore, if !SignBitOnly, we can return true if x >= +0 or x is NaN,
+ // but we must return false if x == -0. Unfortunately we do not currently
+ // have a way of expressing this constraint. See details in
+ // https://llvm.org/bugs/show_bug.cgi?id=31702.
+ return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly,
+ Depth + 1);
+
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd:
+ // x*x+y is non-negative if y is non-negative.
+ return I->getOperand(0) == I->getOperand(1) &&
+ (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs()) &&
+ cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly,
+ Depth + 1);
+ }
+ break;
+ }
+ return false;
+}
+
+bool llvm::CannotBeOrderedLessThanZero(const Value *V,
+ const TargetLibraryInfo *TLI) {
+ return cannotBeOrderedLessThanZeroImpl(V, TLI, false, 0);
+}
+
+bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) {
+ return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0);
+}
+
+bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
+ unsigned Depth) {
+ assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type");
+
+ // If we're told that NaNs won't happen, assume they won't.
+ if (auto *FPMathOp = dyn_cast<FPMathOperator>(V))
+ if (FPMathOp->hasNoNaNs())
+ return true;
+
+ // Handle scalar constants.
+ if (auto *CFP = dyn_cast<ConstantFP>(V))
+ return !CFP->isNaN();
+
+ if (Depth == MaxDepth)
+ return false;
+
+ if (auto *Inst = dyn_cast<Instruction>(V)) {
+ switch (Inst->getOpcode()) {
+ case Instruction::FAdd:
+ case Instruction::FMul:
+ case Instruction::FSub:
+ case Instruction::FDiv:
+ case Instruction::FRem: {
+ // TODO: Need isKnownNeverInfinity
+ return false;
+ }
+ case Instruction::Select: {
+ return isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) &&
+ isKnownNeverNaN(Inst->getOperand(2), TLI, Depth + 1);
+ }
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ return true;
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1);
+ default:
+ break;
+ }
+ }
+
+ if (const auto *II = dyn_cast<IntrinsicInst>(V)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::canonicalize:
+ case Intrinsic::fabs:
+ case Intrinsic::copysign:
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::floor:
+ case Intrinsic::ceil:
+ case Intrinsic::trunc:
+ case Intrinsic::rint:
+ case Intrinsic::nearbyint:
+ case Intrinsic::round:
+ return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1);
+ case Intrinsic::sqrt:
+ return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) &&
+ CannotBeOrderedLessThanZero(II->getArgOperand(0), TLI);
+ case Intrinsic::minnum:
+ case Intrinsic::maxnum:
+ // If either operand is not NaN, the result is not NaN.
+ return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) ||
+ isKnownNeverNaN(II->getArgOperand(1), TLI, Depth + 1);
+ default:
+ return false;
+ }
+ }
+
+ // Bail out for constant expressions, but try to handle vector constants.
+ if (!V->getType()->isVectorTy() || !isa<Constant>(V))
+ return false;
+
+ // For vectors, verify that each element is not NaN.
+ unsigned NumElts = V->getType()->getVectorNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
+ if (!Elt)
+ return false;
+ if (isa<UndefValue>(Elt))
+ continue;
+ auto *CElt = dyn_cast<ConstantFP>(Elt);
+ if (!CElt || CElt->isNaN())
+ return false;
+ }
+ // All elements were confirmed not-NaN or undefined.
+ return true;
+}
+
+Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
+
+ // All byte-wide stores are splatable, even of arbitrary variables.
+ if (V->getType()->isIntegerTy(8))
+ return V;
+
+ LLVMContext &Ctx = V->getContext();
+
+ // Undef don't care.
+ auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx));
+ if (isa<UndefValue>(V))
+ return UndefInt8;
+
+ const uint64_t Size = DL.getTypeStoreSize(V->getType());
+ if (!Size)
+ return UndefInt8;
+
+ Constant *C = dyn_cast<Constant>(V);
+ if (!C) {
+ // Conceptually, we could handle things like:
+ // %a = zext i8 %X to i16
+ // %b = shl i16 %a, 8
+ // %c = or i16 %a, %b
+ // but until there is an example that actually needs this, it doesn't seem
+ // worth worrying about.
+ return nullptr;
+ }
+
+ // Handle 'null' ConstantArrayZero etc.
+ if (C->isNullValue())
+ return Constant::getNullValue(Type::getInt8Ty(Ctx));
+
+ // Constant floating-point values can be handled as integer values if the
+ // corresponding integer value is "byteable". An important case is 0.0.
+ if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
+ Type *Ty = nullptr;
+ if (CFP->getType()->isHalfTy())
+ Ty = Type::getInt16Ty(Ctx);
+ else if (CFP->getType()->isFloatTy())
+ Ty = Type::getInt32Ty(Ctx);
+ else if (CFP->getType()->isDoubleTy())
+ Ty = Type::getInt64Ty(Ctx);
+ // Don't handle long double formats, which have strange constraints.
+ return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL)
+ : nullptr;
+ }
+
+ // We can handle constant integers that are multiple of 8 bits.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
+ if (CI->getBitWidth() % 8 == 0) {
+ assert(CI->getBitWidth() > 8 && "8 bits should be handled above!");
+ if (!CI->getValue().isSplat(8))
+ return nullptr;
+ return ConstantInt::get(Ctx, CI->getValue().trunc(8));
+ }
+ }
+
+ if (auto *CE = dyn_cast<ConstantExpr>(C)) {
+ if (CE->getOpcode() == Instruction::IntToPtr) {
+ auto PS = DL.getPointerSizeInBits(
+ cast<PointerType>(CE->getType())->getAddressSpace());
+ return isBytewiseValue(
+ ConstantExpr::getIntegerCast(CE->getOperand(0),
+ Type::getIntNTy(Ctx, PS), false),
+ DL);
+ }
+ }
+
+ auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
+ if (LHS == RHS)
+ return LHS;
+ if (!LHS || !RHS)
+ return nullptr;
+ if (LHS == UndefInt8)
+ return RHS;
+ if (RHS == UndefInt8)
+ return LHS;
+ return nullptr;
+ };
+
+ if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) {
+ Value *Val = UndefInt8;
+ for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I)
+ if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL))))
+ return nullptr;
+ return Val;
+ }
+
+ if (isa<ConstantAggregate>(C)) {
+ Value *Val = UndefInt8;
+ for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I)
+ if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I), DL))))
+ return nullptr;
+ return Val;
+ }
+
+ // Don't try to handle the handful of other constants.
+ return nullptr;
+}
+
+// This is the recursive version of BuildSubAggregate. It takes a few different
+// arguments. Idxs is the index within the nested struct From that we are
+// looking at now (which is of type IndexedType). IdxSkip is the number of
+// indices from Idxs that should be left out when inserting into the resulting
+// struct. To is the result struct built so far, new insertvalue instructions
+// build on that.
+static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType,
+ SmallVectorImpl<unsigned> &Idxs,
+ unsigned IdxSkip,
+ Instruction *InsertBefore) {
+ StructType *STy = dyn_cast<StructType>(IndexedType);
+ if (STy) {
+ // Save the original To argument so we can modify it
+ Value *OrigTo = To;
+ // General case, the type indexed by Idxs is a struct
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ // Process each struct element recursively
+ Idxs.push_back(i);
+ Value *PrevTo = To;
+ To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip,
+ InsertBefore);
+ Idxs.pop_back();
+ if (!To) {
+ // Couldn't find any inserted value for this index? Cleanup
+ while (PrevTo != OrigTo) {
+ InsertValueInst* Del = cast<InsertValueInst>(PrevTo);
+ PrevTo = Del->getAggregateOperand();
+ Del->eraseFromParent();
+ }
+ // Stop processing elements
+ break;
+ }
+ }
+ // If we successfully found a value for each of our subaggregates
+ if (To)
+ return To;
+ }
+ // Base case, the type indexed by SourceIdxs is not a struct, or not all of
+ // the struct's elements had a value that was inserted directly. In the latter
+ // case, perhaps we can't determine each of the subelements individually, but
+ // we might be able to find the complete struct somewhere.
+
+ // Find the value that is at that particular spot
+ Value *V = FindInsertedValue(From, Idxs);
+
+ if (!V)
+ return nullptr;
+
+ // Insert the value in the new (sub) aggregate
+ return InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip),
+ "tmp", InsertBefore);
+}
+
+// This helper takes a nested struct and extracts a part of it (which is again a
+// struct) into a new value. For example, given the struct:
+// { a, { b, { c, d }, e } }
+// and the indices "1, 1" this returns
+// { c, d }.
+//
+// It does this by inserting an insertvalue for each element in the resulting
+// struct, as opposed to just inserting a single struct. This will only work if
+// each of the elements of the substruct are known (ie, inserted into From by an
+// insertvalue instruction somewhere).
+//
+// All inserted insertvalue instructions are inserted before InsertBefore
+static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
+ Instruction *InsertBefore) {
+ assert(InsertBefore && "Must have someplace to insert!");
+ Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(),
+ idx_range);
+ Value *To = UndefValue::get(IndexedType);
+ SmallVector<unsigned, 10> Idxs(idx_range.begin(), idx_range.end());
+ unsigned IdxSkip = Idxs.size();
+
+ return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
+}
+
+/// Given an aggregate and a sequence of indices, see if the scalar value
+/// indexed is already around as a register, for example if it was inserted
+/// directly into the aggregate.
+///
+/// If InsertBefore is not null, this function will duplicate (modified)
+/// insertvalues when a part of a nested struct is extracted.
+Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
+ Instruction *InsertBefore) {
+ // Nothing to index? Just return V then (this is useful at the end of our
+ // recursion).
+ if (idx_range.empty())
+ return V;
+ // We have indices, so V should have an indexable type.
+ assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
+ "Not looking at a struct or array?");
+ assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
+ "Invalid indices for type?");
+
+ if (Constant *C = dyn_cast<Constant>(V)) {
+ C = C->getAggregateElement(idx_range[0]);
+ if (!C) return nullptr;
+ return FindInsertedValue(C, idx_range.slice(1), InsertBefore);
+ }
+
+ if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) {
+ // Loop the indices for the insertvalue instruction in parallel with the
+ // requested indices
+ const unsigned *req_idx = idx_range.begin();
+ for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
+ i != e; ++i, ++req_idx) {
+ if (req_idx == idx_range.end()) {
+ // We can't handle this without inserting insertvalues
+ if (!InsertBefore)
+ return nullptr;
+
+ // The requested index identifies a part of a nested aggregate. Handle
+ // this specially. For example,
+ // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
+ // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
+ // %C = extractvalue {i32, { i32, i32 } } %B, 1
+ // This can be changed into
+ // %A = insertvalue {i32, i32 } undef, i32 10, 0
+ // %C = insertvalue {i32, i32 } %A, i32 11, 1
+ // which allows the unused 0,0 element from the nested struct to be
+ // removed.
+ return BuildSubAggregate(V, makeArrayRef(idx_range.begin(), req_idx),
+ InsertBefore);
+ }
+
+ // This insert value inserts something else than what we are looking for.
+ // See if the (aggregate) value inserted into has the value we are
+ // looking for, then.
+ if (*req_idx != *i)
+ return FindInsertedValue(I->getAggregateOperand(), idx_range,
+ InsertBefore);
+ }
+ // If we end up here, the indices of the insertvalue match with those
+ // requested (though possibly only partially). Now we recursively look at
+ // the inserted value, passing any remaining indices.
+ return FindInsertedValue(I->getInsertedValueOperand(),
+ makeArrayRef(req_idx, idx_range.end()),
+ InsertBefore);
+ }
+
+ if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) {
+ // If we're extracting a value from an aggregate that was extracted from
+ // something else, we can extract from that something else directly instead.
+ // However, we will need to chain I's indices with the requested indices.
+
+ // Calculate the number of indices required
+ unsigned size = I->getNumIndices() + idx_range.size();
+ // Allocate some space to put the new indices in
+ SmallVector<unsigned, 5> Idxs;
+ Idxs.reserve(size);
+ // Add indices from the extract value instruction
+ Idxs.append(I->idx_begin(), I->idx_end());
+
+ // Add requested indices
+ Idxs.append(idx_range.begin(), idx_range.end());
+
+ assert(Idxs.size() == size
+ && "Number of indices added not correct?");
+
+ return FindInsertedValue(I->getAggregateOperand(), Idxs, InsertBefore);
+ }
+ // Otherwise, we don't know (such as, extracting from a function return value
+ // or load instruction)
+ return nullptr;
+}
+
+bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP,
+ unsigned CharSize) {
+ // Make sure the GEP has exactly three arguments.
+ if (GEP->getNumOperands() != 3)
+ return false;
+
+ // Make sure the index-ee is a pointer to array of \p CharSize integers.
+ // CharSize.
+ ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType());
+ if (!AT || !AT->getElementType()->isIntegerTy(CharSize))
+ return false;
+
+ // Check to make sure that the first operand of the GEP is an integer and
+ // has value 0 so that we are sure we're indexing into the initializer.
+ const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1));
+ if (!FirstIdx || !FirstIdx->isZero())
+ return false;
+
+ return true;
+}
+
+bool llvm::getConstantDataArrayInfo(const Value *V,
+ ConstantDataArraySlice &Slice,
+ unsigned ElementSize, uint64_t Offset) {
+ assert(V);
+
+ // Look through bitcast instructions and geps.
+ V = V->stripPointerCasts();
+
+ // If the value is a GEP instruction or constant expression, treat it as an
+ // offset.
+ if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+ // The GEP operator should be based on a pointer to string constant, and is
+ // indexing into the string constant.
+ if (!isGEPBasedOnPointerToString(GEP, ElementSize))
+ return false;
+
+ // If the second index isn't a ConstantInt, then this is a variable index
+ // into the array. If this occurs, we can't say anything meaningful about
+ // the string.
+ uint64_t StartIdx = 0;
+ if (const ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2)))
+ StartIdx = CI->getZExtValue();
+ else
+ return false;
+ return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize,
+ StartIdx + Offset);
+ }
+
+ // The GEP instruction, constant or instruction, must reference a global
+ // variable that is a constant and is initialized. The referenced constant
+ // initializer is the array that we'll use for optimization.
+ const GlobalVariable *GV = dyn_cast<GlobalVariable>(V);
+ if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
+ return false;
+
+ const ConstantDataArray *Array;
+ ArrayType *ArrayTy;
+ if (GV->getInitializer()->isNullValue()) {
+ Type *GVTy = GV->getValueType();
+ if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) {
+ // A zeroinitializer for the array; there is no ConstantDataArray.
+ Array = nullptr;
+ } else {
+ const DataLayout &DL = GV->getParent()->getDataLayout();
+ uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy);
+ uint64_t Length = SizeInBytes / (ElementSize / 8);
+ if (Length <= Offset)
+ return false;
+
+ Slice.Array = nullptr;
+ Slice.Offset = 0;
+ Slice.Length = Length - Offset;
+ return true;
+ }
+ } else {
+ // This must be a ConstantDataArray.
+ Array = dyn_cast<ConstantDataArray>(GV->getInitializer());
+ if (!Array)
+ return false;
+ ArrayTy = Array->getType();
+ }
+ if (!ArrayTy->getElementType()->isIntegerTy(ElementSize))
+ return false;
+
+ uint64_t NumElts = ArrayTy->getArrayNumElements();
+ if (Offset > NumElts)
+ return false;
+
+ Slice.Array = Array;
+ Slice.Offset = Offset;
+ Slice.Length = NumElts - Offset;
+ return true;
+}
+
+/// This function computes the length of a null-terminated C string pointed to
+/// by V. If successful, it returns true and returns the string in Str.
+/// If unsuccessful, it returns false.
+bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
+ uint64_t Offset, bool TrimAtNul) {
+ ConstantDataArraySlice Slice;
+ if (!getConstantDataArrayInfo(V, Slice, 8, Offset))
+ return false;
+
+ if (Slice.Array == nullptr) {
+ if (TrimAtNul) {
+ Str = StringRef();
+ return true;
+ }
+ if (Slice.Length == 1) {
+ Str = StringRef("", 1);
+ return true;
+ }
+ // We cannot instantiate a StringRef as we do not have an appropriate string
+ // of 0s at hand.
+ return false;
+ }
+
+ // Start out with the entire array in the StringRef.
+ Str = Slice.Array->getAsString();
+ // Skip over 'offset' bytes.
+ Str = Str.substr(Slice.Offset);
+
+ if (TrimAtNul) {
+ // Trim off the \0 and anything after it. If the array is not nul
+ // terminated, we just return the whole end of string. The client may know
+ // some other way that the string is length-bound.
+ Str = Str.substr(0, Str.find('\0'));
+ }
+ return true;
+}
+
+// These next two are very similar to the above, but also look through PHI
+// nodes.
+// TODO: See if we can integrate these two together.
+
+/// If we can compute the length of the string pointed to by
+/// the specified pointer, return 'len+1'. If we can't, return 0.
+static uint64_t GetStringLengthH(const Value *V,
+ SmallPtrSetImpl<const PHINode*> &PHIs,
+ unsigned CharSize) {
+ // Look through noop bitcast instructions.
+ V = V->stripPointerCasts();
+
+ // If this is a PHI node, there are two cases: either we have already seen it
+ // or we haven't.
+ if (const PHINode *PN = dyn_cast<PHINode>(V)) {
+ if (!PHIs.insert(PN).second)
+ return ~0ULL; // already in the set.
+
+ // If it was new, see if all the input strings are the same length.
+ uint64_t LenSoFar = ~0ULL;
+ for (Value *IncValue : PN->incoming_values()) {
+ uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize);
+ if (Len == 0) return 0; // Unknown length -> unknown.
+
+ if (Len == ~0ULL) continue;
+
+ if (Len != LenSoFar && LenSoFar != ~0ULL)
+ return 0; // Disagree -> unknown.
+ LenSoFar = Len;
+ }
+
+ // Success, all agree.
+ return LenSoFar;
+ }
+
+ // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
+ if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
+ uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize);
+ if (Len1 == 0) return 0;
+ uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize);
+ if (Len2 == 0) return 0;
+ if (Len1 == ~0ULL) return Len2;
+ if (Len2 == ~0ULL) return Len1;
+ if (Len1 != Len2) return 0;
+ return Len1;
+ }
+
+ // Otherwise, see if we can read the string.
+ ConstantDataArraySlice Slice;
+ if (!getConstantDataArrayInfo(V, Slice, CharSize))
+ return 0;
+
+ if (Slice.Array == nullptr)
+ return 1;
+
+ // Search for nul characters
+ unsigned NullIndex = 0;
+ for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
+ if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0)
+ break;
+ }
+
+ return NullIndex + 1;
+}
+
+/// If we can compute the length of the string pointed to by
+/// the specified pointer, return 'len+1'. If we can't, return 0.
+uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
+ if (!V->getType()->isPointerTy())
+ return 0;
+
+ SmallPtrSet<const PHINode*, 32> PHIs;
+ uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
+ // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
+ // an empty string as a length.
+ return Len == ~0ULL ? 1 : Len;
+}
+
+const Value *
+llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
+ bool MustPreserveNullness) {
+ assert(Call &&
+ "getArgumentAliasingToReturnedPointer only works on nonnull calls");
+ if (const Value *RV = Call->getReturnedArgOperand())
+ return RV;
+ // This can be used only as a aliasing property.
+ if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
+ Call, MustPreserveNullness))
+ return Call->getArgOperand(0);
+ return nullptr;
+}
+
+bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
+ const CallBase *Call, bool MustPreserveNullness) {
+ return Call->getIntrinsicID() == Intrinsic::launder_invariant_group ||
+ Call->getIntrinsicID() == Intrinsic::strip_invariant_group ||
+ Call->getIntrinsicID() == Intrinsic::aarch64_irg ||
+ Call->getIntrinsicID() == Intrinsic::aarch64_tagp ||
+ (!MustPreserveNullness &&
+ Call->getIntrinsicID() == Intrinsic::ptrmask);
+}
+
+/// \p PN defines a loop-variant pointer to an object. Check if the
+/// previous iteration of the loop was referring to the same object as \p PN.
+static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
+ const LoopInfo *LI) {
+ // Find the loop-defined value.
+ Loop *L = LI->getLoopFor(PN->getParent());
+ if (PN->getNumIncomingValues() != 2)
+ return true;
+
+ // Find the value from previous iteration.
+ auto *PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(0));
+ if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
+ PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(1));
+ if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
+ return true;
+
+ // If a new pointer is loaded in the loop, the pointer references a different
+ // object in every iteration. E.g.:
+ // for (i)
+ // int *p = a[i];
+ // ...
+ if (auto *Load = dyn_cast<LoadInst>(PrevValue))
+ if (!L->isLoopInvariant(Load->getPointerOperand()))
+ return false;
+ return true;
+}
+
+Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL,
+ unsigned MaxLookup) {
+ if (!V->getType()->isPointerTy())
+ return V;
+ for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
+ if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+ V = GEP->getPointerOperand();
+ } else if (Operator::getOpcode(V) == Instruction::BitCast ||
+ Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
+ V = cast<Operator>(V)->getOperand(0);
+ } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+ if (GA->isInterposable())
+ return V;
+ V = GA->getAliasee();
+ } else if (isa<AllocaInst>(V)) {
+ // An alloca can't be further simplified.
+ return V;
+ } else {
+ if (auto *Call = dyn_cast<CallBase>(V)) {
+ // CaptureTracking can know about special capturing properties of some
+ // intrinsics like launder.invariant.group, that can't be expressed with
+ // the attributes, but have properties like returning aliasing pointer.
+ // Because some analysis may assume that nocaptured pointer is not
+ // returned from some special intrinsic (because function would have to
+ // be marked with returns attribute), it is crucial to use this function
+ // because it should be in sync with CaptureTracking. Not using it may
+ // cause weird miscompilations where 2 aliasing pointers are assumed to
+ // noalias.
+ if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) {
+ V = RP;
+ continue;
+ }
+ }
+
+ // See if InstructionSimplify knows any relevant tricks.
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ // TODO: Acquire a DominatorTree and AssumptionCache and use them.
+ if (Value *Simplified = SimplifyInstruction(I, {DL, I})) {
+ V = Simplified;
+ continue;
+ }
+
+ return V;
+ }
+ assert(V->getType()->isPointerTy() && "Unexpected operand type!");
+ }
+ return V;
+}
+
+void llvm::GetUnderlyingObjects(const Value *V,
+ SmallVectorImpl<const Value *> &Objects,
+ const DataLayout &DL, LoopInfo *LI,
+ unsigned MaxLookup) {
+ SmallPtrSet<const Value *, 4> Visited;
+ SmallVector<const Value *, 4> Worklist;
+ Worklist.push_back(V);
+ do {
+ const Value *P = Worklist.pop_back_val();
+ P = GetUnderlyingObject(P, DL, MaxLookup);
+
+ if (!Visited.insert(P).second)
+ continue;
+
+ if (auto *SI = dyn_cast<SelectInst>(P)) {
+ Worklist.push_back(SI->getTrueValue());
+ Worklist.push_back(SI->getFalseValue());
+ continue;
+ }
+
+ if (auto *PN = dyn_cast<PHINode>(P)) {
+ // If this PHI changes the underlying object in every iteration of the
+ // loop, don't look through it. Consider:
+ // int **A;
+ // for (i) {
+ // Prev = Curr; // Prev = PHI (Prev_0, Curr)
+ // Curr = A[i];
+ // *Prev, *Curr;
+ //
+ // Prev is tracking Curr one iteration behind so they refer to different
+ // underlying objects.
+ if (!LI || !LI->isLoopHeader(PN->getParent()) ||
+ isSameUnderlyingObjectInLoop(PN, LI))
+ for (Value *IncValue : PN->incoming_values())
+ Worklist.push_back(IncValue);
+ continue;
+ }
+
+ Objects.push_back(P);
+ } while (!Worklist.empty());
+}
+
+/// This is the function that does the work of looking through basic
+/// ptrtoint+arithmetic+inttoptr sequences.
+static const Value *getUnderlyingObjectFromInt(const Value *V) {
+ do {
+ if (const Operator *U = dyn_cast<Operator>(V)) {
+ // If we find a ptrtoint, we can transfer control back to the
+ // regular getUnderlyingObjectFromInt.
+ if (U->getOpcode() == Instruction::PtrToInt)
+ return U->getOperand(0);
+ // If we find an add of a constant, a multiplied value, or a phi, it's
+ // likely that the other operand will lead us to the base
+ // object. We don't have to worry about the case where the
+ // object address is somehow being computed by the multiply,
+ // because our callers only care when the result is an
+ // identifiable object.
+ if (U->getOpcode() != Instruction::Add ||
+ (!isa<ConstantInt>(U->getOperand(1)) &&
+ Operator::getOpcode(U->getOperand(1)) != Instruction::Mul &&
+ !isa<PHINode>(U->getOperand(1))))
+ return V;
+ V = U->getOperand(0);
+ } else {
+ return V;
+ }
+ assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
+ } while (true);
+}
+
+/// This is a wrapper around GetUnderlyingObjects and adds support for basic
+/// ptrtoint+arithmetic+inttoptr sequences.
+/// It returns false if unidentified object is found in GetUnderlyingObjects.
+bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
+ SmallVectorImpl<Value *> &Objects,
+ const DataLayout &DL) {
+ SmallPtrSet<const Value *, 16> Visited;
+ SmallVector<const Value *, 4> Working(1, V);
+ do {
+ V = Working.pop_back_val();
+
+ SmallVector<const Value *, 4> Objs;
+ GetUnderlyingObjects(V, Objs, DL);
+
+ for (const Value *V : Objs) {
+ if (!Visited.insert(V).second)
+ continue;
+ if (Operator::getOpcode(V) == Instruction::IntToPtr) {
+ const Value *O =
+ getUnderlyingObjectFromInt(cast<User>(V)->getOperand(0));
+ if (O->getType()->isPointerTy()) {
+ Working.push_back(O);
+ continue;
+ }
+ }
+ // If GetUnderlyingObjects fails to find an identifiable object,
+ // getUnderlyingObjectsForCodeGen also fails for safety.
+ if (!isIdentifiedObject(V)) {
+ Objects.clear();
+ return false;
+ }
+ Objects.push_back(const_cast<Value *>(V));
+ }
+ } while (!Working.empty());
+ return true;
+}
+
+/// Return true if the only users of this pointer are lifetime markers.
+bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
+ for (const User *U : V->users()) {
+ const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
+ if (!II) return false;
+
+ if (!II->isLifetimeStartOrEnd())
+ return false;
+ }
+ return true;
+}
+
+bool llvm::mustSuppressSpeculation(const LoadInst &LI) {
+ if (!LI.isUnordered())
+ return true;
+ const Function &F = *LI.getFunction();
+ // Speculative load may create a race that did not exist in the source.
+ return F.hasFnAttribute(Attribute::SanitizeThread) ||
+ // Speculative load may load data from dirty regions.
+ F.hasFnAttribute(Attribute::SanitizeAddress) ||
+ F.hasFnAttribute(Attribute::SanitizeHWAddress);
+}
+
+
+bool llvm::isSafeToSpeculativelyExecute(const Value *V,
+ const Instruction *CtxI,
+ const DominatorTree *DT) {
+ const Operator *Inst = dyn_cast<Operator>(V);
+ if (!Inst)
+ return false;
+
+ for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i)
+ if (Constant *C = dyn_cast<Constant>(Inst->getOperand(i)))
+ if (C->canTrap())
+ return false;
+
+ switch (Inst->getOpcode()) {
+ default:
+ return true;
+ case Instruction::UDiv:
+ case Instruction::URem: {
+ // x / y is undefined if y == 0.
+ const APInt *V;
+ if (match(Inst->getOperand(1), m_APInt(V)))
+ return *V != 0;
+ return false;
+ }
+ case Instruction::SDiv:
+ case Instruction::SRem: {
+ // x / y is undefined if y == 0 or x == INT_MIN and y == -1
+ const APInt *Numerator, *Denominator;
+ if (!match(Inst->getOperand(1), m_APInt(Denominator)))
+ return false;
+ // We cannot hoist this division if the denominator is 0.
+ if (*Denominator == 0)
+ return false;
+ // It's safe to hoist if the denominator is not 0 or -1.
+ if (*Denominator != -1)
+ return true;
+ // At this point we know that the denominator is -1. It is safe to hoist as
+ // long we know that the numerator is not INT_MIN.
+ if (match(Inst->getOperand(0), m_APInt(Numerator)))
+ return !Numerator->isMinSignedValue();
+ // The numerator *might* be MinSignedValue.
+ return false;
+ }
+ case Instruction::Load: {
+ const LoadInst *LI = cast<LoadInst>(Inst);
+ if (mustSuppressSpeculation(*LI))
+ return false;
+ const DataLayout &DL = LI->getModule()->getDataLayout();
+ return isDereferenceableAndAlignedPointer(
+ LI->getPointerOperand(), LI->getType(), MaybeAlign(LI->getAlignment()),
+ DL, CtxI, DT);
+ }
+ case Instruction::Call: {
+ auto *CI = cast<const CallInst>(Inst);
+ const Function *Callee = CI->getCalledFunction();
+
+ // The called function could have undefined behavior or side-effects, even
+ // if marked readnone nounwind.
+ return Callee && Callee->isSpeculatable();
+ }
+ case Instruction::VAArg:
+ case Instruction::Alloca:
+ case Instruction::Invoke:
+ case Instruction::CallBr:
+ case Instruction::PHI:
+ case Instruction::Store:
+ case Instruction::Ret:
+ case Instruction::Br:
+ case Instruction::IndirectBr:
+ case Instruction::Switch:
+ case Instruction::Unreachable:
+ case Instruction::Fence:
+ case Instruction::AtomicRMW:
+ case Instruction::AtomicCmpXchg:
+ case Instruction::LandingPad:
+ case Instruction::Resume:
+ case Instruction::CatchSwitch:
+ case Instruction::CatchPad:
+ case Instruction::CatchRet:
+ case Instruction::CleanupPad:
+ case Instruction::CleanupRet:
+ return false; // Misc instructions which have effects
+ }
+}
+
+bool llvm::mayBeMemoryDependent(const Instruction &I) {
+ return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I);
+}
+
+/// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
+static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
+ switch (OR) {
+ case ConstantRange::OverflowResult::MayOverflow:
+ return OverflowResult::MayOverflow;
+ case ConstantRange::OverflowResult::AlwaysOverflowsLow:
+ return OverflowResult::AlwaysOverflowsLow;
+ case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
+ return OverflowResult::AlwaysOverflowsHigh;
+ case ConstantRange::OverflowResult::NeverOverflows:
+ return OverflowResult::NeverOverflows;
+ }
+ llvm_unreachable("Unknown OverflowResult");
+}
+
+/// Combine constant ranges from computeConstantRange() and computeKnownBits().
+static ConstantRange computeConstantRangeIncludingKnownBits(
+ const Value *V, bool ForSigned, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+ OptimizationRemarkEmitter *ORE = nullptr, bool UseInstrInfo = true) {
+ KnownBits Known = computeKnownBits(
+ V, DL, Depth, AC, CxtI, DT, ORE, UseInstrInfo);
+ ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
+ ConstantRange CR2 = computeConstantRange(V, UseInstrInfo);
+ ConstantRange::PreferredRangeType RangeType =
+ ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
+ return CR1.intersectWith(CR2, RangeType);
+}
+
+OverflowResult llvm::computeOverflowForUnsignedMul(
+ const Value *LHS, const Value *RHS, const DataLayout &DL,
+ AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+ bool UseInstrInfo) {
+ KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
+ ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
+ return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
+}
+
+OverflowResult
+llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
+ const DataLayout &DL, AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT, bool UseInstrInfo) {
+ // Multiplying n * m significant bits yields a result of n + m significant
+ // bits. If the total number of significant bits does not exceed the
+ // result bit width (minus 1), there is no overflow.
+ // This means if we have enough leading sign bits in the operands
+ // we can guarantee that the result does not overflow.
+ // Ref: "Hacker's Delight" by Henry Warren
+ unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
+
+ // Note that underestimating the number of sign bits gives a more
+ // conservative answer.
+ unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) +
+ ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT);
+
+ // First handle the easy case: if we have enough sign bits there's
+ // definitely no overflow.
+ if (SignBits > BitWidth + 1)
+ return OverflowResult::NeverOverflows;
+
+ // There are two ambiguous cases where there can be no overflow:
+ // SignBits == BitWidth + 1 and
+ // SignBits == BitWidth
+ // The second case is difficult to check, therefore we only handle the
+ // first case.
+ if (SignBits == BitWidth + 1) {
+ // It overflows only when both arguments are negative and the true
+ // product is exactly the minimum negative number.
+ // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
+ // For simplicity we just check if at least one side is not negative.
+ KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
+ return OverflowResult::NeverOverflows;
+ }
+ return OverflowResult::MayOverflow;
+}
+
+OverflowResult llvm::computeOverflowForUnsignedAdd(
+ const Value *LHS, const Value *RHS, const DataLayout &DL,
+ AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+ bool UseInstrInfo) {
+ ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
+ LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
+ RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT,
+ nullptr, UseInstrInfo);
+ return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
+}
+
+static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
+ const Value *RHS,
+ const AddOperator *Add,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ if (Add && Add->hasNoSignedWrap()) {
+ return OverflowResult::NeverOverflows;
+ }
+
+ // If LHS and RHS each have at least two sign bits, the addition will look
+ // like
+ //
+ // XX..... +
+ // YY.....
+ //
+ // If the carry into the most significant position is 0, X and Y can't both
+ // be 1 and therefore the carry out of the addition is also 0.
+ //
+ // If the carry into the most significant position is 1, X and Y can't both
+ // be 0 and therefore the carry out of the addition is also 1.
+ //
+ // Since the carry into the most significant position is always equal to
+ // the carry out of the addition, there is no signed overflow.
+ if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
+ ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+ return OverflowResult::NeverOverflows;
+
+ ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
+ LHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT);
+ ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
+ RHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT);
+ OverflowResult OR =
+ mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
+ if (OR != OverflowResult::MayOverflow)
+ return OR;
+
+ // The remaining code needs Add to be available. Early returns if not so.
+ if (!Add)
+ return OverflowResult::MayOverflow;
+
+ // If the sign of Add is the same as at least one of the operands, this add
+ // CANNOT overflow. If this can be determined from the known bits of the
+ // operands the above signedAddMayOverflow() check will have already done so.
+ // The only other way to improve on the known bits is from an assumption, so
+ // call computeKnownBitsFromAssume() directly.
+ bool LHSOrRHSKnownNonNegative =
+ (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
+ bool LHSOrRHSKnownNegative =
+ (LHSRange.isAllNegative() || RHSRange.isAllNegative());
+ if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
+ KnownBits AddKnown(LHSRange.getBitWidth());
+ computeKnownBitsFromAssume(
+ Add, AddKnown, /*Depth=*/0, Query(DL, AC, CxtI, DT, true));
+ if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
+ (AddKnown.isNegative() && LHSOrRHSKnownNegative))
+ return OverflowResult::NeverOverflows;
+ }
+
+ return OverflowResult::MayOverflow;
+}
+
+OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
+ const Value *RHS,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
+ LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT);
+ ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
+ RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT);
+ return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
+}
+
+OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
+ const Value *RHS,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ // If LHS and RHS each have at least two sign bits, the subtraction
+ // cannot overflow.
+ if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
+ ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+ return OverflowResult::NeverOverflows;
+
+ ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
+ LHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT);
+ ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
+ RHS, /*ForSigned=*/true, DL, /*Depth=*/0, AC, CxtI, DT);
+ return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
+}
+
+bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
+ const DominatorTree &DT) {
+ SmallVector<const BranchInst *, 2> GuardingBranches;
+ SmallVector<const ExtractValueInst *, 2> Results;
+
+ for (const User *U : WO->users()) {
+ if (const auto *EVI = dyn_cast<ExtractValueInst>(U)) {
+ assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
+
+ if (EVI->getIndices()[0] == 0)
+ Results.push_back(EVI);
+ else {
+ assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
+
+ for (const auto *U : EVI->users())
+ if (const auto *B = dyn_cast<BranchInst>(U)) {
+ assert(B->isConditional() && "How else is it using an i1?");
+ GuardingBranches.push_back(B);
+ }
+ }
+ } else {
+ // We are using the aggregate directly in a way we don't want to analyze
+ // here (storing it to a global, say).
+ return false;
+ }
+ }
+
+ auto AllUsesGuardedByBranch = [&](const BranchInst *BI) {
+ BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1));
+ if (!NoWrapEdge.isSingleEdge())
+ return false;
+
+ // Check if all users of the add are provably no-wrap.
+ for (const auto *Result : Results) {
+ // If the extractvalue itself is not executed on overflow, the we don't
+ // need to check each use separately, since domination is transitive.
+ if (DT.dominates(NoWrapEdge, Result->getParent()))
+ continue;
+
+ for (auto &RU : Result->uses())
+ if (!DT.dominates(NoWrapEdge, RU))
+ return false;
+ }
+
+ return true;
+ };
+
+ return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch);
+}
+
+
+OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
+ Add, DL, AC, CxtI, DT);
+}
+
+OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
+ const Value *RHS,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ const Instruction *CxtI,
+ const DominatorTree *DT) {
+ return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT);
+}
+
+bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
+ // Note: An atomic operation isn't guaranteed to return in a reasonable amount
+ // of time because it's possible for another thread to interfere with it for an
+ // arbitrary length of time, but programs aren't allowed to rely on that.
+
+ // If there is no successor, then execution can't transfer to it.
+ if (const auto *CRI = dyn_cast<CleanupReturnInst>(I))
+ return !CRI->unwindsToCaller();
+ if (const auto *CatchSwitch = dyn_cast<CatchSwitchInst>(I))
+ return !CatchSwitch->unwindsToCaller();
+ if (isa<ResumeInst>(I))
+ return false;
+ if (isa<ReturnInst>(I))
+ return false;
+ if (isa<UnreachableInst>(I))
+ return false;
+
+ // Calls can throw, or contain an infinite loop, or kill the process.
+ if (auto CS = ImmutableCallSite(I)) {
+ // Call sites that throw have implicit non-local control flow.
+ if (!CS.doesNotThrow())
+ return false;
+
+ // A function which doens't throw and has "willreturn" attribute will
+ // always return.
+ if (CS.hasFnAttr(Attribute::WillReturn))
+ return true;
+
+ // Non-throwing call sites can loop infinitely, call exit/pthread_exit
+ // etc. and thus not return. However, LLVM already assumes that
+ //
+ // - Thread exiting actions are modeled as writes to memory invisible to
+ // the program.
+ //
+ // - Loops that don't have side effects (side effects are volatile/atomic
+ // stores and IO) always terminate (see http://llvm.org/PR965).
+ // Furthermore IO itself is also modeled as writes to memory invisible to
+ // the program.
+ //
+ // We rely on those assumptions here, and use the memory effects of the call
+ // target as a proxy for checking that it always returns.
+
+ // FIXME: This isn't aggressive enough; a call which only writes to a global
+ // is guaranteed to return.
+ return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory();
+ }
+
+ // Other instructions return normally.
+ return true;
+}
+
+bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
+ // TODO: This is slightly conservative for invoke instruction since exiting
+ // via an exception *is* normal control for them.
+ for (auto I = BB->begin(), E = BB->end(); I != E; ++I)
+ if (!isGuaranteedToTransferExecutionToSuccessor(&*I))
+ return false;
+ return true;
+}
+
+bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
+ const Loop *L) {
+ // The loop header is guaranteed to be executed for every iteration.
+ //
+ // FIXME: Relax this constraint to cover all basic blocks that are
+ // guaranteed to be executed at every iteration.
+ if (I->getParent() != L->getHeader()) return false;
+
+ for (const Instruction &LI : *L->getHeader()) {
+ if (&LI == I) return true;
+ if (!isGuaranteedToTransferExecutionToSuccessor(&LI)) return false;
+ }
+ llvm_unreachable("Instruction not contained in its own parent basic block.");
+}
+
+bool llvm::propagatesFullPoison(const Instruction *I) {
+ // TODO: This should include all instructions apart from phis, selects and
+ // call-like instructions.
+ switch (I->getOpcode()) {
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Xor:
+ case Instruction::Trunc:
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
+ case Instruction::Mul:
+ case Instruction::Shl:
+ case Instruction::GetElementPtr:
+ // These operations all propagate poison unconditionally. Note that poison
+ // is not any particular value, so xor or subtraction of poison with
+ // itself still yields poison, not zero.
+ return true;
+
+ case Instruction::AShr:
+ case Instruction::SExt:
+ // For these operations, one bit of the input is replicated across
+ // multiple output bits. A replicated poison bit is still poison.
+ return true;
+
+ case Instruction::ICmp:
+ // Comparing poison with any value yields poison. This is why, for
+ // instance, x s< (x +nsw 1) can be folded to true.
+ return true;
+
+ default:
+ return false;
+ }
+}
+
+const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) {
+ switch (I->getOpcode()) {
+ case Instruction::Store:
+ return cast<StoreInst>(I)->getPointerOperand();
+
+ case Instruction::Load:
+ return cast<LoadInst>(I)->getPointerOperand();
+
+ case Instruction::AtomicCmpXchg:
+ return cast<AtomicCmpXchgInst>(I)->getPointerOperand();
+
+ case Instruction::AtomicRMW:
+ return cast<AtomicRMWInst>(I)->getPointerOperand();
+
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ return I->getOperand(1);
+
+ default:
+ // Note: It's really tempting to think that a conditional branch or
+ // switch should be listed here, but that's incorrect. It's not
+ // branching off of poison which is UB, it is executing a side effecting
+ // instruction which follows the branch.
+ return nullptr;
+ }
+}
+
+bool llvm::mustTriggerUB(const Instruction *I,
+ const SmallSet<const Value *, 16>& KnownPoison) {
+ auto *NotPoison = getGuaranteedNonFullPoisonOp(I);
+ return (NotPoison && KnownPoison.count(NotPoison));
+}
+
+
+bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) {
+ // We currently only look for uses of poison values within the same basic
+ // block, as that makes it easier to guarantee that the uses will be
+ // executed given that PoisonI is executed.
+ //
+ // FIXME: Expand this to consider uses beyond the same basic block. To do
+ // this, look out for the distinction between post-dominance and strong
+ // post-dominance.
+ const BasicBlock *BB = PoisonI->getParent();
+
+ // Set of instructions that we have proved will yield poison if PoisonI
+ // does.
+ SmallSet<const Value *, 16> YieldsPoison;
+ SmallSet<const BasicBlock *, 4> Visited;
+ YieldsPoison.insert(PoisonI);
+ Visited.insert(PoisonI->getParent());
+
+ BasicBlock::const_iterator Begin = PoisonI->getIterator(), End = BB->end();
+
+ unsigned Iter = 0;
+ while (Iter++ < MaxDepth) {
+ for (auto &I : make_range(Begin, End)) {
+ if (&I != PoisonI) {
+ if (mustTriggerUB(&I, YieldsPoison))
+ return true;
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I))
+ return false;
+ }
+
+ // Mark poison that propagates from I through uses of I.
+ if (YieldsPoison.count(&I)) {
+ for (const User *User : I.users()) {
+ const Instruction *UserI = cast<Instruction>(User);
+ if (propagatesFullPoison(UserI))
+ YieldsPoison.insert(User);
+ }
+ }
+ }
+
+ if (auto *NextBB = BB->getSingleSuccessor()) {
+ if (Visited.insert(NextBB).second) {
+ BB = NextBB;
+ Begin = BB->getFirstNonPHI()->getIterator();
+ End = BB->end();
+ continue;
+ }
+ }
+
+ break;
+ }
+ return false;
+}
+
+static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
+ if (FMF.noNaNs())
+ return true;
+
+ if (auto *C = dyn_cast<ConstantFP>(V))
+ return !C->isNaN();
+
+ if (auto *C = dyn_cast<ConstantDataVector>(V)) {
+ if (!C->getElementType()->isFloatingPointTy())
+ return false;
+ for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
+ if (C->getElementAsAPFloat(I).isNaN())
+ return false;
+ }
+ return true;
+ }
+
+ return false;
+}
+
+static bool isKnownNonZero(const Value *V) {
+ if (auto *C = dyn_cast<ConstantFP>(V))
+ return !C->isZero();
+
+ if (auto *C = dyn_cast<ConstantDataVector>(V)) {
+ if (!C->getElementType()->isFloatingPointTy())
+ return false;
+ for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
+ if (C->getElementAsAPFloat(I).isZero())
+ return false;
+ }
+ return true;
+ }
+
+ return false;
+}
+
+/// Match clamp pattern for float types without care about NaNs or signed zeros.
+/// Given non-min/max outer cmp/select from the clamp pattern this
+/// function recognizes if it can be substitued by a "canonical" min/max
+/// pattern.
+static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
+ Value *CmpLHS, Value *CmpRHS,
+ Value *TrueVal, Value *FalseVal,
+ Value *&LHS, Value *&RHS) {
+ // Try to match
+ // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
+ // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
+ // and return description of the outer Max/Min.
+
+ // First, check if select has inverse order:
+ if (CmpRHS == FalseVal) {
+ std::swap(TrueVal, FalseVal);
+ Pred = CmpInst::getInversePredicate(Pred);
+ }
+
+ // Assume success now. If there's no match, callers should not use these anyway.
+ LHS = TrueVal;
+ RHS = FalseVal;
+
+ const APFloat *FC1;
+ if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite())
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ const APFloat *FC2;
+ switch (Pred) {
+ case CmpInst::FCMP_OLT:
+ case CmpInst::FCMP_OLE:
+ case CmpInst::FCMP_ULT:
+ case CmpInst::FCMP_ULE:
+ if (match(FalseVal,
+ m_CombineOr(m_OrdFMin(m_Specific(CmpLHS), m_APFloat(FC2)),
+ m_UnordFMin(m_Specific(CmpLHS), m_APFloat(FC2)))) &&
+ FC1->compare(*FC2) == APFloat::cmpResult::cmpLessThan)
+ return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false};
+ break;
+ case CmpInst::FCMP_OGT:
+ case CmpInst::FCMP_OGE:
+ case CmpInst::FCMP_UGT:
+ case CmpInst::FCMP_UGE:
+ if (match(FalseVal,
+ m_CombineOr(m_OrdFMax(m_Specific(CmpLHS), m_APFloat(FC2)),
+ m_UnordFMax(m_Specific(CmpLHS), m_APFloat(FC2)))) &&
+ FC1->compare(*FC2) == APFloat::cmpResult::cmpGreaterThan)
+ return {SPF_FMINNUM, SPNB_RETURNS_ANY, false};
+ break;
+ default:
+ break;
+ }
+
+ return {SPF_UNKNOWN, SPNB_NA, false};
+}
+
+/// Recognize variations of:
+/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
+static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
+ Value *CmpLHS, Value *CmpRHS,
+ Value *TrueVal, Value *FalseVal) {
+ // Swap the select operands and predicate to match the patterns below.
+ if (CmpRHS != TrueVal) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(TrueVal, FalseVal);
+ }
+ const APInt *C1;
+ if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) {
+ const APInt *C2;
+ // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
+ if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) &&
+ C1->slt(*C2) && Pred == CmpInst::ICMP_SLT)
+ return {SPF_SMAX, SPNB_NA, false};
+
+ // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
+ if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) &&
+ C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT)
+ return {SPF_SMIN, SPNB_NA, false};
+
+ // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
+ if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) &&
+ C1->ult(*C2) && Pred == CmpInst::ICMP_ULT)
+ return {SPF_UMAX, SPNB_NA, false};
+
+ // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
+ if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) &&
+ C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT)
+ return {SPF_UMIN, SPNB_NA, false};
+ }
+ return {SPF_UNKNOWN, SPNB_NA, false};
+}
+
+/// Recognize variations of:
+/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
+static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
+ Value *CmpLHS, Value *CmpRHS,
+ Value *TVal, Value *FVal,
+ unsigned Depth) {
+ // TODO: Allow FP min/max with nnan/nsz.
+ assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
+
+ Value *A = nullptr, *B = nullptr;
+ SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1);
+ if (!SelectPatternResult::isMinOrMax(L.Flavor))
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ Value *C = nullptr, *D = nullptr;
+ SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1);
+ if (L.Flavor != R.Flavor)
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ // We have something like: x Pred y ? min(a, b) : min(c, d).
+ // Try to match the compare to the min/max operations of the select operands.
+ // First, make sure we have the right compare predicate.
+ switch (L.Flavor) {
+ case SPF_SMIN:
+ if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(CmpLHS, CmpRHS);
+ }
+ if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
+ break;
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ case SPF_SMAX:
+ if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(CmpLHS, CmpRHS);
+ }
+ if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
+ break;
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ case SPF_UMIN:
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(CmpLHS, CmpRHS);
+ }
+ if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
+ break;
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ case SPF_UMAX:
+ if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(CmpLHS, CmpRHS);
+ }
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
+ break;
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ default:
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ }
+
+ // If there is a common operand in the already matched min/max and the other
+ // min/max operands match the compare operands (either directly or inverted),
+ // then this is min/max of the same flavor.
+
+ // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
+ // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
+ if (D == B) {
+ if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
+ match(A, m_Not(m_Specific(CmpRHS)))))
+ return {L.Flavor, SPNB_NA, false};
+ }
+ // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
+ // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
+ if (C == B) {
+ if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
+ match(A, m_Not(m_Specific(CmpRHS)))))
+ return {L.Flavor, SPNB_NA, false};
+ }
+ // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
+ // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
+ if (D == A) {
+ if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
+ match(B, m_Not(m_Specific(CmpRHS)))))
+ return {L.Flavor, SPNB_NA, false};
+ }
+ // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
+ // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
+ if (C == A) {
+ if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
+ match(B, m_Not(m_Specific(CmpRHS)))))
+ return {L.Flavor, SPNB_NA, false};
+ }
+
+ return {SPF_UNKNOWN, SPNB_NA, false};
+}
+
+/// Match non-obvious integer minimum and maximum sequences.
+static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
+ Value *CmpLHS, Value *CmpRHS,
+ Value *TrueVal, Value *FalseVal,
+ Value *&LHS, Value *&RHS,
+ unsigned Depth) {
+ // Assume success. If there's no match, callers should not use these anyway.
+ LHS = TrueVal;
+ RHS = FalseVal;
+
+ SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
+ if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
+ return SPR;
+
+ SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth);
+ if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
+ return SPR;
+
+ if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ // Z = X -nsw Y
+ // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0)
+ // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0)
+ if (match(TrueVal, m_Zero()) &&
+ match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS))))
+ return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false};
+
+ // Z = X -nsw Y
+ // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0)
+ // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0)
+ if (match(FalseVal, m_Zero()) &&
+ match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS))))
+ return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false};
+
+ const APInt *C1;
+ if (!match(CmpRHS, m_APInt(C1)))
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ // An unsigned min/max can be written with a signed compare.
+ const APInt *C2;
+ if ((CmpLHS == TrueVal && match(FalseVal, m_APInt(C2))) ||
+ (CmpLHS == FalseVal && match(TrueVal, m_APInt(C2)))) {
+ // Is the sign bit set?
+ // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
+ // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
+ if (Pred == CmpInst::ICMP_SLT && C1->isNullValue() &&
+ C2->isMaxSignedValue())
+ return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
+
+ // Is the sign bit clear?
+ // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
+ // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
+ if (Pred == CmpInst::ICMP_SGT && C1->isAllOnesValue() &&
+ C2->isMinSignedValue())
+ return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
+ }
+
+ // Look through 'not' ops to find disguised signed min/max.
+ // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C)
+ // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C)
+ if (match(TrueVal, m_Not(m_Specific(CmpLHS))) &&
+ match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2)
+ return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false};
+
+ // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X)
+ // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X)
+ if (match(FalseVal, m_Not(m_Specific(CmpLHS))) &&
+ match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2)
+ return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false};
+
+ return {SPF_UNKNOWN, SPNB_NA, false};
+}
+
+bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW) {
+ assert(X && Y && "Invalid operand");
+
+ // X = sub (0, Y) || X = sub nsw (0, Y)
+ if ((!NeedNSW && match(X, m_Sub(m_ZeroInt(), m_Specific(Y)))) ||
+ (NeedNSW && match(X, m_NSWSub(m_ZeroInt(), m_Specific(Y)))))
+ return true;
+
+ // Y = sub (0, X) || Y = sub nsw (0, X)
+ if ((!NeedNSW && match(Y, m_Sub(m_ZeroInt(), m_Specific(X)))) ||
+ (NeedNSW && match(Y, m_NSWSub(m_ZeroInt(), m_Specific(X)))))
+ return true;
+
+ // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
+ Value *A, *B;
+ return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) &&
+ match(Y, m_Sub(m_Specific(B), m_Specific(A))))) ||
+ (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) &&
+ match(Y, m_NSWSub(m_Specific(B), m_Specific(A)))));
+}
+
+static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
+ FastMathFlags FMF,
+ Value *CmpLHS, Value *CmpRHS,
+ Value *TrueVal, Value *FalseVal,
+ Value *&LHS, Value *&RHS,
+ unsigned Depth) {
+ if (CmpInst::isFPPredicate(Pred)) {
+ // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
+ // 0.0 operand, set the compare's 0.0 operands to that same value for the
+ // purpose of identifying min/max. Disregard vector constants with undefined
+ // elements because those can not be back-propagated for analysis.
+ Value *OutputZeroVal = nullptr;
+ if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
+ !cast<Constant>(TrueVal)->containsUndefElement())
+ OutputZeroVal = TrueVal;
+ else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
+ !cast<Constant>(FalseVal)->containsUndefElement())
+ OutputZeroVal = FalseVal;
+
+ if (OutputZeroVal) {
+ if (match(CmpLHS, m_AnyZeroFP()))
+ CmpLHS = OutputZeroVal;
+ if (match(CmpRHS, m_AnyZeroFP()))
+ CmpRHS = OutputZeroVal;
+ }
+ }
+
+ LHS = CmpLHS;
+ RHS = CmpRHS;
+
+ // Signed zero may return inconsistent results between implementations.
+ // (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
+ // minNum(0.0, -0.0) // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
+ // Therefore, we behave conservatively and only proceed if at least one of the
+ // operands is known to not be zero or if we don't care about signed zero.
+ switch (Pred) {
+ default: break;
+ // FIXME: Include OGT/OLT/UGT/ULT.
+ case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
+ case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
+ if (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
+ !isKnownNonZero(CmpRHS))
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ }
+
+ SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
+ bool Ordered = false;
+
+ // When given one NaN and one non-NaN input:
+ // - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
+ // - A simple C99 (a < b ? a : b) construction will return 'b' (as the
+ // ordered comparison fails), which could be NaN or non-NaN.
+ // so here we discover exactly what NaN behavior is required/accepted.
+ if (CmpInst::isFPPredicate(Pred)) {
+ bool LHSSafe = isKnownNonNaN(CmpLHS, FMF);
+ bool RHSSafe = isKnownNonNaN(CmpRHS, FMF);
+
+ if (LHSSafe && RHSSafe) {
+ // Both operands are known non-NaN.
+ NaNBehavior = SPNB_RETURNS_ANY;
+ } else if (CmpInst::isOrdered(Pred)) {
+ // An ordered comparison will return false when given a NaN, so it
+ // returns the RHS.
+ Ordered = true;
+ if (LHSSafe)
+ // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
+ NaNBehavior = SPNB_RETURNS_NAN;
+ else if (RHSSafe)
+ NaNBehavior = SPNB_RETURNS_OTHER;
+ else
+ // Completely unsafe.
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ } else {
+ Ordered = false;
+ // An unordered comparison will return true when given a NaN, so it
+ // returns the LHS.
+ if (LHSSafe)
+ // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
+ NaNBehavior = SPNB_RETURNS_OTHER;
+ else if (RHSSafe)
+ NaNBehavior = SPNB_RETURNS_NAN;
+ else
+ // Completely unsafe.
+ return {SPF_UNKNOWN, SPNB_NA, false};
+ }
+ }
+
+ if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
+ std::swap(CmpLHS, CmpRHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ if (NaNBehavior == SPNB_RETURNS_NAN)
+ NaNBehavior = SPNB_RETURNS_OTHER;
+ else if (NaNBehavior == SPNB_RETURNS_OTHER)
+ NaNBehavior = SPNB_RETURNS_NAN;
+ Ordered = !Ordered;
+ }
+
+ // ([if]cmp X, Y) ? X : Y
+ if (TrueVal == CmpLHS && FalseVal == CmpRHS) {
+ switch (Pred) {
+ default: return {SPF_UNKNOWN, SPNB_NA, false}; // Equality.
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE: return {SPF_UMAX, SPNB_NA, false};
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE: return {SPF_SMAX, SPNB_NA, false};
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE: return {SPF_UMIN, SPNB_NA, false};
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE: return {SPF_SMIN, SPNB_NA, false};
+ case FCmpInst::FCMP_UGT:
+ case FCmpInst::FCMP_UGE:
+ case FCmpInst::FCMP_OGT:
+ case FCmpInst::FCMP_OGE: return {SPF_FMAXNUM, NaNBehavior, Ordered};
+ case FCmpInst::FCMP_ULT:
+ case FCmpInst::FCMP_ULE:
+ case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_OLE: return {SPF_FMINNUM, NaNBehavior, Ordered};
+ }
+ }
+
+ if (isKnownNegation(TrueVal, FalseVal)) {
+ // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
+ // match against either LHS or sext(LHS).
+ auto MaybeSExtCmpLHS =
+ m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS)));
+ auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes());
+ auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One());
+ if (match(TrueVal, MaybeSExtCmpLHS)) {
+ // Set the return values. If the compare uses the negated value (-X >s 0),
+ // swap the return values because the negated value is always 'RHS'.
+ LHS = TrueVal;
+ RHS = FalseVal;
+ if (match(CmpLHS, m_Neg(m_Specific(FalseVal))))
+ std::swap(LHS, RHS);
+
+ // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
+ // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
+ if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
+ return {SPF_ABS, SPNB_NA, false};
+
+ // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
+ if (Pred == ICmpInst::ICMP_SGE && match(CmpRHS, ZeroOrOne))
+ return {SPF_ABS, SPNB_NA, false};
+
+ // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
+ // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
+ if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
+ return {SPF_NABS, SPNB_NA, false};
+ }
+ else if (match(FalseVal, MaybeSExtCmpLHS)) {
+ // Set the return values. If the compare uses the negated value (-X >s 0),
+ // swap the return values because the negated value is always 'RHS'.
+ LHS = FalseVal;
+ RHS = TrueVal;
+ if (match(CmpLHS, m_Neg(m_Specific(TrueVal))))
+ std::swap(LHS, RHS);
+
+ // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
+ // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
+ if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
+ return {SPF_NABS, SPNB_NA, false};
+
+ // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
+ // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
+ if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
+ return {SPF_ABS, SPNB_NA, false};
+ }
+ }
+
+ if (CmpInst::isIntPredicate(Pred))
+ return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
+
+ // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
+ // may return either -0.0 or 0.0, so fcmp/select pair has stricter
+ // semantics than minNum. Be conservative in such case.
+ if (NaNBehavior != SPNB_RETURNS_ANY ||
+ (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
+ !isKnownNonZero(CmpRHS)))
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
+}
+
+/// Helps to match a select pattern in case of a type mismatch.
+///
+/// The function processes the case when type of true and false values of a
+/// select instruction differs from type of the cmp instruction operands because
+/// of a cast instruction. The function checks if it is legal to move the cast
+/// operation after "select". If yes, it returns the new second value of
+/// "select" (with the assumption that cast is moved):
+/// 1. As operand of cast instruction when both values of "select" are same cast
+/// instructions.
+/// 2. As restored constant (by applying reverse cast operation) when the first
+/// value of the "select" is a cast operation and the second value is a
+/// constant.
+/// NOTE: We return only the new second value because the first value could be
+/// accessed as operand of cast instruction.
+static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
+ Instruction::CastOps *CastOp) {
+ auto *Cast1 = dyn_cast<CastInst>(V1);
+ if (!Cast1)
+ return nullptr;
+
+ *CastOp = Cast1->getOpcode();
+ Type *SrcTy = Cast1->getSrcTy();
+ if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
+ // If V1 and V2 are both the same cast from the same type, look through V1.
+ if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
+ return Cast2->getOperand(0);
+ return nullptr;
+ }
+
+ auto *C = dyn_cast<Constant>(V2);
+ if (!C)
+ return nullptr;
+
+ Constant *CastedTo = nullptr;
+ switch (*CastOp) {
+ case Instruction::ZExt:
+ if (CmpI->isUnsigned())
+ CastedTo = ConstantExpr::getTrunc(C, SrcTy);
+ break;
+ case Instruction::SExt:
+ if (CmpI->isSigned())
+ CastedTo = ConstantExpr::getTrunc(C, SrcTy, true);
+ break;
+ case Instruction::Trunc:
+ Constant *CmpConst;
+ if (match(CmpI->getOperand(1), m_Constant(CmpConst)) &&
+ CmpConst->getType() == SrcTy) {
+ // Here we have the following case:
+ //
+ // %cond = cmp iN %x, CmpConst
+ // %tr = trunc iN %x to iK
+ // %narrowsel = select i1 %cond, iK %t, iK C
+ //
+ // We can always move trunc after select operation:
+ //
+ // %cond = cmp iN %x, CmpConst
+ // %widesel = select i1 %cond, iN %x, iN CmpConst
+ // %tr = trunc iN %widesel to iK
+ //
+ // Note that C could be extended in any way because we don't care about
+ // upper bits after truncation. It can't be abs pattern, because it would
+ // look like:
+ //
+ // select i1 %cond, x, -x.
+ //
+ // So only min/max pattern could be matched. Such match requires widened C
+ // == CmpConst. That is why set widened C = CmpConst, condition trunc
+ // CmpConst == C is checked below.
+ CastedTo = CmpConst;
+ } else {
+ CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned());
+ }
+ break;
+ case Instruction::FPTrunc:
+ CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true);
+ break;
+ case Instruction::FPExt:
+ CastedTo = ConstantExpr::getFPTrunc(C, SrcTy, true);
+ break;
+ case Instruction::FPToUI:
+ CastedTo = ConstantExpr::getUIToFP(C, SrcTy, true);
+ break;
+ case Instruction::FPToSI:
+ CastedTo = ConstantExpr::getSIToFP(C, SrcTy, true);
+ break;
+ case Instruction::UIToFP:
+ CastedTo = ConstantExpr::getFPToUI(C, SrcTy, true);
+ break;
+ case Instruction::SIToFP:
+ CastedTo = ConstantExpr::getFPToSI(C, SrcTy, true);
+ break;
+ default:
+ break;
+ }
+
+ if (!CastedTo)
+ return nullptr;
+
+ // Make sure the cast doesn't lose any information.
+ Constant *CastedBack =
+ ConstantExpr::getCast(*CastOp, CastedTo, C->getType(), true);
+ if (CastedBack != C)
+ return nullptr;
+
+ return CastedTo;
+}
+
+SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
+ Instruction::CastOps *CastOp,
+ unsigned Depth) {
+ if (Depth >= MaxDepth)
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ SelectInst *SI = dyn_cast<SelectInst>(V);
+ if (!SI) return {SPF_UNKNOWN, SPNB_NA, false};
+
+ CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition());
+ if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false};
+
+ Value *TrueVal = SI->getTrueValue();
+ Value *FalseVal = SI->getFalseValue();
+
+ return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS,
+ CastOp, Depth);
+}
+
+SelectPatternResult llvm::matchDecomposedSelectPattern(
+ CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
+ Instruction::CastOps *CastOp, unsigned Depth) {
+ CmpInst::Predicate Pred = CmpI->getPredicate();
+ Value *CmpLHS = CmpI->getOperand(0);
+ Value *CmpRHS = CmpI->getOperand(1);
+ FastMathFlags FMF;
+ if (isa<FPMathOperator>(CmpI))
+ FMF = CmpI->getFastMathFlags();
+
+ // Bail out early.
+ if (CmpI->isEquality())
+ return {SPF_UNKNOWN, SPNB_NA, false};
+
+ // Deal with type mismatches.
+ if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
+ if (Value *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) {
+ // If this is a potential fmin/fmax with a cast to integer, then ignore
+ // -0.0 because there is no corresponding integer value.
+ if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
+ FMF.setNoSignedZeros();
+ return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
+ cast<CastInst>(TrueVal)->getOperand(0), C,
+ LHS, RHS, Depth);
+ }
+ if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) {
+ // If this is a potential fmin/fmax with a cast to integer, then ignore
+ // -0.0 because there is no corresponding integer value.
+ if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
+ FMF.setNoSignedZeros();
+ return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
+ C, cast<CastInst>(FalseVal)->getOperand(0),
+ LHS, RHS, Depth);
+ }
+ }
+ return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
+ LHS, RHS, Depth);
+}
+
+CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
+ if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
+ if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
+ if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
+ if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
+ if (SPF == SPF_FMINNUM)
+ return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
+ if (SPF == SPF_FMAXNUM)
+ return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
+ llvm_unreachable("unhandled!");
+}
+
+SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
+ if (SPF == SPF_SMIN) return SPF_SMAX;
+ if (SPF == SPF_UMIN) return SPF_UMAX;
+ if (SPF == SPF_SMAX) return SPF_SMIN;
+ if (SPF == SPF_UMAX) return SPF_UMIN;
+ llvm_unreachable("unhandled!");
+}
+
+CmpInst::Predicate llvm::getInverseMinMaxPred(SelectPatternFlavor SPF) {
+ return getMinMaxPred(getInverseMinMaxFlavor(SPF));
+}
+
+/// Return true if "icmp Pred LHS RHS" is always true.
+static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
+ const Value *RHS, const DataLayout &DL,
+ unsigned Depth) {
+ assert(!LHS->getType()->isVectorTy() && "TODO: extend to handle vectors!");
+ if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
+ return true;
+
+ switch (Pred) {
+ default:
+ return false;
+
+ case CmpInst::ICMP_SLE: {
+ const APInt *C;
+
+ // LHS s<= LHS +_{nsw} C if C >= 0
+ if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))))
+ return !C->isNegative();
+ return false;
+ }
+
+ case CmpInst::ICMP_ULE: {
+ const APInt *C;
+
+ // LHS u<= LHS +_{nuw} C for any C
+ if (match(RHS, m_NUWAdd(m_Specific(LHS), m_APInt(C))))
+ return true;
+
+ // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
+ auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B,
+ const Value *&X,
+ const APInt *&CA, const APInt *&CB) {
+ if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) &&
+ match(B, m_NUWAdd(m_Specific(X), m_APInt(CB))))
+ return true;
+
+ // If X & C == 0 then (X | C) == X +_{nuw} C
+ if (match(A, m_Or(m_Value(X), m_APInt(CA))) &&
+ match(B, m_Or(m_Specific(X), m_APInt(CB)))) {
+ KnownBits Known(CA->getBitWidth());
+ computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr,
+ /*CxtI*/ nullptr, /*DT*/ nullptr);
+ if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero))
+ return true;
+ }
+
+ return false;
+ };
+
+ const Value *X;
+ const APInt *CLHS, *CRHS;
+ if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS))
+ return CLHS->ule(*CRHS);
+
+ return false;
+ }
+ }
+}
+
+/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
+/// ALHS ARHS" is true. Otherwise, return None.
+static Optional<bool>
+isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
+ const Value *ARHS, const Value *BLHS, const Value *BRHS,
+ const DataLayout &DL, unsigned Depth) {
+ switch (Pred) {
+ default:
+ return None;
+
+ case CmpInst::ICMP_SLT:
+ case CmpInst::ICMP_SLE:
+ if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) &&
+ isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth))
+ return true;
+ return None;
+
+ case CmpInst::ICMP_ULT:
+ case CmpInst::ICMP_ULE:
+ if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) &&
+ isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth))
+ return true;
+ return None;
+ }
+}
+
+/// Return true if the operands of the two compares match. IsSwappedOps is true
+/// when the operands match, but are swapped.
+static bool isMatchingOps(const Value *ALHS, const Value *ARHS,
+ const Value *BLHS, const Value *BRHS,
+ bool &IsSwappedOps) {
+
+ bool IsMatchingOps = (ALHS == BLHS && ARHS == BRHS);
+ IsSwappedOps = (ALHS == BRHS && ARHS == BLHS);
+ return IsMatchingOps || IsSwappedOps;
+}
+
+/// Return true if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is true.
+/// Return false if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is false.
+/// Otherwise, return None if we can't infer anything.
+static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred,
+ CmpInst::Predicate BPred,
+ bool AreSwappedOps) {
+ // Canonicalize the predicate as if the operands were not commuted.
+ if (AreSwappedOps)
+ BPred = ICmpInst::getSwappedPredicate(BPred);
+
+ if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred))
+ return true;
+ if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred))
+ return false;
+
+ return None;
+}
+
+/// Return true if "icmp APred X, C1" implies "icmp BPred X, C2" is true.
+/// Return false if "icmp APred X, C1" implies "icmp BPred X, C2" is false.
+/// Otherwise, return None if we can't infer anything.
+static Optional<bool>
+isImpliedCondMatchingImmOperands(CmpInst::Predicate APred,
+ const ConstantInt *C1,
+ CmpInst::Predicate BPred,
+ const ConstantInt *C2) {
+ ConstantRange DomCR =
+ ConstantRange::makeExactICmpRegion(APred, C1->getValue());
+ ConstantRange CR =
+ ConstantRange::makeAllowedICmpRegion(BPred, C2->getValue());
+ ConstantRange Intersection = DomCR.intersectWith(CR);
+ ConstantRange Difference = DomCR.difference(CR);
+ if (Intersection.isEmptySet())
+ return false;
+ if (Difference.isEmptySet())
+ return true;
+ return None;
+}
+
+/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is
+/// false. Otherwise, return None if we can't infer anything.
+static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
+ const ICmpInst *RHS,
+ const DataLayout &DL, bool LHSIsTrue,
+ unsigned Depth) {
+ Value *ALHS = LHS->getOperand(0);
+ Value *ARHS = LHS->getOperand(1);
+ // The rest of the logic assumes the LHS condition is true. If that's not the
+ // case, invert the predicate to make it so.
+ ICmpInst::Predicate APred =
+ LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate();
+
+ Value *BLHS = RHS->getOperand(0);
+ Value *BRHS = RHS->getOperand(1);
+ ICmpInst::Predicate BPred = RHS->getPredicate();
+
+ // Can we infer anything when the two compares have matching operands?
+ bool AreSwappedOps;
+ if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, AreSwappedOps)) {
+ if (Optional<bool> Implication = isImpliedCondMatchingOperands(
+ APred, BPred, AreSwappedOps))
+ return Implication;
+ // No amount of additional analysis will infer the second condition, so
+ // early exit.
+ return None;
+ }
+
+ // Can we infer anything when the LHS operands match and the RHS operands are
+ // constants (not necessarily matching)?
+ if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) {
+ if (Optional<bool> Implication = isImpliedCondMatchingImmOperands(
+ APred, cast<ConstantInt>(ARHS), BPred, cast<ConstantInt>(BRHS)))
+ return Implication;
+ // No amount of additional analysis will infer the second condition, so
+ // early exit.
+ return None;
+ }
+
+ if (APred == BPred)
+ return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth);
+ return None;
+}
+
+/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is
+/// false. Otherwise, return None if we can't infer anything. We expect the
+/// RHS to be an icmp and the LHS to be an 'and' or an 'or' instruction.
+static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS,
+ const ICmpInst *RHS,
+ const DataLayout &DL, bool LHSIsTrue,
+ unsigned Depth) {
+ // The LHS must be an 'or' or an 'and' instruction.
+ assert((LHS->getOpcode() == Instruction::And ||
+ LHS->getOpcode() == Instruction::Or) &&
+ "Expected LHS to be 'and' or 'or'.");
+
+ assert(Depth <= MaxDepth && "Hit recursion limit");
+
+ // If the result of an 'or' is false, then we know both legs of the 'or' are
+ // false. Similarly, if the result of an 'and' is true, then we know both
+ // legs of the 'and' are true.
+ Value *ALHS, *ARHS;
+ if ((!LHSIsTrue && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) ||
+ (LHSIsTrue && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) {
+ // FIXME: Make this non-recursion.
+ if (Optional<bool> Implication =
+ isImpliedCondition(ALHS, RHS, DL, LHSIsTrue, Depth + 1))
+ return Implication;
+ if (Optional<bool> Implication =
+ isImpliedCondition(ARHS, RHS, DL, LHSIsTrue, Depth + 1))
+ return Implication;
+ return None;
+ }
+ return None;
+}
+
+Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
+ const DataLayout &DL, bool LHSIsTrue,
+ unsigned Depth) {
+ // Bail out when we hit the limit.
+ if (Depth == MaxDepth)
+ return None;
+
+ // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
+ // example.
+ if (LHS->getType() != RHS->getType())
+ return None;
+
+ Type *OpTy = LHS->getType();
+ assert(OpTy->isIntOrIntVectorTy(1) && "Expected integer type only!");
+
+ // LHS ==> RHS by definition
+ if (LHS == RHS)
+ return LHSIsTrue;
+
+ // FIXME: Extending the code below to handle vectors.
+ if (OpTy->isVectorTy())
+ return None;
+
+ assert(OpTy->isIntegerTy(1) && "implied by above");
+
+ // Both LHS and RHS are icmps.
+ const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS);
+ const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS);
+ if (LHSCmp && RHSCmp)
+ return isImpliedCondICmps(LHSCmp, RHSCmp, DL, LHSIsTrue, Depth);
+
+ // The LHS should be an 'or' or an 'and' instruction. We expect the RHS to be
+ // an icmp. FIXME: Add support for and/or on the RHS.
+ const BinaryOperator *LHSBO = dyn_cast<BinaryOperator>(LHS);
+ if (LHSBO && RHSCmp) {
+ if ((LHSBO->getOpcode() == Instruction::And ||
+ LHSBO->getOpcode() == Instruction::Or))
+ return isImpliedCondAndOr(LHSBO, RHSCmp, DL, LHSIsTrue, Depth);
+ }
+ return None;
+}
+
+Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
+ const Instruction *ContextI,
+ const DataLayout &DL) {
+ assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
+ if (!ContextI || !ContextI->getParent())
+ return None;
+
+ // TODO: This is a poor/cheap way to determine dominance. Should we use a
+ // dominator tree (eg, from a SimplifyQuery) instead?
+ const BasicBlock *ContextBB = ContextI->getParent();
+ const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
+ if (!PredBB)
+ return None;
+
+ // We need a conditional branch in the predecessor.
+ Value *PredCond;
+ BasicBlock *TrueBB, *FalseBB;
+ if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB)))
+ return None;
+
+ // The branch should get simplified. Don't bother simplifying this condition.
+ if (TrueBB == FalseBB)
+ return None;
+
+ assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
+ "Predecessor block does not point to successor?");
+
+ // Is this condition implied by the predecessor condition?
+ bool CondIsTrue = TrueBB == ContextBB;
+ return isImpliedCondition(PredCond, Cond, DL, CondIsTrue);
+}
+
+static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
+ APInt &Upper, const InstrInfoQuery &IIQ) {
+ unsigned Width = Lower.getBitWidth();
+ const APInt *C;
+ switch (BO.getOpcode()) {
+ case Instruction::Add:
+ if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) {
+ // FIXME: If we have both nuw and nsw, we should reduce the range further.
+ if (IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(&BO))) {
+ // 'add nuw x, C' produces [C, UINT_MAX].
+ Lower = *C;
+ } else if (IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(&BO))) {
+ if (C->isNegative()) {
+ // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
+ Lower = APInt::getSignedMinValue(Width);
+ Upper = APInt::getSignedMaxValue(Width) + *C + 1;
+ } else {
+ // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
+ Lower = APInt::getSignedMinValue(Width) + *C;
+ Upper = APInt::getSignedMaxValue(Width) + 1;
+ }
+ }
+ }
+ break;
+
+ case Instruction::And:
+ if (match(BO.getOperand(1), m_APInt(C)))
+ // 'and x, C' produces [0, C].
+ Upper = *C + 1;
+ break;
+
+ case Instruction::Or:
+ if (match(BO.getOperand(1), m_APInt(C)))
+ // 'or x, C' produces [C, UINT_MAX].
+ Lower = *C;
+ break;
+
+ case Instruction::AShr:
+ if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
+ // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
+ Lower = APInt::getSignedMinValue(Width).ashr(*C);
+ Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1;
+ } else if (match(BO.getOperand(0), m_APInt(C))) {
+ unsigned ShiftAmount = Width - 1;
+ if (!C->isNullValue() && IIQ.isExact(&BO))
+ ShiftAmount = C->countTrailingZeros();
+ if (C->isNegative()) {
+ // 'ashr C, x' produces [C, C >> (Width-1)]
+ Lower = *C;
+ Upper = C->ashr(ShiftAmount) + 1;
+ } else {
+ // 'ashr C, x' produces [C >> (Width-1), C]
+ Lower = C->ashr(ShiftAmount);
+ Upper = *C + 1;
+ }
+ }
+ break;
+
+ case Instruction::LShr:
+ if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
+ // 'lshr x, C' produces [0, UINT_MAX >> C].
+ Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1;
+ } else if (match(BO.getOperand(0), m_APInt(C))) {
+ // 'lshr C, x' produces [C >> (Width-1), C].
+ unsigned ShiftAmount = Width - 1;
+ if (!C->isNullValue() && IIQ.isExact(&BO))
+ ShiftAmount = C->countTrailingZeros();
+ Lower = C->lshr(ShiftAmount);
+ Upper = *C + 1;
+ }
+ break;
+
+ case Instruction::Shl:
+ if (match(BO.getOperand(0), m_APInt(C))) {
+ if (IIQ.hasNoUnsignedWrap(&BO)) {
+ // 'shl nuw C, x' produces [C, C << CLZ(C)]
+ Lower = *C;
+ Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
+ } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
+ if (C->isNegative()) {
+ // 'shl nsw C, x' produces [C << CLO(C)-1, C]
+ unsigned ShiftAmount = C->countLeadingOnes() - 1;
+ Lower = C->shl(ShiftAmount);
+ Upper = *C + 1;
+ } else {
+ // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
+ unsigned ShiftAmount = C->countLeadingZeros() - 1;
+ Lower = *C;
+ Upper = C->shl(ShiftAmount) + 1;
+ }
+ }
+ }
+ break;
+
+ case Instruction::SDiv:
+ if (match(BO.getOperand(1), m_APInt(C))) {
+ APInt IntMin = APInt::getSignedMinValue(Width);
+ APInt IntMax = APInt::getSignedMaxValue(Width);
+ if (C->isAllOnesValue()) {
+ // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
+ // where C != -1 and C != 0 and C != 1
+ Lower = IntMin + 1;
+ Upper = IntMax + 1;
+ } else if (C->countLeadingZeros() < Width - 1) {
+ // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
+ // where C != -1 and C != 0 and C != 1
+ Lower = IntMin.sdiv(*C);
+ Upper = IntMax.sdiv(*C);
+ if (Lower.sgt(Upper))
+ std::swap(Lower, Upper);
+ Upper = Upper + 1;
+ assert(Upper != Lower && "Upper part of range has wrapped!");
+ }
+ } else if (match(BO.getOperand(0), m_APInt(C))) {
+ if (C->isMinSignedValue()) {
+ // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
+ Lower = *C;
+ Upper = Lower.lshr(1) + 1;
+ } else {
+ // 'sdiv C, x' produces [-|C|, |C|].
+ Upper = C->abs() + 1;
+ Lower = (-Upper) + 1;
+ }
+ }
+ break;
+
+ case Instruction::UDiv:
+ if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) {
+ // 'udiv x, C' produces [0, UINT_MAX / C].
+ Upper = APInt::getMaxValue(Width).udiv(*C) + 1;
+ } else if (match(BO.getOperand(0), m_APInt(C))) {
+ // 'udiv C, x' produces [0, C].
+ Upper = *C + 1;
+ }
+ break;
+
+ case Instruction::SRem:
+ if (match(BO.getOperand(1), m_APInt(C))) {
+ // 'srem x, C' produces (-|C|, |C|).
+ Upper = C->abs();
+ Lower = (-Upper) + 1;
+ }
+ break;
+
+ case Instruction::URem:
+ if (match(BO.getOperand(1), m_APInt(C)))
+ // 'urem x, C' produces [0, C).
+ Upper = *C;
+ break;
+
+ default:
+ break;
+ }
+}
+
+static void setLimitsForIntrinsic(const IntrinsicInst &II, APInt &Lower,
+ APInt &Upper) {
+ unsigned Width = Lower.getBitWidth();
+ const APInt *C;
+ switch (II.getIntrinsicID()) {
+ case Intrinsic::uadd_sat:
+ // uadd.sat(x, C) produces [C, UINT_MAX].
+ if (match(II.getOperand(0), m_APInt(C)) ||
+ match(II.getOperand(1), m_APInt(C)))
+ Lower = *C;
+ break;
+ case Intrinsic::sadd_sat:
+ if (match(II.getOperand(0), m_APInt(C)) ||
+ match(II.getOperand(1), m_APInt(C))) {
+ if (C->isNegative()) {
+ // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
+ Lower = APInt::getSignedMinValue(Width);
+ Upper = APInt::getSignedMaxValue(Width) + *C + 1;
+ } else {
+ // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
+ Lower = APInt::getSignedMinValue(Width) + *C;
+ Upper = APInt::getSignedMaxValue(Width) + 1;
+ }
+ }
+ break;
+ case Intrinsic::usub_sat:
+ // usub.sat(C, x) produces [0, C].
+ if (match(II.getOperand(0), m_APInt(C)))
+ Upper = *C + 1;
+ // usub.sat(x, C) produces [0, UINT_MAX - C].
+ else if (match(II.getOperand(1), m_APInt(C)))
+ Upper = APInt::getMaxValue(Width) - *C + 1;
+ break;
+ case Intrinsic::ssub_sat:
+ if (match(II.getOperand(0), m_APInt(C))) {
+ if (C->isNegative()) {
+ // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
+ Lower = APInt::getSignedMinValue(Width);
+ Upper = *C - APInt::getSignedMinValue(Width) + 1;
+ } else {
+ // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
+ Lower = *C - APInt::getSignedMaxValue(Width);
+ Upper = APInt::getSignedMaxValue(Width) + 1;
+ }
+ } else if (match(II.getOperand(1), m_APInt(C))) {
+ if (C->isNegative()) {
+ // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
+ Lower = APInt::getSignedMinValue(Width) - *C;
+ Upper = APInt::getSignedMaxValue(Width) + 1;
+ } else {
+ // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
+ Lower = APInt::getSignedMinValue(Width);
+ Upper = APInt::getSignedMaxValue(Width) - *C + 1;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower,
+ APInt &Upper, const InstrInfoQuery &IIQ) {
+ const Value *LHS = nullptr, *RHS = nullptr;
+ SelectPatternResult R = matchSelectPattern(&SI, LHS, RHS);
+ if (R.Flavor == SPF_UNKNOWN)
+ return;
+
+ unsigned BitWidth = SI.getType()->getScalarSizeInBits();
+
+ if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
+ // If the negation part of the abs (in RHS) has the NSW flag,
+ // then the result of abs(X) is [0..SIGNED_MAX],
+ // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
+ Lower = APInt::getNullValue(BitWidth);
+ if (match(RHS, m_Neg(m_Specific(LHS))) &&
+ IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
+ Upper = APInt::getSignedMaxValue(BitWidth) + 1;
+ else
+ Upper = APInt::getSignedMinValue(BitWidth) + 1;
+ return;
+ }
+
+ if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
+ // The result of -abs(X) is <= 0.
+ Lower = APInt::getSignedMinValue(BitWidth);
+ Upper = APInt(BitWidth, 1);
+ return;
+ }
+
+ const APInt *C;
+ if (!match(LHS, m_APInt(C)) && !match(RHS, m_APInt(C)))
+ return;
+
+ switch (R.Flavor) {
+ case SPF_UMIN:
+ Upper = *C + 1;
+ break;
+ case SPF_UMAX:
+ Lower = *C;
+ break;
+ case SPF_SMIN:
+ Lower = APInt::getSignedMinValue(BitWidth);
+ Upper = *C + 1;
+ break;
+ case SPF_SMAX:
+ Lower = *C;
+ Upper = APInt::getSignedMaxValue(BitWidth) + 1;
+ break;
+ default:
+ break;
+ }
+}
+
+ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) {
+ assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
+
+ const APInt *C;
+ if (match(V, m_APInt(C)))
+ return ConstantRange(*C);
+
+ InstrInfoQuery IIQ(UseInstrInfo);
+ unsigned BitWidth = V->getType()->getScalarSizeInBits();
+ APInt Lower = APInt(BitWidth, 0);
+ APInt Upper = APInt(BitWidth, 0);
+ if (auto *BO = dyn_cast<BinaryOperator>(V))
+ setLimitsForBinOp(*BO, Lower, Upper, IIQ);
+ else if (auto *II = dyn_cast<IntrinsicInst>(V))
+ setLimitsForIntrinsic(*II, Lower, Upper);
+ else if (auto *SI = dyn_cast<SelectInst>(V))
+ setLimitsForSelectPattern(*SI, Lower, Upper, IIQ);
+
+ ConstantRange CR = ConstantRange::getNonEmpty(Lower, Upper);
+
+ if (auto *I = dyn_cast<Instruction>(V))
+ if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
+ CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
+
+ return CR;
+}
+
+static Optional<int64_t>
+getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) {
+ // Skip over the first indices.
+ gep_type_iterator GTI = gep_type_begin(GEP);
+ for (unsigned i = 1; i != Idx; ++i, ++GTI)
+ /*skip along*/;
+
+ // Compute the offset implied by the rest of the indices.
+ int64_t Offset = 0;
+ for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
+ ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i));
+ if (!OpC)
+ return None;
+ if (OpC->isZero())
+ continue; // No offset.
+
+ // Handle struct indices, which add their field offset to the pointer.
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
+ continue;
+ }
+
+ // Otherwise, we have a sequential type like an array or vector. Multiply
+ // the index by the ElementSize.
+ uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
+ Offset += Size * OpC->getSExtValue();
+ }
+
+ return Offset;
+}
+
+Optional<int64_t> llvm::isPointerOffset(const Value *Ptr1, const Value *Ptr2,
+ const DataLayout &DL) {
+ Ptr1 = Ptr1->stripPointerCasts();
+ Ptr2 = Ptr2->stripPointerCasts();
+
+ // Handle the trivial case first.
+ if (Ptr1 == Ptr2) {
+ return 0;
+ }
+
+ const GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1);
+ const GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2);
+
+ // If one pointer is a GEP see if the GEP is a constant offset from the base,
+ // as in "P" and "gep P, 1".
+ // Also do this iteratively to handle the the following case:
+ // Ptr_t1 = GEP Ptr1, c1
+ // Ptr_t2 = GEP Ptr_t1, c2
+ // Ptr2 = GEP Ptr_t2, c3
+ // where we will return c1+c2+c3.
+ // TODO: Handle the case when both Ptr1 and Ptr2 are GEPs of some common base
+ // -- replace getOffsetFromBase with getOffsetAndBase, check that the bases
+ // are the same, and return the difference between offsets.
+ auto getOffsetFromBase = [&DL](const GEPOperator *GEP,
+ const Value *Ptr) -> Optional<int64_t> {
+ const GEPOperator *GEP_T = GEP;
+ int64_t OffsetVal = 0;
+ bool HasSameBase = false;
+ while (GEP_T) {
+ auto Offset = getOffsetFromIndex(GEP_T, 1, DL);
+ if (!Offset)
+ return None;
+ OffsetVal += *Offset;
+ auto Op0 = GEP_T->getOperand(0)->stripPointerCasts();
+ if (Op0 == Ptr) {
+ HasSameBase = true;
+ break;
+ }
+ GEP_T = dyn_cast<GEPOperator>(Op0);
+ }
+ if (!HasSameBase)
+ return None;
+ return OffsetVal;
+ };
+
+ if (GEP1) {
+ auto Offset = getOffsetFromBase(GEP1, Ptr2);
+ if (Offset)
+ return -*Offset;
+ }
+ if (GEP2) {
+ auto Offset = getOffsetFromBase(GEP2, Ptr1);
+ if (Offset)
+ return Offset;
+ }
+
+ // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical
+ // base. After that base, they may have some number of common (and
+ // potentially variable) indices. After that they handle some constant
+ // offset, which determines their offset from each other. At this point, we
+ // handle no other case.
+ if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0))
+ return None;
+
+ // Skip any common indices and track the GEP types.
+ unsigned Idx = 1;
+ for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx)
+ if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx))
+ break;
+
+ auto Offset1 = getOffsetFromIndex(GEP1, Idx, DL);
+ auto Offset2 = getOffsetFromIndex(GEP2, Idx, DL);
+ if (!Offset1 || !Offset2)
+ return None;
+ return *Offset2 - *Offset1;
+}
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
new file mode 100644
index 000000000000..600f57ab9d71
--- /dev/null
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -0,0 +1,1161 @@
+//===----------- VectorUtils.cpp - Vectorizer utility functions -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines vectorizer utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/Analysis/DemandedBits.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Value.h"
+
+#define DEBUG_TYPE "vectorutils"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+/// Maximum factor for an interleaved memory access.
+static cl::opt<unsigned> MaxInterleaveGroupFactor(
+ "max-interleave-group-factor", cl::Hidden,
+ cl::desc("Maximum factor for an interleaved access group (default = 8)"),
+ cl::init(8));
+
+/// Return true if all of the intrinsic's arguments and return type are scalars
+/// for the scalar form of the intrinsic, and vectors for the vector form of the
+/// intrinsic (except operands that are marked as always being scalar by
+/// hasVectorInstrinsicScalarOpd).
+bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
+ switch (ID) {
+ case Intrinsic::bswap: // Begin integer bit-manipulation.
+ case Intrinsic::bitreverse:
+ case Intrinsic::ctpop:
+ case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::fshl:
+ case Intrinsic::fshr:
+ case Intrinsic::sadd_sat:
+ case Intrinsic::ssub_sat:
+ case Intrinsic::uadd_sat:
+ case Intrinsic::usub_sat:
+ case Intrinsic::smul_fix:
+ case Intrinsic::smul_fix_sat:
+ case Intrinsic::umul_fix:
+ case Intrinsic::umul_fix_sat:
+ case Intrinsic::sqrt: // Begin floating-point.
+ case Intrinsic::sin:
+ case Intrinsic::cos:
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::log:
+ case Intrinsic::log10:
+ case Intrinsic::log2:
+ case Intrinsic::fabs:
+ case Intrinsic::minnum:
+ case Intrinsic::maxnum:
+ case Intrinsic::minimum:
+ case Intrinsic::maximum:
+ case Intrinsic::copysign:
+ case Intrinsic::floor:
+ case Intrinsic::ceil:
+ case Intrinsic::trunc:
+ case Intrinsic::rint:
+ case Intrinsic::nearbyint:
+ case Intrinsic::round:
+ case Intrinsic::pow:
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd:
+ case Intrinsic::powi:
+ case Intrinsic::canonicalize:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// Identifies if the vector form of the intrinsic has a scalar operand.
+bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx) {
+ switch (ID) {
+ case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::powi:
+ return (ScalarOpdIdx == 1);
+ case Intrinsic::smul_fix:
+ case Intrinsic::smul_fix_sat:
+ case Intrinsic::umul_fix:
+ case Intrinsic::umul_fix_sat:
+ return (ScalarOpdIdx == 2);
+ default:
+ return false;
+ }
+}
+
+/// Returns intrinsic ID for call.
+/// For the input call instruction it finds mapping intrinsic and returns
+/// its ID, in case it does not found it return not_intrinsic.
+Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI,
+ const TargetLibraryInfo *TLI) {
+ Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI);
+ if (ID == Intrinsic::not_intrinsic)
+ return Intrinsic::not_intrinsic;
+
+ if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start ||
+ ID == Intrinsic::lifetime_end || ID == Intrinsic::assume ||
+ ID == Intrinsic::sideeffect)
+ return ID;
+ return Intrinsic::not_intrinsic;
+}
+
+/// Find the operand of the GEP that should be checked for consecutive
+/// stores. This ignores trailing indices that have no effect on the final
+/// pointer.
+unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) {
+ const DataLayout &DL = Gep->getModule()->getDataLayout();
+ unsigned LastOperand = Gep->getNumOperands() - 1;
+ unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType());
+
+ // Walk backwards and try to peel off zeros.
+ while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) {
+ // Find the type we're currently indexing into.
+ gep_type_iterator GEPTI = gep_type_begin(Gep);
+ std::advance(GEPTI, LastOperand - 2);
+
+ // If it's a type with the same allocation size as the result of the GEP we
+ // can peel off the zero index.
+ if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize)
+ break;
+ --LastOperand;
+ }
+
+ return LastOperand;
+}
+
+/// If the argument is a GEP, then returns the operand identified by
+/// getGEPInductionOperand. However, if there is some other non-loop-invariant
+/// operand, it returns that instead.
+Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
+ GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!GEP)
+ return Ptr;
+
+ unsigned InductionOperand = getGEPInductionOperand(GEP);
+
+ // Check that all of the gep indices are uniform except for our induction
+ // operand.
+ for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i)
+ if (i != InductionOperand &&
+ !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp))
+ return Ptr;
+ return GEP->getOperand(InductionOperand);
+}
+
+/// If a value has only one user that is a CastInst, return it.
+Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
+ Value *UniqueCast = nullptr;
+ for (User *U : Ptr->users()) {
+ CastInst *CI = dyn_cast<CastInst>(U);
+ if (CI && CI->getType() == Ty) {
+ if (!UniqueCast)
+ UniqueCast = CI;
+ else
+ return nullptr;
+ }
+ }
+ return UniqueCast;
+}
+
+/// Get the stride of a pointer access in a loop. Looks for symbolic
+/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
+Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
+ auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
+ if (!PtrTy || PtrTy->isAggregateType())
+ return nullptr;
+
+ // Try to remove a gep instruction to make the pointer (actually index at this
+ // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the
+ // pointer, otherwise, we are analyzing the index.
+ Value *OrigPtr = Ptr;
+
+ // The size of the pointer access.
+ int64_t PtrAccessSize = 1;
+
+ Ptr = stripGetElementPtr(Ptr, SE, Lp);
+ const SCEV *V = SE->getSCEV(Ptr);
+
+ if (Ptr != OrigPtr)
+ // Strip off casts.
+ while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V))
+ V = C->getOperand();
+
+ const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V);
+ if (!S)
+ return nullptr;
+
+ V = S->getStepRecurrence(*SE);
+ if (!V)
+ return nullptr;
+
+ // Strip off the size of access multiplication if we are still analyzing the
+ // pointer.
+ if (OrigPtr == Ptr) {
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) {
+ if (M->getOperand(0)->getSCEVType() != scConstant)
+ return nullptr;
+
+ const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt();
+
+ // Huge step value - give up.
+ if (APStepVal.getBitWidth() > 64)
+ return nullptr;
+
+ int64_t StepVal = APStepVal.getSExtValue();
+ if (PtrAccessSize != StepVal)
+ return nullptr;
+ V = M->getOperand(1);
+ }
+ }
+
+ // Strip off casts.
+ Type *StripedOffRecurrenceCast = nullptr;
+ if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) {
+ StripedOffRecurrenceCast = C->getType();
+ V = C->getOperand();
+ }
+
+ // Look for the loop invariant symbolic value.
+ const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
+ if (!U)
+ return nullptr;
+
+ Value *Stride = U->getValue();
+ if (!Lp->isLoopInvariant(Stride))
+ return nullptr;
+
+ // If we have stripped off the recurrence cast we have to make sure that we
+ // return the value that is used in this loop so that we can replace it later.
+ if (StripedOffRecurrenceCast)
+ Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast);
+
+ return Stride;
+}
+
+/// Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
+ assert(V->getType()->isVectorTy() && "Not looking at a vector?");
+ VectorType *VTy = cast<VectorType>(V->getType());
+ unsigned Width = VTy->getNumElements();
+ if (EltNo >= Width) // Out of range access.
+ return UndefValue::get(VTy->getElementType());
+
+ if (Constant *C = dyn_cast<Constant>(V))
+ return C->getAggregateElement(EltNo);
+
+ if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
+ // If this is an insert to a variable element, we don't know what it is.
+ if (!isa<ConstantInt>(III->getOperand(2)))
+ return nullptr;
+ unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
+
+ // If this is an insert to the element we are looking for, return the
+ // inserted value.
+ if (EltNo == IIElt)
+ return III->getOperand(1);
+
+ // Otherwise, the insertelement doesn't modify the value, recurse on its
+ // vector input.
+ return findScalarElement(III->getOperand(0), EltNo);
+ }
+
+ if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
+ unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
+ int InEl = SVI->getMaskValue(EltNo);
+ if (InEl < 0)
+ return UndefValue::get(VTy->getElementType());
+ if (InEl < (int)LHSWidth)
+ return findScalarElement(SVI->getOperand(0), InEl);
+ return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
+ }
+
+ // Extract a value from a vector add operation with a constant zero.
+ // TODO: Use getBinOpIdentity() to generalize this.
+ Value *Val; Constant *C;
+ if (match(V, m_Add(m_Value(Val), m_Constant(C))))
+ if (Constant *Elt = C->getAggregateElement(EltNo))
+ if (Elt->isNullValue())
+ return findScalarElement(Val, EltNo);
+
+ // Otherwise, we don't know.
+ return nullptr;
+}
+
+/// Get splat value if the input is a splat vector or return nullptr.
+/// This function is not fully general. It checks only 2 cases:
+/// the input value is (1) a splat constant vector or (2) a sequence
+/// of instructions that broadcasts a scalar at element 0.
+const llvm::Value *llvm::getSplatValue(const Value *V) {
+ if (isa<VectorType>(V->getType()))
+ if (auto *C = dyn_cast<Constant>(V))
+ return C->getSplatValue();
+
+ // shuf (inselt ?, Splat, 0), ?, <0, undef, 0, ...>
+ Value *Splat;
+ if (match(V, m_ShuffleVector(m_InsertElement(m_Value(), m_Value(Splat),
+ m_ZeroInt()),
+ m_Value(), m_ZeroInt())))
+ return Splat;
+
+ return nullptr;
+}
+
+// This setting is based on its counterpart in value tracking, but it could be
+// adjusted if needed.
+const unsigned MaxDepth = 6;
+
+bool llvm::isSplatValue(const Value *V, unsigned Depth) {
+ assert(Depth <= MaxDepth && "Limit Search Depth");
+
+ if (isa<VectorType>(V->getType())) {
+ if (isa<UndefValue>(V))
+ return true;
+ // FIXME: Constant splat analysis does not allow undef elements.
+ if (auto *C = dyn_cast<Constant>(V))
+ return C->getSplatValue() != nullptr;
+ }
+
+ // FIXME: Constant splat analysis does not allow undef elements.
+ Constant *Mask;
+ if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask))))
+ return Mask->getSplatValue() != nullptr;
+
+ // The remaining tests are all recursive, so bail out if we hit the limit.
+ if (Depth++ == MaxDepth)
+ return false;
+
+ // If both operands of a binop are splats, the result is a splat.
+ Value *X, *Y, *Z;
+ if (match(V, m_BinOp(m_Value(X), m_Value(Y))))
+ return isSplatValue(X, Depth) && isSplatValue(Y, Depth);
+
+ // If all operands of a select are splats, the result is a splat.
+ if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z))))
+ return isSplatValue(X, Depth) && isSplatValue(Y, Depth) &&
+ isSplatValue(Z, Depth);
+
+ // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops).
+
+ return false;
+}
+
+MapVector<Instruction *, uint64_t>
+llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
+ const TargetTransformInfo *TTI) {
+
+ // DemandedBits will give us every value's live-out bits. But we want
+ // to ensure no extra casts would need to be inserted, so every DAG
+ // of connected values must have the same minimum bitwidth.
+ EquivalenceClasses<Value *> ECs;
+ SmallVector<Value *, 16> Worklist;
+ SmallPtrSet<Value *, 4> Roots;
+ SmallPtrSet<Value *, 16> Visited;
+ DenseMap<Value *, uint64_t> DBits;
+ SmallPtrSet<Instruction *, 4> InstructionSet;
+ MapVector<Instruction *, uint64_t> MinBWs;
+
+ // Determine the roots. We work bottom-up, from truncs or icmps.
+ bool SeenExtFromIllegalType = false;
+ for (auto *BB : Blocks)
+ for (auto &I : *BB) {
+ InstructionSet.insert(&I);
+
+ if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) &&
+ !TTI->isTypeLegal(I.getOperand(0)->getType()))
+ SeenExtFromIllegalType = true;
+
+ // Only deal with non-vector integers up to 64-bits wide.
+ if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) &&
+ !I.getType()->isVectorTy() &&
+ I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) {
+ // Don't make work for ourselves. If we know the loaded type is legal,
+ // don't add it to the worklist.
+ if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType()))
+ continue;
+
+ Worklist.push_back(&I);
+ Roots.insert(&I);
+ }
+ }
+ // Early exit.
+ if (Worklist.empty() || (TTI && !SeenExtFromIllegalType))
+ return MinBWs;
+
+ // Now proceed breadth-first, unioning values together.
+ while (!Worklist.empty()) {
+ Value *Val = Worklist.pop_back_val();
+ Value *Leader = ECs.getOrInsertLeaderValue(Val);
+
+ if (Visited.count(Val))
+ continue;
+ Visited.insert(Val);
+
+ // Non-instructions terminate a chain successfully.
+ if (!isa<Instruction>(Val))
+ continue;
+ Instruction *I = cast<Instruction>(Val);
+
+ // If we encounter a type that is larger than 64 bits, we can't represent
+ // it so bail out.
+ if (DB.getDemandedBits(I).getBitWidth() > 64)
+ return MapVector<Instruction *, uint64_t>();
+
+ uint64_t V = DB.getDemandedBits(I).getZExtValue();
+ DBits[Leader] |= V;
+ DBits[I] = V;
+
+ // Casts, loads and instructions outside of our range terminate a chain
+ // successfully.
+ if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) ||
+ !InstructionSet.count(I))
+ continue;
+
+ // Unsafe casts terminate a chain unsuccessfully. We can't do anything
+ // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to
+ // transform anything that relies on them.
+ if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) ||
+ !I->getType()->isIntegerTy()) {
+ DBits[Leader] |= ~0ULL;
+ continue;
+ }
+
+ // We don't modify the types of PHIs. Reductions will already have been
+ // truncated if possible, and inductions' sizes will have been chosen by
+ // indvars.
+ if (isa<PHINode>(I))
+ continue;
+
+ if (DBits[Leader] == ~0ULL)
+ // All bits demanded, no point continuing.
+ continue;
+
+ for (Value *O : cast<User>(I)->operands()) {
+ ECs.unionSets(Leader, O);
+ Worklist.push_back(O);
+ }
+ }
+
+ // Now we've discovered all values, walk them to see if there are
+ // any users we didn't see. If there are, we can't optimize that
+ // chain.
+ for (auto &I : DBits)
+ for (auto *U : I.first->users())
+ if (U->getType()->isIntegerTy() && DBits.count(U) == 0)
+ DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL;
+
+ for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) {
+ uint64_t LeaderDemandedBits = 0;
+ for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI)
+ LeaderDemandedBits |= DBits[*MI];
+
+ uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) -
+ llvm::countLeadingZeros(LeaderDemandedBits);
+ // Round up to a power of 2
+ if (!isPowerOf2_64((uint64_t)MinBW))
+ MinBW = NextPowerOf2(MinBW);
+
+ // We don't modify the types of PHIs. Reductions will already have been
+ // truncated if possible, and inductions' sizes will have been chosen by
+ // indvars.
+ // If we are required to shrink a PHI, abandon this entire equivalence class.
+ bool Abort = false;
+ for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI)
+ if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) {
+ Abort = true;
+ break;
+ }
+ if (Abort)
+ continue;
+
+ for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) {
+ if (!isa<Instruction>(*MI))
+ continue;
+ Type *Ty = (*MI)->getType();
+ if (Roots.count(*MI))
+ Ty = cast<Instruction>(*MI)->getOperand(0)->getType();
+ if (MinBW < Ty->getScalarSizeInBits())
+ MinBWs[cast<Instruction>(*MI)] = MinBW;
+ }
+ }
+
+ return MinBWs;
+}
+
+/// Add all access groups in @p AccGroups to @p List.
+template <typename ListT>
+static void addToAccessGroupList(ListT &List, MDNode *AccGroups) {
+ // Interpret an access group as a list containing itself.
+ if (AccGroups->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group");
+ List.insert(AccGroups);
+ return;
+ }
+
+ for (auto &AccGroupListOp : AccGroups->operands()) {
+ auto *Item = cast<MDNode>(AccGroupListOp.get());
+ assert(isValidAsAccessGroup(Item) && "List item must be an access group");
+ List.insert(Item);
+ }
+}
+
+MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) {
+ if (!AccGroups1)
+ return AccGroups2;
+ if (!AccGroups2)
+ return AccGroups1;
+ if (AccGroups1 == AccGroups2)
+ return AccGroups1;
+
+ SmallSetVector<Metadata *, 4> Union;
+ addToAccessGroupList(Union, AccGroups1);
+ addToAccessGroupList(Union, AccGroups2);
+
+ if (Union.size() == 0)
+ return nullptr;
+ if (Union.size() == 1)
+ return cast<MDNode>(Union.front());
+
+ LLVMContext &Ctx = AccGroups1->getContext();
+ return MDNode::get(Ctx, Union.getArrayRef());
+}
+
+MDNode *llvm::intersectAccessGroups(const Instruction *Inst1,
+ const Instruction *Inst2) {
+ bool MayAccessMem1 = Inst1->mayReadOrWriteMemory();
+ bool MayAccessMem2 = Inst2->mayReadOrWriteMemory();
+
+ if (!MayAccessMem1 && !MayAccessMem2)
+ return nullptr;
+ if (!MayAccessMem1)
+ return Inst2->getMetadata(LLVMContext::MD_access_group);
+ if (!MayAccessMem2)
+ return Inst1->getMetadata(LLVMContext::MD_access_group);
+
+ MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group);
+ MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group);
+ if (!MD1 || !MD2)
+ return nullptr;
+ if (MD1 == MD2)
+ return MD1;
+
+ // Use set for scalable 'contains' check.
+ SmallPtrSet<Metadata *, 4> AccGroupSet2;
+ addToAccessGroupList(AccGroupSet2, MD2);
+
+ SmallVector<Metadata *, 4> Intersection;
+ if (MD1->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(MD1) && "Node must be an access group");
+ if (AccGroupSet2.count(MD1))
+ Intersection.push_back(MD1);
+ } else {
+ for (const MDOperand &Node : MD1->operands()) {
+ auto *Item = cast<MDNode>(Node.get());
+ assert(isValidAsAccessGroup(Item) && "List item must be an access group");
+ if (AccGroupSet2.count(Item))
+ Intersection.push_back(Item);
+ }
+ }
+
+ if (Intersection.size() == 0)
+ return nullptr;
+ if (Intersection.size() == 1)
+ return cast<MDNode>(Intersection.front());
+
+ LLVMContext &Ctx = Inst1->getContext();
+ return MDNode::get(Ctx, Intersection);
+}
+
+/// \returns \p I after propagating metadata from \p VL.
+Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) {
+ Instruction *I0 = cast<Instruction>(VL[0]);
+ SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata;
+ I0->getAllMetadataOtherThanDebugLoc(Metadata);
+
+ for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias, LLVMContext::MD_fpmath,
+ LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load,
+ LLVMContext::MD_access_group}) {
+ MDNode *MD = I0->getMetadata(Kind);
+
+ for (int J = 1, E = VL.size(); MD && J != E; ++J) {
+ const Instruction *IJ = cast<Instruction>(VL[J]);
+ MDNode *IMD = IJ->getMetadata(Kind);
+ switch (Kind) {
+ case LLVMContext::MD_tbaa:
+ MD = MDNode::getMostGenericTBAA(MD, IMD);
+ break;
+ case LLVMContext::MD_alias_scope:
+ MD = MDNode::getMostGenericAliasScope(MD, IMD);
+ break;
+ case LLVMContext::MD_fpmath:
+ MD = MDNode::getMostGenericFPMath(MD, IMD);
+ break;
+ case LLVMContext::MD_noalias:
+ case LLVMContext::MD_nontemporal:
+ case LLVMContext::MD_invariant_load:
+ MD = MDNode::intersect(MD, IMD);
+ break;
+ case LLVMContext::MD_access_group:
+ MD = intersectAccessGroups(Inst, IJ);
+ break;
+ default:
+ llvm_unreachable("unhandled metadata");
+ }
+ }
+
+ Inst->setMetadata(Kind, MD);
+ }
+
+ return Inst;
+}
+
+Constant *
+llvm::createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF,
+ const InterleaveGroup<Instruction> &Group) {
+ // All 1's means mask is not needed.
+ if (Group.getNumMembers() == Group.getFactor())
+ return nullptr;
+
+ // TODO: support reversed access.
+ assert(!Group.isReverse() && "Reversed group not supported.");
+
+ SmallVector<Constant *, 16> Mask;
+ for (unsigned i = 0; i < VF; i++)
+ for (unsigned j = 0; j < Group.getFactor(); ++j) {
+ unsigned HasMember = Group.getMember(j) ? 1 : 0;
+ Mask.push_back(Builder.getInt1(HasMember));
+ }
+
+ return ConstantVector::get(Mask);
+}
+
+Constant *llvm::createReplicatedMask(IRBuilder<> &Builder,
+ unsigned ReplicationFactor, unsigned VF) {
+ SmallVector<Constant *, 16> MaskVec;
+ for (unsigned i = 0; i < VF; i++)
+ for (unsigned j = 0; j < ReplicationFactor; j++)
+ MaskVec.push_back(Builder.getInt32(i));
+
+ return ConstantVector::get(MaskVec);
+}
+
+Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF,
+ unsigned NumVecs) {
+ SmallVector<Constant *, 16> Mask;
+ for (unsigned i = 0; i < VF; i++)
+ for (unsigned j = 0; j < NumVecs; j++)
+ Mask.push_back(Builder.getInt32(j * VF + i));
+
+ return ConstantVector::get(Mask);
+}
+
+Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start,
+ unsigned Stride, unsigned VF) {
+ SmallVector<Constant *, 16> Mask;
+ for (unsigned i = 0; i < VF; i++)
+ Mask.push_back(Builder.getInt32(Start + i * Stride));
+
+ return ConstantVector::get(Mask);
+}
+
+Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start,
+ unsigned NumInts, unsigned NumUndefs) {
+ SmallVector<Constant *, 16> Mask;
+ for (unsigned i = 0; i < NumInts; i++)
+ Mask.push_back(Builder.getInt32(Start + i));
+
+ Constant *Undef = UndefValue::get(Builder.getInt32Ty());
+ for (unsigned i = 0; i < NumUndefs; i++)
+ Mask.push_back(Undef);
+
+ return ConstantVector::get(Mask);
+}
+
+/// A helper function for concatenating vectors. This function concatenates two
+/// vectors having the same element type. If the second vector has fewer
+/// elements than the first, it is padded with undefs.
+static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1,
+ Value *V2) {
+ VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType());
+ VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType());
+ assert(VecTy1 && VecTy2 &&
+ VecTy1->getScalarType() == VecTy2->getScalarType() &&
+ "Expect two vectors with the same element type");
+
+ unsigned NumElts1 = VecTy1->getNumElements();
+ unsigned NumElts2 = VecTy2->getNumElements();
+ assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements");
+
+ if (NumElts1 > NumElts2) {
+ // Extend with UNDEFs.
+ Constant *ExtMask =
+ createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2);
+ V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask);
+ }
+
+ Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0);
+ return Builder.CreateShuffleVector(V1, V2, Mask);
+}
+
+Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) {
+ unsigned NumVecs = Vecs.size();
+ assert(NumVecs > 1 && "Should be at least two vectors");
+
+ SmallVector<Value *, 8> ResList;
+ ResList.append(Vecs.begin(), Vecs.end());
+ do {
+ SmallVector<Value *, 8> TmpList;
+ for (unsigned i = 0; i < NumVecs - 1; i += 2) {
+ Value *V0 = ResList[i], *V1 = ResList[i + 1];
+ assert((V0->getType() == V1->getType() || i == NumVecs - 2) &&
+ "Only the last vector may have a different type");
+
+ TmpList.push_back(concatenateTwoVectors(Builder, V0, V1));
+ }
+
+ // Push the last vector if the total number of vectors is odd.
+ if (NumVecs % 2 != 0)
+ TmpList.push_back(ResList[NumVecs - 1]);
+
+ ResList = TmpList;
+ NumVecs = ResList.size();
+ } while (NumVecs > 1);
+
+ return ResList[0];
+}
+
+bool llvm::maskIsAllZeroOrUndef(Value *Mask) {
+ auto *ConstMask = dyn_cast<Constant>(Mask);
+ if (!ConstMask)
+ return false;
+ if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask))
+ return true;
+ for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E;
+ ++I) {
+ if (auto *MaskElt = ConstMask->getAggregateElement(I))
+ if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt))
+ continue;
+ return false;
+ }
+ return true;
+}
+
+
+bool llvm::maskIsAllOneOrUndef(Value *Mask) {
+ auto *ConstMask = dyn_cast<Constant>(Mask);
+ if (!ConstMask)
+ return false;
+ if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+ return true;
+ for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E;
+ ++I) {
+ if (auto *MaskElt = ConstMask->getAggregateElement(I))
+ if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+ continue;
+ return false;
+ }
+ return true;
+}
+
+/// TODO: This is a lot like known bits, but for
+/// vectors. Is there something we can common this with?
+APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
+
+ const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements();
+ APInt DemandedElts = APInt::getAllOnesValue(VWidth);
+ if (auto *CV = dyn_cast<ConstantVector>(Mask))
+ for (unsigned i = 0; i < VWidth; i++)
+ if (CV->getAggregateElement(i)->isNullValue())
+ DemandedElts.clearBit(i);
+ return DemandedElts;
+}
+
+bool InterleavedAccessInfo::isStrided(int Stride) {
+ unsigned Factor = std::abs(Stride);
+ return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
+}
+
+void InterleavedAccessInfo::collectConstStrideAccesses(
+ MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
+ const ValueToValueMap &Strides) {
+ auto &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+
+ // Since it's desired that the load/store instructions be maintained in
+ // "program order" for the interleaved access analysis, we have to visit the
+ // blocks in the loop in reverse postorder (i.e., in a topological order).
+ // Such an ordering will ensure that any load/store that may be executed
+ // before a second load/store will precede the second load/store in
+ // AccessStrideInfo.
+ LoopBlocksDFS DFS(TheLoop);
+ DFS.perform(LI);
+ for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO()))
+ for (auto &I : *BB) {
+ auto *LI = dyn_cast<LoadInst>(&I);
+ auto *SI = dyn_cast<StoreInst>(&I);
+ if (!LI && !SI)
+ continue;
+
+ Value *Ptr = getLoadStorePointerOperand(&I);
+ // We don't check wrapping here because we don't know yet if Ptr will be
+ // part of a full group or a group with gaps. Checking wrapping for all
+ // pointers (even those that end up in groups with no gaps) will be overly
+ // conservative. For full groups, wrapping should be ok since if we would
+ // wrap around the address space we would do a memory access at nullptr
+ // even without the transformation. The wrapping checks are therefore
+ // deferred until after we've formed the interleaved groups.
+ int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides,
+ /*Assume=*/true, /*ShouldCheckWrap=*/false);
+
+ const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
+ PointerType *PtrTy = cast<PointerType>(Ptr->getType());
+ uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType());
+
+ // An alignment of 0 means target ABI alignment.
+ MaybeAlign Alignment = MaybeAlign(getLoadStoreAlignment(&I));
+ if (!Alignment)
+ Alignment = Align(DL.getABITypeAlignment(PtrTy->getElementType()));
+
+ AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, *Alignment);
+ }
+}
+
+// Analyze interleaved accesses and collect them into interleaved load and
+// store groups.
+//
+// When generating code for an interleaved load group, we effectively hoist all
+// loads in the group to the location of the first load in program order. When
+// generating code for an interleaved store group, we sink all stores to the
+// location of the last store. This code motion can change the order of load
+// and store instructions and may break dependences.
+//
+// The code generation strategy mentioned above ensures that we won't violate
+// any write-after-read (WAR) dependences.
+//
+// E.g., for the WAR dependence: a = A[i]; // (1)
+// A[i] = b; // (2)
+//
+// The store group of (2) is always inserted at or below (2), and the load
+// group of (1) is always inserted at or above (1). Thus, the instructions will
+// never be reordered. All other dependences are checked to ensure the
+// correctness of the instruction reordering.
+//
+// The algorithm visits all memory accesses in the loop in bottom-up program
+// order. Program order is established by traversing the blocks in the loop in
+// reverse postorder when collecting the accesses.
+//
+// We visit the memory accesses in bottom-up order because it can simplify the
+// construction of store groups in the presence of write-after-write (WAW)
+// dependences.
+//
+// E.g., for the WAW dependence: A[i] = a; // (1)
+// A[i] = b; // (2)
+// A[i + 1] = c; // (3)
+//
+// We will first create a store group with (3) and (2). (1) can't be added to
+// this group because it and (2) are dependent. However, (1) can be grouped
+// with other accesses that may precede it in program order. Note that a
+// bottom-up order does not imply that WAW dependences should not be checked.
+void InterleavedAccessInfo::analyzeInterleaving(
+ bool EnablePredicatedInterleavedMemAccesses) {
+ LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n");
+ const ValueToValueMap &Strides = LAI->getSymbolicStrides();
+
+ // Holds all accesses with a constant stride.
+ MapVector<Instruction *, StrideDescriptor> AccessStrideInfo;
+ collectConstStrideAccesses(AccessStrideInfo, Strides);
+
+ if (AccessStrideInfo.empty())
+ return;
+
+ // Collect the dependences in the loop.
+ collectDependences();
+
+ // Holds all interleaved store groups temporarily.
+ SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups;
+ // Holds all interleaved load groups temporarily.
+ SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups;
+
+ // Search in bottom-up program order for pairs of accesses (A and B) that can
+ // form interleaved load or store groups. In the algorithm below, access A
+ // precedes access B in program order. We initialize a group for B in the
+ // outer loop of the algorithm, and then in the inner loop, we attempt to
+ // insert each A into B's group if:
+ //
+ // 1. A and B have the same stride,
+ // 2. A and B have the same memory object size, and
+ // 3. A belongs in B's group according to its distance from B.
+ //
+ // Special care is taken to ensure group formation will not break any
+ // dependences.
+ for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend();
+ BI != E; ++BI) {
+ Instruction *B = BI->first;
+ StrideDescriptor DesB = BI->second;
+
+ // Initialize a group for B if it has an allowable stride. Even if we don't
+ // create a group for B, we continue with the bottom-up algorithm to ensure
+ // we don't break any of B's dependences.
+ InterleaveGroup<Instruction> *Group = nullptr;
+ if (isStrided(DesB.Stride) &&
+ (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) {
+ Group = getInterleaveGroup(B);
+ if (!Group) {
+ LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B
+ << '\n');
+ Group = createInterleaveGroup(B, DesB.Stride, DesB.Alignment);
+ }
+ if (B->mayWriteToMemory())
+ StoreGroups.insert(Group);
+ else
+ LoadGroups.insert(Group);
+ }
+
+ for (auto AI = std::next(BI); AI != E; ++AI) {
+ Instruction *A = AI->first;
+ StrideDescriptor DesA = AI->second;
+
+ // Our code motion strategy implies that we can't have dependences
+ // between accesses in an interleaved group and other accesses located
+ // between the first and last member of the group. Note that this also
+ // means that a group can't have more than one member at a given offset.
+ // The accesses in a group can have dependences with other accesses, but
+ // we must ensure we don't extend the boundaries of the group such that
+ // we encompass those dependent accesses.
+ //
+ // For example, assume we have the sequence of accesses shown below in a
+ // stride-2 loop:
+ //
+ // (1, 2) is a group | A[i] = a; // (1)
+ // | A[i-1] = b; // (2) |
+ // A[i-3] = c; // (3)
+ // A[i] = d; // (4) | (2, 4) is not a group
+ //
+ // Because accesses (2) and (3) are dependent, we can group (2) with (1)
+ // but not with (4). If we did, the dependent access (3) would be within
+ // the boundaries of the (2, 4) group.
+ if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) {
+ // If a dependence exists and A is already in a group, we know that A
+ // must be a store since A precedes B and WAR dependences are allowed.
+ // Thus, A would be sunk below B. We release A's group to prevent this
+ // illegal code motion. A will then be free to form another group with
+ // instructions that precede it.
+ if (isInterleaved(A)) {
+ InterleaveGroup<Instruction> *StoreGroup = getInterleaveGroup(A);
+
+ LLVM_DEBUG(dbgs() << "LV: Invalidated store group due to "
+ "dependence between " << *A << " and "<< *B << '\n');
+
+ StoreGroups.remove(StoreGroup);
+ releaseGroup(StoreGroup);
+ }
+
+ // If a dependence exists and A is not already in a group (or it was
+ // and we just released it), B might be hoisted above A (if B is a
+ // load) or another store might be sunk below A (if B is a store). In
+ // either case, we can't add additional instructions to B's group. B
+ // will only form a group with instructions that it precedes.
+ break;
+ }
+
+ // At this point, we've checked for illegal code motion. If either A or B
+ // isn't strided, there's nothing left to do.
+ if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride))
+ continue;
+
+ // Ignore A if it's already in a group or isn't the same kind of memory
+ // operation as B.
+ // Note that mayReadFromMemory() isn't mutually exclusive to
+ // mayWriteToMemory in the case of atomic loads. We shouldn't see those
+ // here, canVectorizeMemory() should have returned false - except for the
+ // case we asked for optimization remarks.
+ if (isInterleaved(A) ||
+ (A->mayReadFromMemory() != B->mayReadFromMemory()) ||
+ (A->mayWriteToMemory() != B->mayWriteToMemory()))
+ continue;
+
+ // Check rules 1 and 2. Ignore A if its stride or size is different from
+ // that of B.
+ if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size)
+ continue;
+
+ // Ignore A if the memory object of A and B don't belong to the same
+ // address space
+ if (getLoadStoreAddressSpace(A) != getLoadStoreAddressSpace(B))
+ continue;
+
+ // Calculate the distance from A to B.
+ const SCEVConstant *DistToB = dyn_cast<SCEVConstant>(
+ PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev));
+ if (!DistToB)
+ continue;
+ int64_t DistanceToB = DistToB->getAPInt().getSExtValue();
+
+ // Check rule 3. Ignore A if its distance to B is not a multiple of the
+ // size.
+ if (DistanceToB % static_cast<int64_t>(DesB.Size))
+ continue;
+
+ // All members of a predicated interleave-group must have the same predicate,
+ // and currently must reside in the same BB.
+ BasicBlock *BlockA = A->getParent();
+ BasicBlock *BlockB = B->getParent();
+ if ((isPredicated(BlockA) || isPredicated(BlockB)) &&
+ (!EnablePredicatedInterleavedMemAccesses || BlockA != BlockB))
+ continue;
+
+ // The index of A is the index of B plus A's distance to B in multiples
+ // of the size.
+ int IndexA =
+ Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size);
+
+ // Try to insert A into B's group.
+ if (Group->insertMember(A, IndexA, DesA.Alignment)) {
+ LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n'
+ << " into the interleave group with" << *B
+ << '\n');
+ InterleaveGroupMap[A] = Group;
+
+ // Set the first load in program order as the insert position.
+ if (A->mayReadFromMemory())
+ Group->setInsertPos(A);
+ }
+ } // Iteration over A accesses.
+ } // Iteration over B accesses.
+
+ // Remove interleaved store groups with gaps.
+ for (auto *Group : StoreGroups)
+ if (Group->getNumMembers() != Group->getFactor()) {
+ LLVM_DEBUG(
+ dbgs() << "LV: Invalidate candidate interleaved store group due "
+ "to gaps.\n");
+ releaseGroup(Group);
+ }
+ // Remove interleaved groups with gaps (currently only loads) whose memory
+ // accesses may wrap around. We have to revisit the getPtrStride analysis,
+ // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does
+ // not check wrapping (see documentation there).
+ // FORNOW we use Assume=false;
+ // TODO: Change to Assume=true but making sure we don't exceed the threshold
+ // of runtime SCEV assumptions checks (thereby potentially failing to
+ // vectorize altogether).
+ // Additional optional optimizations:
+ // TODO: If we are peeling the loop and we know that the first pointer doesn't
+ // wrap then we can deduce that all pointers in the group don't wrap.
+ // This means that we can forcefully peel the loop in order to only have to
+ // check the first pointer for no-wrap. When we'll change to use Assume=true
+ // we'll only need at most one runtime check per interleaved group.
+ for (auto *Group : LoadGroups) {
+ // Case 1: A full group. Can Skip the checks; For full groups, if the wide
+ // load would wrap around the address space we would do a memory access at
+ // nullptr even without the transformation.
+ if (Group->getNumMembers() == Group->getFactor())
+ continue;
+
+ // Case 2: If first and last members of the group don't wrap this implies
+ // that all the pointers in the group don't wrap.
+ // So we check only group member 0 (which is always guaranteed to exist),
+ // and group member Factor - 1; If the latter doesn't exist we rely on
+ // peeling (if it is a non-reversed accsess -- see Case 3).
+ Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0));
+ if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,
+ /*ShouldCheckWrap=*/true)) {
+ LLVM_DEBUG(
+ dbgs() << "LV: Invalidate candidate interleaved group due to "
+ "first group member potentially pointer-wrapping.\n");
+ releaseGroup(Group);
+ continue;
+ }
+ Instruction *LastMember = Group->getMember(Group->getFactor() - 1);
+ if (LastMember) {
+ Value *LastMemberPtr = getLoadStorePointerOperand(LastMember);
+ if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false,
+ /*ShouldCheckWrap=*/true)) {
+ LLVM_DEBUG(
+ dbgs() << "LV: Invalidate candidate interleaved group due to "
+ "last group member potentially pointer-wrapping.\n");
+ releaseGroup(Group);
+ }
+ } else {
+ // Case 3: A non-reversed interleaved load group with gaps: We need
+ // to execute at least one scalar epilogue iteration. This will ensure
+ // we don't speculatively access memory out-of-bounds. We only need
+ // to look for a member at index factor - 1, since every group must have
+ // a member at index zero.
+ if (Group->isReverse()) {
+ LLVM_DEBUG(
+ dbgs() << "LV: Invalidate candidate interleaved group due to "
+ "a reverse access with gaps.\n");
+ releaseGroup(Group);
+ continue;
+ }
+ LLVM_DEBUG(
+ dbgs() << "LV: Interleaved group requires epilogue iteration.\n");
+ RequiresScalarEpilogue = true;
+ }
+ }
+}
+
+void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() {
+ // If no group had triggered the requirement to create an epilogue loop,
+ // there is nothing to do.
+ if (!requiresScalarEpilogue())
+ return;
+
+ // Avoid releasing a Group twice.
+ SmallPtrSet<InterleaveGroup<Instruction> *, 4> DelSet;
+ for (auto &I : InterleaveGroupMap) {
+ InterleaveGroup<Instruction> *Group = I.second;
+ if (Group->requiresScalarEpilogue())
+ DelSet.insert(Group);
+ }
+ for (auto *Ptr : DelSet) {
+ LLVM_DEBUG(
+ dbgs()
+ << "LV: Invalidate candidate interleaved group due to gaps that "
+ "require a scalar epilogue (not allowed under optsize) and cannot "
+ "be masked (not enabled). \n");
+ releaseGroup(Ptr);
+ }
+
+ RequiresScalarEpilogue = false;
+}
+
+template <typename InstT>
+void InterleaveGroup<InstT>::addMetadata(InstT *NewInst) const {
+ llvm_unreachable("addMetadata can only be used for Instruction");
+}
+
+namespace llvm {
+template <>
+void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const {
+ SmallVector<Value *, 4> VL;
+ std::transform(Members.begin(), Members.end(), std::back_inserter(VL),
+ [](std::pair<int, Instruction *> p) { return p.second; });
+ propagateMetadata(NewInst, VL);
+}
+}